Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import nanoid
import websockets
from opentelemetry import propagate
from pydantic import ValidationError
from websockets import (
WebSocketCommonProtocol,
Expand Down Expand Up @@ -170,7 +171,13 @@ async def _establish_new_connection(

try:
uri_and_metadata = await self._uri_and_metadata_factory()
ws = await websockets.connect(uri_and_metadata["uri"], max_size=None)
otel_headers: dict[str, str] = {}
propagate.inject(otel_headers)
ws = await websockets.connect(
uri_and_metadata["uri"],
max_size=None,
extra_headers=otel_headers,
)
session_id = (
self.generate_nanoid()
if not old_session
Expand Down
67 changes: 39 additions & 28 deletions src/replit_river/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Mapping

import websockets
from opentelemetry import context, propagate
from websockets.exceptions import ConnectionClosed
from websockets.server import WebSocketServerProtocol

Expand Down Expand Up @@ -68,34 +69,44 @@ async def serve(self, websocket: WebSocketServerProtocol) -> None:
logger.debug(
"River server started establishing session with ws: %s", websocket.id
)
grace_ms = self._transport_options.handshake_timeout_ms

# Extract OTel context (traceparent, tracestate, baggage) from the
# WebSocket HTTP upgrade request headers and make it the ambient
# context for the lifetime of this connection.
otel_context = propagate.extract(websocket.request_headers)
token = context.attach(otel_context)

try:
session = await asyncio.wait_for(
self._handshake_to_get_session(websocket),
grace_ms / 1000, # wait_for unit is seconds
)
if not session:
grace_ms = self._transport_options.handshake_timeout_ms
try:
session = await asyncio.wait_for(
self._handshake_to_get_session(websocket),
grace_ms / 1000, # wait_for unit is seconds
)
if not session:
return
except asyncio.TimeoutError:
logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket")
await websocket.close()
return
except asyncio.TimeoutError:
logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket")
await websocket.close()
return
except asyncio.CancelledError:
logger.error("Handshake cancelled, closing websocket")
await websocket.close()
return
logger.debug("River server session established, start serving messages")
except asyncio.CancelledError:
logger.error("Handshake cancelled, closing websocket")
await websocket.close()
return
logger.debug("River server session established, start serving messages")

try:
# Session serve will be closed in two cases
# 1. websocket is closed
# 2. exception thrown
# session should be kept in order to be reused by the reconnect within the
# grace period.
await session.serve()
except ConnectionClosed:
logger.debug("ConnectionClosed while serving", exc_info=True)
# We don't have to close the websocket here, it is already closed.
except Exception:
logger.exception("River transport error in server %s", self._server_id)
await websocket.close()
try:
# Session serve will be closed in two cases
# 1. websocket is closed
# 2. exception thrown
# session should be kept in order to be reused by the reconnect within
# the grace period.
await session.serve()
except ConnectionClosed:
logger.debug("ConnectionClosed while serving", exc_info=True)
# We don't have to close the websocket here, it is already closed.
except Exception:
logger.exception("River transport error in server %s", self._server_id)
await websocket.close()
finally:
context.detach(token)
4 changes: 4 additions & 0 deletions src/replit_river/v2/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import websockets.asyncio.client
from aiochannel import Channel, ChannelEmpty, ChannelFull
from aiochannel.errors import ChannelClosed
from opentelemetry import propagate
from opentelemetry.trace import Span, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import ValidationError
Expand Down Expand Up @@ -1133,9 +1134,12 @@ async def _do_ensure_connected[HandshakeMetadata](
ws: ClientConnection | None = None
try:
uri_and_metadata = await uri_and_metadata_factory()
otel_headers: dict[str, str] = {}
propagate.inject(otel_headers)
ws = await websockets.asyncio.client.connect(
uri_and_metadata["uri"],
max_size=None,
additional_headers=otel_headers,
)
transition_connecting(ws)

Expand Down
229 changes: 227 additions & 2 deletions tests/v1/test_opentelemetry.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import contextlib
import logging
from datetime import timedelta
from typing import AsyncGenerator, AsyncIterator, Iterator
from typing import AsyncGenerator, AsyncIterator, Iterator, Literal

import grpc
import grpc.aio
import pytest
from opentelemetry import baggage, context, propagate, trace
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.propagators.composite import CompositePropagator
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import StatusCode
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from websockets.server import serve

from replit_river.client import Client
from replit_river.client_transport import UriAndMetadata
from replit_river.error_schema import RiverError, RiverException
from replit_river.rpc import stream_method_handler
from replit_river.rpc import rpc_method_handler, stream_method_handler
from replit_river.server import Server
from replit_river.transport_options import TransportOptions
from tests.conftest import (
HandlerMapping,
deserialize_error,
Expand Down Expand Up @@ -219,3 +228,219 @@ async def stream_data() -> AsyncGenerator[str, None]:
assert len(spans) == 1
assert spans[0].name == "river.client.stream.test_service.stream_method"
assert spans[0].status.status_code == StatusCode.OK


# ===== OTel context propagation via WebSocket HTTP upgrade headers =====


# A handler that reads OTel baggage from the ambient context and returns it.
async def baggage_echo_handler(request: str, ctx: grpc.aio.ServicerContext) -> str:
all_baggage = baggage.get_all()
# Return baggage as a comma-separated "key=value" string
return ",".join(f"{k}={v}" for k, v in sorted(all_baggage.items()))


baggage_echo_handlers: HandlerMapping = {
("test_service", "baggage_echo"): (
"rpc",
rpc_method_handler(
baggage_echo_handler, deserialize_request, serialize_response
),
)
}


@pytest.fixture
def _enable_baggage_propagator() -> Iterator[None]:
"""Temporarily install a composite propagator that includes both
W3C TraceContext and W3C Baggage propagation so that
``propagate.inject()`` / ``propagate.extract()`` handle the
``baggage`` HTTP header."""
previous = propagate.get_global_textmap()
propagate.set_global_textmap(
CompositePropagator(
[
TraceContextTextMapPropagator(),
W3CBaggagePropagator(),
]
)
)
yield
propagate.set_global_textmap(previous)


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
@pytest.mark.usefixtures("_enable_baggage_propagator")
async def test_baggage_propagated_via_ws_headers(
no_logging_error: NoErrors,
server: Server,
transport_options: TransportOptions,
) -> None:
"""Verify that OTel baggage set on the client side is propagated to the
server via the WebSocket HTTP upgrade request headers."""

# Set baggage in the ambient OTel context *before* the client connects,
# so that ``propagate.inject()`` (called inside ``websockets.connect()``)
# includes the ``baggage`` header.
ctx = baggage.set_baggage("test-key", "test-value")
ctx = baggage.set_baggage("another-key", "another-value", context=ctx)
token = context.attach(ctx)

binding = None
try:
binding = await serve(server.serve, "127.0.0.1")
sockets = list(binding.sockets)
assert len(sockets) == 1
socket = sockets[0]

async def websocket_uri_factory() -> UriAndMetadata[None]:
return {
"uri": "ws://%s:%d" % socket.getsockname(),
"metadata": None,
}

client: Client[Literal[None]] = Client[None](
uri_and_metadata_factory=websocket_uri_factory,
client_id="test_client",
server_id="test_server",
transport_options=transport_options,
)
try:
response = await client.send_rpc(
"test_service",
"baggage_echo",
"ignored",
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
# The handler returns sorted "key=value" pairs
assert response == "another-key=another-value,test-key=test-value"
finally:
logging.debug("Start closing test client")
await client.close()
finally:
context.detach(token)
logging.debug("Start closing test server")
if binding:
binding.close()
await server.close()
if binding:
await binding.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
@pytest.mark.usefixtures("_enable_baggage_propagator")
async def test_no_baggage_when_none_set(
no_logging_error: NoErrors,
server: Server,
transport_options: TransportOptions,
) -> None:
"""Verify that when no baggage is set, the server sees empty baggage."""

binding = None
try:
binding = await serve(server.serve, "127.0.0.1")
sockets = list(binding.sockets)
assert len(sockets) == 1
socket = sockets[0]

async def websocket_uri_factory() -> UriAndMetadata[None]:
return {
"uri": "ws://%s:%d" % socket.getsockname(),
"metadata": None,
}

client: Client[Literal[None]] = Client[None](
uri_and_metadata_factory=websocket_uri_factory,
client_id="test_client",
server_id="test_server",
transport_options=transport_options,
)
try:
response = await client.send_rpc(
"test_service",
"baggage_echo",
"ignored",
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
assert response == ""
finally:
logging.debug("Start closing test client")
await client.close()
finally:
logging.debug("Start closing test server")
if binding:
binding.close()
await server.close()
if binding:
await binding.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
@pytest.mark.usefixtures("_enable_baggage_propagator")
async def test_traceparent_propagated_via_ws_headers(
no_logging_error: NoErrors,
server: Server,
transport_options: TransportOptions,
span_exporter: InMemorySpanExporter,
) -> None:
"""Verify that when a span is active on the client, the traceparent
header is sent on the WS upgrade and the server-side context inherits
the trace."""
tracer = trace.get_tracer(__name__)

with tracer.start_as_current_span("client-operation"):
# Also set some baggage
ctx = baggage.set_baggage("trace-test", "yes")
token = context.attach(ctx)

binding = None
try:
binding = await serve(server.serve, "127.0.0.1")
sockets = list(binding.sockets)
assert len(sockets) == 1
socket = sockets[0]

async def websocket_uri_factory() -> UriAndMetadata[None]:
return {
"uri": "ws://%s:%d" % socket.getsockname(),
"metadata": None,
}

client: Client[Literal[None]] = Client[None](
uri_and_metadata_factory=websocket_uri_factory,
client_id="test_client",
server_id="test_server",
transport_options=transport_options,
)
try:
response = await client.send_rpc(
"test_service",
"baggage_echo",
"ignored",
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
# Verify baggage was propagated
assert response == "trace-test=yes"
finally:
logging.debug("Start closing test client")
await client.close()
finally:
context.detach(token)
logging.debug("Start closing test server")
if binding:
binding.close()
await server.close()
if binding:
await binding.wait_closed()
Loading