From 634f270812c739236e6febefc63f900eba72910b Mon Sep 17 00:00:00 2001 From: lhchavez Date: Sat, 28 Feb 2026 15:48:09 -0800 Subject: [PATCH 1/3] feat: propagate OTel context via WebSocket HTTP upgrade headers Propagate traceparent, tracestate, and baggage through the WebSocket connection using standard W3C HTTP headers on the upgrade request, matching how any HTTP-based service would propagate OTel context. Client side (v1 + v2): - Use propagate.inject() to capture the current OTel context into a headers dict, then pass it as extra_headers/additional_headers to websockets.connect(). Server side: - In Server.serve(), use propagate.extract() on websocket.request_headers to restore the OTel context, then attach it as the ambient context for the lifetime of the connection. --- src/replit_river/client_transport.py | 9 +- src/replit_river/server.py | 71 +++++---- src/replit_river/v2/session.py | 4 + tests/v1/test_opentelemetry.py | 227 ++++++++++++++++++++++++++- 4 files changed, 280 insertions(+), 31 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index cd3be09d..551a870a 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -5,6 +5,7 @@ import nanoid import websockets +from opentelemetry import propagate from pydantic import ValidationError from websockets import ( WebSocketCommonProtocol, @@ -170,7 +171,13 @@ async def _establish_new_connection( try: uri_and_metadata = await self._uri_and_metadata_factory() - ws = await websockets.connect(uri_and_metadata["uri"], max_size=None) + otel_headers: dict[str, str] = {} + propagate.inject(otel_headers) + ws = await websockets.connect( + uri_and_metadata["uri"], + max_size=None, + extra_headers=otel_headers, + ) session_id = ( self.generate_nanoid() if not old_session diff --git a/src/replit_river/server.py b/src/replit_river/server.py index 64974fc3..198b8fe3 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -3,6 +3,7 @@ from typing import Mapping import websockets +from opentelemetry import context, propagate from websockets.exceptions import ConnectionClosed from websockets.server import WebSocketServerProtocol @@ -68,34 +69,48 @@ async def serve(self, websocket: WebSocketServerProtocol) -> None: logger.debug( "River server started establishing session with ws: %s", websocket.id ) - grace_ms = self._transport_options.handshake_timeout_ms + + # Extract OTel context (traceparent, tracestate, baggage) from the + # WebSocket HTTP upgrade request headers and make it the ambient + # context for the lifetime of this connection. + otel_context = propagate.extract(websocket.request_headers) + token = context.attach(otel_context) + try: - session = await asyncio.wait_for( - self._handshake_to_get_session(websocket), - grace_ms / 1000, # wait_for unit is seconds - ) - if not session: + grace_ms = self._transport_options.handshake_timeout_ms + try: + session = await asyncio.wait_for( + self._handshake_to_get_session(websocket), + grace_ms / 1000, # wait_for unit is seconds + ) + if not session: + return + except asyncio.TimeoutError: + logger.error( + f"Handshake timeout after {grace_ms}ms, closing websocket" + ) + await websocket.close() return - except asyncio.TimeoutError: - logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket") - await websocket.close() - return - except asyncio.CancelledError: - logger.error("Handshake cancelled, closing websocket") - await websocket.close() - return - logger.debug("River server session established, start serving messages") + except asyncio.CancelledError: + logger.error("Handshake cancelled, closing websocket") + await websocket.close() + return + logger.debug("River server session established, start serving messages") - try: - # Session serve will be closed in two cases - # 1. websocket is closed - # 2. exception thrown - # session should be kept in order to be reused by the reconnect within the - # grace period. - await session.serve() - except ConnectionClosed: - logger.debug("ConnectionClosed while serving", exc_info=True) - # We don't have to close the websocket here, it is already closed. - except Exception: - logger.exception("River transport error in server %s", self._server_id) - await websocket.close() + try: + # Session serve will be closed in two cases + # 1. websocket is closed + # 2. exception thrown + # session should be kept in order to be reused by the reconnect within + # the grace period. + await session.serve() + except ConnectionClosed: + logger.debug("ConnectionClosed while serving", exc_info=True) + # We don't have to close the websocket here, it is already closed. + except Exception: + logger.exception( + "River transport error in server %s", self._server_id + ) + await websocket.close() + finally: + context.detach(token) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index f7bfe4ed..879e9ee5 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -23,6 +23,7 @@ import websockets.asyncio.client from aiochannel import Channel, ChannelEmpty, ChannelFull from aiochannel.errors import ChannelClosed +from opentelemetry import propagate from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import ValidationError @@ -1133,9 +1134,12 @@ async def _do_ensure_connected[HandshakeMetadata]( ws: ClientConnection | None = None try: uri_and_metadata = await uri_and_metadata_factory() + otel_headers: dict[str, str] = {} + propagate.inject(otel_headers) ws = await websockets.asyncio.client.connect( uri_and_metadata["uri"], max_size=None, + additional_headers=otel_headers, ) transition_connecting(ws) diff --git a/tests/v1/test_opentelemetry.py b/tests/v1/test_opentelemetry.py index c47b5418..efd317e9 100644 --- a/tests/v1/test_opentelemetry.py +++ b/tests/v1/test_opentelemetry.py @@ -1,16 +1,25 @@ import contextlib +import logging from datetime import timedelta -from typing import AsyncGenerator, AsyncIterator, Iterator +from typing import AsyncGenerator, AsyncIterator, Iterator, Literal import grpc import grpc.aio import pytest +from opentelemetry import baggage, context, propagate, trace +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from websockets.server import serve from replit_river.client import Client +from replit_river.client_transport import UriAndMetadata from replit_river.error_schema import RiverError, RiverException -from replit_river.rpc import stream_method_handler +from replit_river.rpc import rpc_method_handler, stream_method_handler +from replit_river.server import Server +from replit_river.transport_options import TransportOptions from tests.conftest import ( HandlerMapping, deserialize_error, @@ -219,3 +228,217 @@ async def stream_data() -> AsyncGenerator[str, None]: assert len(spans) == 1 assert spans[0].name == "river.client.stream.test_service.stream_method" assert spans[0].status.status_code == StatusCode.OK + + +# ===== OTel context propagation via WebSocket HTTP upgrade headers ===== + + +# A handler that reads OTel baggage from the ambient context and returns it. +async def baggage_echo_handler( + request: str, ctx: grpc.aio.ServicerContext +) -> str: + all_baggage = baggage.get_all() + # Return baggage as a comma-separated "key=value" string + return ",".join(f"{k}={v}" for k, v in sorted(all_baggage.items())) + + +baggage_echo_handlers: HandlerMapping = { + ("test_service", "baggage_echo"): ( + "rpc", + rpc_method_handler(baggage_echo_handler, deserialize_request, serialize_response), + ) +} + + +@pytest.fixture +def _enable_baggage_propagator(): + """Temporarily install a composite propagator that includes both + W3C TraceContext and W3C Baggage propagation so that + ``propagate.inject()`` / ``propagate.extract()`` handle the + ``baggage`` HTTP header.""" + previous = propagate.get_global_textmap() + propagate.set_global_textmap( + CompositePropagator([ + TraceContextTextMapPropagator(), + W3CBaggagePropagator(), + ]) + ) + yield + propagate.set_global_textmap(previous) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}]) +@pytest.mark.usefixtures("_enable_baggage_propagator") +async def test_baggage_propagated_via_ws_headers( + no_logging_error: NoErrors, + server: Server, + transport_options: TransportOptions, +) -> None: + """Verify that OTel baggage set on the client side is propagated to the + server via the WebSocket HTTP upgrade request headers.""" + + # Set baggage in the ambient OTel context *before* the client connects, + # so that ``propagate.inject()`` (called inside ``websockets.connect()``) + # includes the ``baggage`` header. + ctx = baggage.set_baggage("test-key", "test-value") + ctx = baggage.set_baggage("another-key", "another-value", context=ctx) + token = context.attach(ctx) + + binding = None + try: + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1 + socket = sockets[0] + + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + response = await client.send_rpc( + "test_service", + "baggage_echo", + "ignored", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + # The handler returns sorted "key=value" pairs + assert response == "another-key=another-value,test-key=test-value" + finally: + logging.debug("Start closing test client") + await client.close() + finally: + context.detach(token) + logging.debug("Start closing test server") + if binding: + binding.close() + await server.close() + if binding: + await binding.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}]) +@pytest.mark.usefixtures("_enable_baggage_propagator") +async def test_no_baggage_when_none_set( + no_logging_error: NoErrors, + server: Server, + transport_options: TransportOptions, +) -> None: + """Verify that when no baggage is set, the server sees empty baggage.""" + + binding = None + try: + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1 + socket = sockets[0] + + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + response = await client.send_rpc( + "test_service", + "baggage_echo", + "ignored", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + assert response == "" + finally: + logging.debug("Start closing test client") + await client.close() + finally: + logging.debug("Start closing test server") + if binding: + binding.close() + await server.close() + if binding: + await binding.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}]) +@pytest.mark.usefixtures("_enable_baggage_propagator") +async def test_traceparent_propagated_via_ws_headers( + no_logging_error: NoErrors, + server: Server, + transport_options: TransportOptions, + span_exporter: InMemorySpanExporter, +) -> None: + """Verify that when a span is active on the client, the traceparent + header is sent on the WS upgrade and the server-side context inherits + the trace.""" + tracer = trace.get_tracer(__name__) + + with tracer.start_as_current_span("client-operation") as client_span: + # Also set some baggage + ctx = baggage.set_baggage("trace-test", "yes") + token = context.attach(ctx) + + binding = None + try: + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1 + socket = sockets[0] + + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + response = await client.send_rpc( + "test_service", + "baggage_echo", + "ignored", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + # Verify baggage was propagated + assert response == "trace-test=yes" + finally: + logging.debug("Start closing test client") + await client.close() + finally: + context.detach(token) + logging.debug("Start closing test server") + if binding: + binding.close() + await server.close() + if binding: + await binding.wait_closed() From 7d250ab4296df8f9f5c44b3424de1c1aa827b33c Mon Sep 17 00:00:00 2001 From: lhchavez Date: Sat, 28 Feb 2026 16:18:13 -0800 Subject: [PATCH 2/3] fix: address formatting issues (ruff format + unused variable) --- src/replit_river/server.py | 8 ++------ tests/v1/test_opentelemetry.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/replit_river/server.py b/src/replit_river/server.py index 198b8fe3..67739d3e 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -86,9 +86,7 @@ async def serve(self, websocket: WebSocketServerProtocol) -> None: if not session: return except asyncio.TimeoutError: - logger.error( - f"Handshake timeout after {grace_ms}ms, closing websocket" - ) + logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket") await websocket.close() return except asyncio.CancelledError: @@ -108,9 +106,7 @@ async def serve(self, websocket: WebSocketServerProtocol) -> None: logger.debug("ConnectionClosed while serving", exc_info=True) # We don't have to close the websocket here, it is already closed. except Exception: - logger.exception( - "River transport error in server %s", self._server_id - ) + logger.exception("River transport error in server %s", self._server_id) await websocket.close() finally: context.detach(token) diff --git a/tests/v1/test_opentelemetry.py b/tests/v1/test_opentelemetry.py index efd317e9..a9316fa5 100644 --- a/tests/v1/test_opentelemetry.py +++ b/tests/v1/test_opentelemetry.py @@ -234,9 +234,7 @@ async def stream_data() -> AsyncGenerator[str, None]: # A handler that reads OTel baggage from the ambient context and returns it. -async def baggage_echo_handler( - request: str, ctx: grpc.aio.ServicerContext -) -> str: +async def baggage_echo_handler(request: str, ctx: grpc.aio.ServicerContext) -> str: all_baggage = baggage.get_all() # Return baggage as a comma-separated "key=value" string return ",".join(f"{k}={v}" for k, v in sorted(all_baggage.items())) @@ -245,7 +243,9 @@ async def baggage_echo_handler( baggage_echo_handlers: HandlerMapping = { ("test_service", "baggage_echo"): ( "rpc", - rpc_method_handler(baggage_echo_handler, deserialize_request, serialize_response), + rpc_method_handler( + baggage_echo_handler, deserialize_request, serialize_response + ), ) } @@ -258,10 +258,12 @@ def _enable_baggage_propagator(): ``baggage`` HTTP header.""" previous = propagate.get_global_textmap() propagate.set_global_textmap( - CompositePropagator([ - TraceContextTextMapPropagator(), - W3CBaggagePropagator(), - ]) + CompositePropagator( + [ + TraceContextTextMapPropagator(), + W3CBaggagePropagator(), + ] + ) ) yield propagate.set_global_textmap(previous) @@ -395,7 +397,7 @@ async def test_traceparent_propagated_via_ws_headers( the trace.""" tracer = trace.get_tracer(__name__) - with tracer.start_as_current_span("client-operation") as client_span: + with tracer.start_as_current_span("client-operation"): # Also set some baggage ctx = baggage.set_baggage("trace-test", "yes") token = context.attach(ctx) From f27bd0b99278094a4a44791d55bbbfadf410a3a4 Mon Sep 17 00:00:00 2001 From: lhchavez Date: Sat, 28 Feb 2026 16:23:22 -0800 Subject: [PATCH 3/3] fix: add return type annotation to _enable_baggage_propagator fixture --- tests/v1/test_opentelemetry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/test_opentelemetry.py b/tests/v1/test_opentelemetry.py index a9316fa5..219d332e 100644 --- a/tests/v1/test_opentelemetry.py +++ b/tests/v1/test_opentelemetry.py @@ -251,7 +251,7 @@ async def baggage_echo_handler(request: str, ctx: grpc.aio.ServicerContext) -> s @pytest.fixture -def _enable_baggage_propagator(): +def _enable_baggage_propagator() -> Iterator[None]: """Temporarily install a composite propagator that includes both W3C TraceContext and W3C Baggage propagation so that ``propagate.inject()`` / ``propagate.extract()`` handle the