mas_tower/
trace_context.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use http::Request;
8use opentelemetry::propagation::Injector;
9use opentelemetry_http::HeaderInjector;
10use tower::{Layer, Service};
11use tracing::Span;
12use tracing_opentelemetry::OpenTelemetrySpanExt;
13
14/// A trait to get an [`Injector`] from a request.
15trait AsInjector {
16    type Injector<'a>: Injector
17    where
18        Self: 'a;
19
20    fn as_injector(&mut self) -> Self::Injector<'_>;
21}
22
23impl<B> AsInjector for Request<B> {
24    type Injector<'a>
25        = HeaderInjector<'a>
26    where
27        Self: 'a;
28
29    fn as_injector(&mut self) -> Self::Injector<'_> {
30        HeaderInjector(self.headers_mut())
31    }
32}
33
34/// A [`Layer`] that adds a trace context to the request.
35#[derive(Debug, Clone, Copy, Default)]
36pub struct TraceContextLayer {
37    _private: (),
38}
39
40impl TraceContextLayer {
41    /// Create a new [`TraceContextLayer`].
42    #[must_use]
43    pub fn new() -> Self {
44        Self::default()
45    }
46}
47
48impl<S> Layer<S> for TraceContextLayer {
49    type Service = TraceContextService<S>;
50
51    fn layer(&self, inner: S) -> Self::Service {
52        TraceContextService::new(inner)
53    }
54}
55
56/// A [`Service`] that adds a trace context to the request.
57#[derive(Debug, Clone)]
58pub struct TraceContextService<S> {
59    inner: S,
60}
61
62impl<S> TraceContextService<S> {
63    /// Create a new [`TraceContextService`].
64    pub fn new(inner: S) -> Self {
65        Self { inner }
66    }
67}
68
69impl<S, R> Service<R> for TraceContextService<S>
70where
71    S: Service<R>,
72    R: AsInjector,
73{
74    type Response = S::Response;
75    type Error = S::Error;
76    type Future = S::Future;
77
78    fn poll_ready(
79        &mut self,
80        cx: &mut std::task::Context<'_>,
81    ) -> std::task::Poll<Result<(), Self::Error>> {
82        self.inner.poll_ready(cx)
83    }
84
85    fn call(&mut self, mut req: R) -> Self::Future {
86        // Get the `opentelemetry` context out of the `tracing` span.
87        let context = Span::current().context();
88
89        // Inject the trace context into the request. The block is there to ensure that
90        // the injector is dropped before calling the inner service, to avoid borrowing
91        // issues.
92        {
93            let mut injector = req.as_injector();
94            opentelemetry::global::get_text_map_propagator(|propagator| {
95                propagator.inject_context(&context, &mut injector);
96            });
97        }
98
99        self.inner.call(req)
100    }
101}