diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index cd3be09d..551a870a 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -5,6 +5,7 @@ import nanoid import websockets +from opentelemetry import propagate from pydantic import ValidationError from websockets import ( WebSocketCommonProtocol, @@ -170,7 +171,13 @@ async def _establish_new_connection( try: uri_and_metadata = await self._uri_and_metadata_factory() - ws = await websockets.connect(uri_and_metadata["uri"], max_size=None) + otel_headers: dict[str, str] = {} + propagate.inject(otel_headers) + ws = await websockets.connect( + uri_and_metadata["uri"], + max_size=None, + extra_headers=otel_headers, + ) session_id = ( self.generate_nanoid() if not old_session diff --git a/src/replit_river/server.py b/src/replit_river/server.py index 64974fc3..67739d3e 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -3,6 +3,7 @@ from typing import Mapping import websockets +from opentelemetry import context, propagate from websockets.exceptions import ConnectionClosed from websockets.server import WebSocketServerProtocol @@ -68,34 +69,44 @@ async def serve(self, websocket: WebSocketServerProtocol) -> None: logger.debug( "River server started establishing session with ws: %s", websocket.id ) - grace_ms = self._transport_options.handshake_timeout_ms + + # Extract OTel context (traceparent, tracestate, baggage) from the + # WebSocket HTTP upgrade request headers and make it the ambient + # context for the lifetime of this connection. + otel_context = propagate.extract(websocket.request_headers) + token = context.attach(otel_context) + try: - session = await asyncio.wait_for( - self._handshake_to_get_session(websocket), - grace_ms / 1000, # wait_for unit is seconds - ) - if not session: + grace_ms = self._transport_options.handshake_timeout_ms + try: + session = await asyncio.wait_for( + self._handshake_to_get_session(websocket), + grace_ms / 1000, # wait_for unit is seconds + ) + if not session: + return + except asyncio.TimeoutError: + logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket") + await websocket.close() return - except asyncio.TimeoutError: - logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket") - await websocket.close() - return - except asyncio.CancelledError: - logger.error("Handshake cancelled, closing websocket") - await websocket.close() - return - logger.debug("River server session established, start serving messages") + except asyncio.CancelledError: + logger.error("Handshake cancelled, closing websocket") + await websocket.close() + return + logger.debug("River server session established, start serving messages") - try: - # Session serve will be closed in two cases - # 1. websocket is closed - # 2. exception thrown - # session should be kept in order to be reused by the reconnect within the - # grace period. - await session.serve() - except ConnectionClosed: - logger.debug("ConnectionClosed while serving", exc_info=True) - # We don't have to close the websocket here, it is already closed. - except Exception: - logger.exception("River transport error in server %s", self._server_id) - await websocket.close() + try: + # Session serve will be closed in two cases + # 1. websocket is closed + # 2. exception thrown + # session should be kept in order to be reused by the reconnect within + # the grace period. + await session.serve() + except ConnectionClosed: + logger.debug("ConnectionClosed while serving", exc_info=True) + # We don't have to close the websocket here, it is already closed. + except Exception: + logger.exception("River transport error in server %s", self._server_id) + await websocket.close() + finally: + context.detach(token) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index f7bfe4ed..879e9ee5 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -23,6 +23,7 @@ import websockets.asyncio.client from aiochannel import Channel, ChannelEmpty, ChannelFull from aiochannel.errors import ChannelClosed +from opentelemetry import propagate from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import ValidationError @@ -1133,9 +1134,12 @@ async def _do_ensure_connected[HandshakeMetadata]( ws: ClientConnection | None = None try: uri_and_metadata = await uri_and_metadata_factory() + otel_headers: dict[str, str] = {} + propagate.inject(otel_headers) ws = await websockets.asyncio.client.connect( uri_and_metadata["uri"], max_size=None, + additional_headers=otel_headers, ) transition_connecting(ws) diff --git a/tests/v1/test_opentelemetry.py b/tests/v1/test_opentelemetry.py index c47b5418..219d332e 100644 --- a/tests/v1/test_opentelemetry.py +++ b/tests/v1/test_opentelemetry.py @@ -1,16 +1,25 @@ import contextlib +import logging from datetime import timedelta -from typing import AsyncGenerator, AsyncIterator, Iterator +from typing import AsyncGenerator, AsyncIterator, Iterator, Literal import grpc import grpc.aio import pytest +from opentelemetry import baggage, context, propagate, trace +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from websockets.server import serve from replit_river.client import Client +from replit_river.client_transport import UriAndMetadata from replit_river.error_schema import RiverError, RiverException -from replit_river.rpc import stream_method_handler +from replit_river.rpc import rpc_method_handler, stream_method_handler +from replit_river.server import Server +from replit_river.transport_options import TransportOptions from tests.conftest import ( HandlerMapping, deserialize_error, @@ -219,3 +228,219 @@ async def stream_data() -> AsyncGenerator[str, None]: assert len(spans) == 1 assert spans[0].name == "river.client.stream.test_service.stream_method" assert spans[0].status.status_code == StatusCode.OK + + +# ===== OTel context propagation via WebSocket HTTP upgrade headers ===== + + +# A handler that reads OTel baggage from the ambient context and returns it. +async def baggage_echo_handler(request: str, ctx: grpc.aio.ServicerContext) -> str: + all_baggage = baggage.get_all() + # Return baggage as a comma-separated "key=value" string + return ",".join(f"{k}={v}" for k, v in sorted(all_baggage.items())) + + +baggage_echo_handlers: HandlerMapping = { + ("test_service", "baggage_echo"): ( + "rpc", + rpc_method_handler( + baggage_echo_handler, deserialize_request, serialize_response + ), + ) +} + + +@pytest.fixture +def _enable_baggage_propagator() -> Iterator[None]: + """Temporarily install a composite propagator that includes both + W3C TraceContext and W3C Baggage propagation so that + ``propagate.inject()`` / ``propagate.extract()`` handle the + ``baggage`` HTTP header.""" + previous = propagate.get_global_textmap() + propagate.set_global_textmap( + CompositePropagator( + [ + TraceContextTextMapPropagator(), + W3CBaggagePropagator(), + ] + ) + ) + yield + propagate.set_global_textmap(previous) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}]) +@pytest.mark.usefixtures("_enable_baggage_propagator") +async def test_baggage_propagated_via_ws_headers( + no_logging_error: NoErrors, + server: Server, + transport_options: TransportOptions, +) -> None: + """Verify that OTel baggage set on the client side is propagated to the + server via the WebSocket HTTP upgrade request headers.""" + + # Set baggage in the ambient OTel context *before* the client connects, + # so that ``propagate.inject()`` (called inside ``websockets.connect()``) + # includes the ``baggage`` header. + ctx = baggage.set_baggage("test-key", "test-value") + ctx = baggage.set_baggage("another-key", "another-value", context=ctx) + token = context.attach(ctx) + + binding = None + try: + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1 + socket = sockets[0] + + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + response = await client.send_rpc( + "test_service", + "baggage_echo", + "ignored", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + # The handler returns sorted "key=value" pairs + assert response == "another-key=another-value,test-key=test-value" + finally: + logging.debug("Start closing test client") + await client.close() + finally: + context.detach(token) + logging.debug("Start closing test server") + if binding: + binding.close() + await server.close() + if binding: + await binding.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}]) +@pytest.mark.usefixtures("_enable_baggage_propagator") +async def test_no_baggage_when_none_set( + no_logging_error: NoErrors, + server: Server, + transport_options: TransportOptions, +) -> None: + """Verify that when no baggage is set, the server sees empty baggage.""" + + binding = None + try: + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1 + socket = sockets[0] + + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + response = await client.send_rpc( + "test_service", + "baggage_echo", + "ignored", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + assert response == "" + finally: + logging.debug("Start closing test client") + await client.close() + finally: + logging.debug("Start closing test server") + if binding: + binding.close() + await server.close() + if binding: + await binding.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}]) +@pytest.mark.usefixtures("_enable_baggage_propagator") +async def test_traceparent_propagated_via_ws_headers( + no_logging_error: NoErrors, + server: Server, + transport_options: TransportOptions, + span_exporter: InMemorySpanExporter, +) -> None: + """Verify that when a span is active on the client, the traceparent + header is sent on the WS upgrade and the server-side context inherits + the trace.""" + tracer = trace.get_tracer(__name__) + + with tracer.start_as_current_span("client-operation"): + # Also set some baggage + ctx = baggage.set_baggage("trace-test", "yes") + token = context.attach(ctx) + + binding = None + try: + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1 + socket = sockets[0] + + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + response = await client.send_rpc( + "test_service", + "baggage_echo", + "ignored", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + # Verify baggage was propagated + assert response == "trace-test=yes" + finally: + logging.debug("Start closing test client") + await client.close() + finally: + context.detach(token) + logging.debug("Start closing test server") + if binding: + binding.close() + await server.close() + if binding: + await binding.wait_closed()