mas_tower/tracing/
service.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use tower::Service;
8
9use super::future::TraceFuture;
10
11#[derive(Clone, Debug)]
12pub struct TraceService<S, MakeSpan, OnResponse = (), OnError = ()> {
13    inner: S,
14    make_span: MakeSpan,
15    on_response: OnResponse,
16    on_error: OnError,
17}
18
19impl<S, MakeSpan, OnResponse, OnError> TraceService<S, MakeSpan, OnResponse, OnError> {
20    /// Create a new [`TraceService`].
21    #[must_use]
22    pub fn new(inner: S, make_span: MakeSpan, on_response: OnResponse, on_error: OnError) -> Self {
23        Self {
24            inner,
25            make_span,
26            on_response,
27            on_error,
28        }
29    }
30}
31
32impl<R, S, MakeSpan, OnResponse, OnError> Service<R>
33    for TraceService<S, MakeSpan, OnResponse, OnError>
34where
35    S: Service<R>,
36    MakeSpan: super::make_span::MakeSpan<R>,
37    OnResponse: super::enrich_span::EnrichSpan<S::Response> + Clone,
38    OnError: super::enrich_span::EnrichSpan<S::Error> + Clone,
39{
40    type Response = S::Response;
41    type Error = S::Error;
42    type Future = TraceFuture<S::Future, OnResponse, OnError>;
43
44    fn poll_ready(
45        &mut self,
46        cx: &mut std::task::Context<'_>,
47    ) -> std::task::Poll<Result<(), Self::Error>> {
48        self.inner.poll_ready(cx)
49    }
50
51    fn call(&mut self, request: R) -> Self::Future {
52        let span = self.make_span.make_span(&request);
53        let guard = span.enter();
54        let inner = self.inner.call(request);
55        drop(guard);
56
57        TraceFuture::new(inner, span, self.on_response.clone(), self.on_error.clone())
58    }
59}