mas_tower/metrics/
duration.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::time::Instant;
8
9use opentelemetry::{KeyValue, metrics::Histogram};
10use pin_project_lite::pin_project;
11use tower::{Layer, Service};
12
13use crate::{METER, MetricsAttributes, utils::FnWrapper};
14
15/// A [`Layer`] that records the duration of requests in milliseconds.
16#[derive(Clone, Debug)]
17pub struct DurationRecorderLayer<OnRequest = (), OnResponse = (), OnError = ()> {
18    histogram: Histogram<u64>,
19    on_request: OnRequest,
20    on_response: OnResponse,
21    on_error: OnError,
22}
23
24impl DurationRecorderLayer {
25    /// Create a new [`DurationRecorderLayer`].
26    #[must_use]
27    pub fn new(name: &'static str) -> Self {
28        let histogram = METER.u64_histogram(name).build();
29        Self {
30            histogram,
31            on_request: (),
32            on_response: (),
33            on_error: (),
34        }
35    }
36}
37
38impl<OnRequest, OnResponse, OnError> DurationRecorderLayer<OnRequest, OnResponse, OnError> {
39    /// Set the [`MetricsAttributes`] to use on request.
40    #[must_use]
41    pub fn on_request<NewOnRequest>(
42        self,
43        on_request: NewOnRequest,
44    ) -> DurationRecorderLayer<NewOnRequest, OnResponse, OnError> {
45        DurationRecorderLayer {
46            histogram: self.histogram,
47            on_request,
48            on_response: self.on_response,
49            on_error: self.on_error,
50        }
51    }
52
53    #[must_use]
54    pub fn on_request_fn<F, T>(
55        self,
56        on_request: F,
57    ) -> DurationRecorderLayer<FnWrapper<F>, OnResponse, OnError>
58    where
59        F: Fn(&T) -> Vec<KeyValue>,
60    {
61        self.on_request(FnWrapper(on_request))
62    }
63
64    /// Set the [`MetricsAttributes`] to use on response.
65    #[must_use]
66    pub fn on_response<NewOnResponse>(
67        self,
68        on_response: NewOnResponse,
69    ) -> DurationRecorderLayer<OnRequest, NewOnResponse, OnError> {
70        DurationRecorderLayer {
71            histogram: self.histogram,
72            on_request: self.on_request,
73            on_response,
74            on_error: self.on_error,
75        }
76    }
77
78    #[must_use]
79    pub fn on_response_fn<F, T>(
80        self,
81        on_response: F,
82    ) -> DurationRecorderLayer<OnRequest, FnWrapper<F>, OnError>
83    where
84        F: Fn(&T) -> Vec<KeyValue>,
85    {
86        self.on_response(FnWrapper(on_response))
87    }
88
89    /// Set the [`MetricsAttributes`] to use on error.
90    #[must_use]
91    pub fn on_error<NewOnError>(
92        self,
93        on_error: NewOnError,
94    ) -> DurationRecorderLayer<OnRequest, OnResponse, NewOnError> {
95        DurationRecorderLayer {
96            histogram: self.histogram,
97            on_request: self.on_request,
98            on_response: self.on_response,
99            on_error,
100        }
101    }
102
103    #[must_use]
104    pub fn on_error_fn<F, T>(
105        self,
106        on_error: F,
107    ) -> DurationRecorderLayer<OnRequest, OnResponse, FnWrapper<F>>
108    where
109        F: Fn(&T) -> Vec<KeyValue>,
110    {
111        self.on_error(FnWrapper(on_error))
112    }
113}
114
115impl<S, OnRequest, OnResponse, OnError> Layer<S>
116    for DurationRecorderLayer<OnRequest, OnResponse, OnError>
117where
118    OnRequest: Clone,
119    OnResponse: Clone,
120    OnError: Clone,
121{
122    type Service = DurationRecorderService<S, OnRequest, OnResponse, OnError>;
123
124    fn layer(&self, inner: S) -> Self::Service {
125        DurationRecorderService {
126            inner,
127            histogram: self.histogram.clone(),
128            on_request: self.on_request.clone(),
129            on_response: self.on_response.clone(),
130            on_error: self.on_error.clone(),
131        }
132    }
133}
134
135/// A middleware that records the duration of requests in milliseconds.
136#[derive(Clone, Debug)]
137pub struct DurationRecorderService<S, OnRequest = (), OnResponse = (), OnError = ()> {
138    inner: S,
139    histogram: Histogram<u64>,
140    on_request: OnRequest,
141    on_response: OnResponse,
142    on_error: OnError,
143}
144
145pin_project! {
146    /// The future returned by the [`DurationRecorderService`].
147    pub struct DurationRecorderFuture<F, OnResponse = (), OnError = ()> {
148        #[pin]
149        inner: F,
150
151        start: Instant,
152        histogram: Histogram<u64>,
153        attributes_from_request: Vec<KeyValue>,
154        from_response: OnResponse,
155        from_error: OnError,
156    }
157}
158
159impl<F, R, E, OnResponse, OnError> Future for DurationRecorderFuture<F, OnResponse, OnError>
160where
161    F: Future<Output = Result<R, E>>,
162    OnResponse: MetricsAttributes<R>,
163    OnError: MetricsAttributes<E>,
164{
165    type Output = F::Output;
166
167    fn poll(
168        self: std::pin::Pin<&mut Self>,
169        cx: &mut std::task::Context<'_>,
170    ) -> std::task::Poll<Self::Output> {
171        let this = self.project();
172        let result = std::task::ready!(this.inner.poll(cx));
173
174        // Measure the duration of the request.
175        let duration = this.start.elapsed();
176        let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
177
178        // Collect the attributes from the request, response and error.
179        let mut attributes = this.attributes_from_request.clone();
180        match &result {
181            Ok(response) => {
182                attributes.extend(this.from_response.attributes(response));
183            }
184            Err(error) => {
185                attributes.extend(this.from_error.attributes(error));
186            }
187        }
188
189        this.histogram.record(duration_ms, &attributes);
190        std::task::Poll::Ready(result)
191    }
192}
193
194impl<S, R, OnRequest, OnResponse, OnError> Service<R>
195    for DurationRecorderService<S, OnRequest, OnResponse, OnError>
196where
197    S: Service<R>,
198    OnRequest: MetricsAttributes<R>,
199    OnResponse: MetricsAttributes<S::Response> + Clone,
200    OnError: MetricsAttributes<S::Error> + Clone,
201{
202    type Response = S::Response;
203    type Error = S::Error;
204    type Future = DurationRecorderFuture<S::Future, OnResponse, OnError>;
205
206    fn poll_ready(
207        &mut self,
208        cx: &mut std::task::Context<'_>,
209    ) -> std::task::Poll<Result<(), Self::Error>> {
210        self.inner.poll_ready(cx)
211    }
212
213    fn call(&mut self, request: R) -> Self::Future {
214        let start = Instant::now();
215        let attributes_from_request = self.on_request.attributes(&request).collect();
216        let inner = self.inner.call(request);
217
218        DurationRecorderFuture {
219            inner,
220            start,
221            histogram: self.histogram.clone(),
222            attributes_from_request,
223            from_response: self.on_response.clone(),
224            from_error: self.on_error.clone(),
225        }
226    }
227}