From 7823369db960ce7af3c106e0491789f85d085ffa Mon Sep 17 00:00:00 2001 From: lhchavez Date: Sat, 28 Feb 2026 15:37:34 -0800 Subject: [PATCH] feat: propagate OTel tracing context from client to server through websocket Add server-side trace context extraction and span creation so that distributed traces flow end-to-end through River websocket connections. Changes: - Add TransportMessageTracingGetter to extract traceparent/tracestate from incoming TransportMessages (counterpart to existing Setter) - Extract trace context in ServerSession._open_stream_and_call_handler and create a SERVER span that is a child of the client's CLIENT span - Run handler within the extracted context so downstream code inherits the trace - Update and expand tests to verify server spans, trace propagation, span attributes, and independent traces for concurrent RPCs --- src/replit_river/rpc.py | 32 ++- src/replit_river/server_session.py | 53 ++++- tests/v1/test_opentelemetry.py | 302 +++++++++++++++++++++++++++-- 3 files changed, 369 insertions(+), 18 deletions(-) diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 678e8bf6..26153808 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -20,7 +20,7 @@ import grpc from aiochannel import Channel, ChannelClosed -from opentelemetry.propagators.textmap import Setter +from opentelemetry.propagators.textmap import Getter, Setter from pydantic import BaseModel, ConfigDict, Field from replit_river.error_schema import ( @@ -126,6 +126,36 @@ def set(self, carrier: TransportMessage, key: str, value: str) -> None: logger.warning("unknown trace propagation key", extra={"key": key}) +class TransportMessageTracingGetter(Getter[TransportMessage]): + """ + Handles extracting tracing context from an incoming transport message. + """ + + def get(self, carrier: TransportMessage, key: str) -> list[str] | None: + if not carrier.tracing: + return None + match key: + case "traceparent": + value = carrier.tracing.traceparent + case "tracestate": + value = carrier.tracing.tracestate + case _: + return None + if not value: + return None + return [value] + + def keys(self, carrier: TransportMessage) -> list[str]: + if not carrier.tracing: + return [] + keys: list[str] = [] + if carrier.tracing.traceparent: + keys.append("traceparent") + if carrier.tracing.tracestate: + keys.append("tracestate") + return keys + + class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]): """Represents a gRPC-compatible ServicerContext for River interop.""" diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index c397e900..904e680f 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -4,6 +4,8 @@ import websockets from aiochannel import Channel, ChannelClosed +from opentelemetry import context, trace +from opentelemetry.trace import SpanKind from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.exceptions import ConnectionClosed @@ -24,6 +26,7 @@ STREAM_OPEN_BIT, GenericRpcHandlerBuilder, TransportMessage, + TransportMessageTracingGetter, TransportMessageTracingSetter, ) @@ -32,9 +35,11 @@ logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() +trace_getter = TransportMessageTracingGetter() class ServerSession(Session): @@ -216,6 +221,23 @@ async def _open_stream_and_call_handler( "upload-stream", # subscription "stream", ) + + # Extract trace context from the incoming message and create a server span. + extracted_context = trace_propagator.extract( + carrier=msg, getter=trace_getter + ) + span = tracer.start_span( + f"river.server.{method_type}.{msg.serviceName}.{msg.procedureName}", + context=extracted_context, + kind=SpanKind.SERVER, + ) + span.set_attribute("river.service_name", msg.serviceName) + span.set_attribute("river.procedure_name", msg.procedureName) + span.set_attribute("river.method_type", method_type) + span.set_attribute("river.stream_id", msg.streamId) + span.set_attribute("river.client_id", msg.from_) + handler_ctx = trace.set_span_in_context(span, extracted_context) + # New channel pair. input_stream: Channel[Any] = Channel( MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1 @@ -231,9 +253,13 @@ async def _open_stream_and_call_handler( await input_stream.put(msg.payload) except (RuntimeError, ChannelClosed) as e: raise InvalidMessageException(e) from e - # Start the handler. + # Start the handler with the extracted trace context. self._task_manager.create_task( - handler_func(msg.from_, input_stream, output_stream), tg + self._run_handler_with_tracing( + handler_func, msg.from_, input_stream, output_stream, + span, handler_ctx, + ), + tg, ) self._task_manager.create_task( self._send_responses_from_output_stream( @@ -243,6 +269,29 @@ async def _open_stream_and_call_handler( ) return input_stream + async def _run_handler_with_tracing( + self, + handler_func: GenericRpcHandlerBuilder, + peer: str, + input_stream: Channel[Any], + output_stream: Channel[Any], + span: trace.Span, + handler_ctx: context.Context, + ) -> None: + """Run an RPC handler within the extracted trace context, ending the span + when the handler completes.""" + token = context.attach(handler_ctx) + try: + await handler_func(peer, input_stream, output_stream) + span.set_status(trace.StatusCode.OK) + except Exception as e: + span.set_status(trace.StatusCode.ERROR, str(e)) + span.record_exception(e) + raise + finally: + span.end() + context.detach(token) + async def _send_responses_from_output_stream( self, stream_id: str, diff --git a/tests/v1/test_opentelemetry.py b/tests/v1/test_opentelemetry.py index c47b5418..fb34b4eb 100644 --- a/tests/v1/test_opentelemetry.py +++ b/tests/v1/test_opentelemetry.py @@ -6,7 +6,7 @@ import grpc.aio import pytest from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from opentelemetry.trace import StatusCode +from opentelemetry.trace import SpanKind, StatusCode from replit_river.client import Client from replit_river.error_schema import RiverError, RiverException @@ -44,8 +44,13 @@ async def test_rpc_method_span( ) assert response == "Hello, Alice!" spans = span_exporter.get_finished_spans() - assert len(spans) == 1 - assert spans[0].name == "river.client.rpc.test_service.rpc_method" + assert len(spans) == 2 + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + assert client_spans[0].name == "river.client.rpc.test_service.rpc_method" + assert server_spans[0].name == "river.server.rpc.test_service.rpc_method" @pytest.mark.asyncio @@ -70,8 +75,16 @@ async def upload_data() -> AsyncGenerator[str, None]: ) assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3" spans = span_exporter.get_finished_spans() - assert len(spans) == 1 - assert spans[0].name == "river.client.upload.test_service.upload_method" + assert len(spans) == 2 + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + assert client_spans[0].name == "river.client.upload.test_service.upload_method" + assert ( + server_spans[0].name + == "river.server.upload-stream.test_service.upload_method" + ) @pytest.mark.asyncio @@ -91,8 +104,19 @@ async def test_subscription_method_span( assert "Subscription message" in response spans = span_exporter.get_finished_spans() - assert len(spans) == 1 - assert spans[0].name == "river.client.subscription.test_service.subscription_method" + assert len(spans) == 2 + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + assert ( + client_spans[0].name + == "river.client.subscription.test_service.subscription_method" + ) + assert ( + server_spans[0].name + == "river.server.subscription-stream.test_service.subscription_method" + ) @pytest.mark.asyncio @@ -126,8 +150,13 @@ async def stream_data() -> AsyncGenerator[str, None]: ] spans = span_exporter.get_finished_spans() - assert len(spans) == 1 - assert spans[0].name == "river.client.stream.test_service.stream_method" + assert len(spans) == 2 + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + assert client_spans[0].name == "river.client.stream.test_service.stream_method" + assert server_spans[0].name == "river.server.stream.test_service.stream_method" async def stream_error_handler( @@ -180,9 +209,20 @@ async def stream_data() -> AsyncGenerator[str, None]: assert isinstance(responses[0], RiverError) spans = span_exporter.get_finished_spans() - assert len(spans) == 1 - assert spans[0].name == "river.client.stream.test_service.stream_method_error" - assert spans[0].status.status_code == StatusCode.ERROR + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + assert ( + client_spans[0].name + == "river.client.stream.test_service.stream_method_error" + ) + assert client_spans[0].status.status_code == StatusCode.ERROR + assert ( + server_spans[0].name + == "river.server.stream.test_service.stream_method_error" + ) + assert server_spans[0].status.status_code == StatusCode.OK @pytest.mark.asyncio @@ -216,6 +256,238 @@ async def stream_data() -> AsyncGenerator[str, None]: ] spans = span_exporter.get_finished_spans() - assert len(spans) == 1 - assert spans[0].name == "river.client.stream.test_service.stream_method" - assert spans[0].status.status_code == StatusCode.OK + client_spans = [s for s in spans if "client" in s.name] + assert len(client_spans) == 1 + assert client_spans[0].name == "river.client.stream.test_service.stream_method" + assert client_spans[0].status.status_code == StatusCode.OK + + +# ===== Trace propagation tests ===== + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) +async def test_rpc_trace_propagation( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + """Test that the server span is a child of the client span (same trace).""" + response = await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + assert response == "Hello, Alice!" + + spans = span_exporter.get_finished_spans() + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + + client_span = client_spans[0] + server_span = server_spans[0] + + # Both spans should share the same trace ID + assert client_span.context.trace_id == server_span.context.trace_id + + # Server span should be a child of the client span + assert server_span.parent is not None + assert server_span.parent.span_id == client_span.context.span_id + + # Verify span kinds + assert client_span.kind == SpanKind.CLIENT + assert server_span.kind == SpanKind.SERVER + + # Both should be OK + assert client_span.status.status_code == StatusCode.OK + assert server_span.status.status_code == StatusCode.OK + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_subscription}]) +async def test_subscription_trace_propagation( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + """Test that trace context propagates for subscriptions.""" + async for response in client.send_subscription( + "test_service", + "subscription_method", + "Bob", + serialize_request, + deserialize_response, + deserialize_error, + ): + pass + + spans = span_exporter.get_finished_spans() + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + + client_span = client_spans[0] + server_span = server_spans[0] + + assert client_span.context.trace_id == server_span.context.trace_id + assert server_span.parent is not None + assert server_span.parent.span_id == client_span.context.span_id + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_upload}]) +async def test_upload_trace_propagation( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + """Test that trace context propagates for uploads.""" + + async def upload_data() -> AsyncGenerator[str, None]: + yield "Data 1" + yield "Data 2" + + response = await client.send_upload( + "test_service", + "upload_method", + "Initial Data", + upload_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ) + assert response == "Uploaded: Initial Data, Data 1, Data 2" + + spans = span_exporter.get_finished_spans() + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + + client_span = client_spans[0] + server_span = server_spans[0] + + assert client_span.context.trace_id == server_span.context.trace_id + assert server_span.parent is not None + assert server_span.parent.span_id == client_span.context.span_id + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_stream}]) +async def test_stream_trace_propagation( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + """Test that trace context propagates for bidirectional streams.""" + + async def stream_data() -> AsyncGenerator[str, None]: + yield "Stream 1" + yield "Stream 2" + + responses = [] + async for response in client.send_stream( + "test_service", + "stream_method", + "Initial Stream Data", + stream_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ): + responses.append(response) + + assert len(responses) == 3 # Initial + 2 stream items + + spans = span_exporter.get_finished_spans() + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 1 + assert len(server_spans) == 1 + + client_span = client_spans[0] + server_span = server_spans[0] + + assert client_span.context.trace_id == server_span.context.trace_id + assert server_span.parent is not None + assert server_span.parent.span_id == client_span.context.span_id + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) +async def test_server_span_has_attributes( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + """Test that server spans have the expected attributes.""" + await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + + spans = span_exporter.get_finished_spans() + server_spans = [s for s in spans if "server" in s.name] + assert len(server_spans) == 1 + server_span = server_spans[0] + + assert server_span.attributes is not None + assert server_span.attributes.get("river.service_name") == "test_service" + assert server_span.attributes.get("river.procedure_name") == "rpc_method" + assert server_span.attributes.get("river.method_type") == "rpc" + assert server_span.attributes.get("river.client_id") == "test_client" + assert server_span.attributes.get("river.stream_id") is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) +async def test_multiple_rpcs_have_independent_traces( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + """Test that independent RPCs create independent traces.""" + await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + await client.send_rpc( + "test_service", + "rpc_method", + "Bob", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + + spans = span_exporter.get_finished_spans() + assert len(spans) == 4 # 2 client + 2 server + + client_spans = [s for s in spans if "client" in s.name] + server_spans = [s for s in spans if "server" in s.name] + assert len(client_spans) == 2 + assert len(server_spans) == 2 + + # Each RPC should have its own trace ID + trace_id_1 = client_spans[0].context.trace_id + trace_id_2 = client_spans[1].context.trace_id + assert trace_id_1 != trace_id_2 + + # Each server span should have matching trace IDs with its corresponding client span + for server_span in server_spans: + assert server_span.parent is not None + matching_client = [ + c + for c in client_spans + if c.context.span_id == server_span.parent.span_id + ] + assert len(matching_client) == 1 + assert server_span.context.trace_id == matching_client[0].context.trace_id