1use std::time::Instant;
8
9use opentelemetry::{KeyValue, metrics::Histogram};
10use pin_project_lite::pin_project;
11use tower::{Layer, Service};
12
13use crate::{METER, MetricsAttributes, utils::FnWrapper};
14
15#[derive(Clone, Debug)]
17pub struct DurationRecorderLayer<OnRequest = (), OnResponse = (), OnError = ()> {
18 histogram: Histogram<u64>,
19 on_request: OnRequest,
20 on_response: OnResponse,
21 on_error: OnError,
22}
23
24impl DurationRecorderLayer {
25 #[must_use]
27 pub fn new(name: &'static str) -> Self {
28 let histogram = METER.u64_histogram(name).build();
29 Self {
30 histogram,
31 on_request: (),
32 on_response: (),
33 on_error: (),
34 }
35 }
36}
37
38impl<OnRequest, OnResponse, OnError> DurationRecorderLayer<OnRequest, OnResponse, OnError> {
39 #[must_use]
41 pub fn on_request<NewOnRequest>(
42 self,
43 on_request: NewOnRequest,
44 ) -> DurationRecorderLayer<NewOnRequest, OnResponse, OnError> {
45 DurationRecorderLayer {
46 histogram: self.histogram,
47 on_request,
48 on_response: self.on_response,
49 on_error: self.on_error,
50 }
51 }
52
53 #[must_use]
54 pub fn on_request_fn<F, T>(
55 self,
56 on_request: F,
57 ) -> DurationRecorderLayer<FnWrapper<F>, OnResponse, OnError>
58 where
59 F: Fn(&T) -> Vec<KeyValue>,
60 {
61 self.on_request(FnWrapper(on_request))
62 }
63
64 #[must_use]
66 pub fn on_response<NewOnResponse>(
67 self,
68 on_response: NewOnResponse,
69 ) -> DurationRecorderLayer<OnRequest, NewOnResponse, OnError> {
70 DurationRecorderLayer {
71 histogram: self.histogram,
72 on_request: self.on_request,
73 on_response,
74 on_error: self.on_error,
75 }
76 }
77
78 #[must_use]
79 pub fn on_response_fn<F, T>(
80 self,
81 on_response: F,
82 ) -> DurationRecorderLayer<OnRequest, FnWrapper<F>, OnError>
83 where
84 F: Fn(&T) -> Vec<KeyValue>,
85 {
86 self.on_response(FnWrapper(on_response))
87 }
88
89 #[must_use]
91 pub fn on_error<NewOnError>(
92 self,
93 on_error: NewOnError,
94 ) -> DurationRecorderLayer<OnRequest, OnResponse, NewOnError> {
95 DurationRecorderLayer {
96 histogram: self.histogram,
97 on_request: self.on_request,
98 on_response: self.on_response,
99 on_error,
100 }
101 }
102
103 #[must_use]
104 pub fn on_error_fn<F, T>(
105 self,
106 on_error: F,
107 ) -> DurationRecorderLayer<OnRequest, OnResponse, FnWrapper<F>>
108 where
109 F: Fn(&T) -> Vec<KeyValue>,
110 {
111 self.on_error(FnWrapper(on_error))
112 }
113}
114
115impl<S, OnRequest, OnResponse, OnError> Layer<S>
116 for DurationRecorderLayer<OnRequest, OnResponse, OnError>
117where
118 OnRequest: Clone,
119 OnResponse: Clone,
120 OnError: Clone,
121{
122 type Service = DurationRecorderService<S, OnRequest, OnResponse, OnError>;
123
124 fn layer(&self, inner: S) -> Self::Service {
125 DurationRecorderService {
126 inner,
127 histogram: self.histogram.clone(),
128 on_request: self.on_request.clone(),
129 on_response: self.on_response.clone(),
130 on_error: self.on_error.clone(),
131 }
132 }
133}
134
135#[derive(Clone, Debug)]
137pub struct DurationRecorderService<S, OnRequest = (), OnResponse = (), OnError = ()> {
138 inner: S,
139 histogram: Histogram<u64>,
140 on_request: OnRequest,
141 on_response: OnResponse,
142 on_error: OnError,
143}
144
145pin_project! {
146 pub struct DurationRecorderFuture<F, OnResponse = (), OnError = ()> {
148 #[pin]
149 inner: F,
150
151 start: Instant,
152 histogram: Histogram<u64>,
153 attributes_from_request: Vec<KeyValue>,
154 from_response: OnResponse,
155 from_error: OnError,
156 }
157}
158
159impl<F, R, E, OnResponse, OnError> Future for DurationRecorderFuture<F, OnResponse, OnError>
160where
161 F: Future<Output = Result<R, E>>,
162 OnResponse: MetricsAttributes<R>,
163 OnError: MetricsAttributes<E>,
164{
165 type Output = F::Output;
166
167 fn poll(
168 self: std::pin::Pin<&mut Self>,
169 cx: &mut std::task::Context<'_>,
170 ) -> std::task::Poll<Self::Output> {
171 let this = self.project();
172 let result = std::task::ready!(this.inner.poll(cx));
173
174 let duration = this.start.elapsed();
176 let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
177
178 let mut attributes = this.attributes_from_request.clone();
180 match &result {
181 Ok(response) => {
182 attributes.extend(this.from_response.attributes(response));
183 }
184 Err(error) => {
185 attributes.extend(this.from_error.attributes(error));
186 }
187 }
188
189 this.histogram.record(duration_ms, &attributes);
190 std::task::Poll::Ready(result)
191 }
192}
193
194impl<S, R, OnRequest, OnResponse, OnError> Service<R>
195 for DurationRecorderService<S, OnRequest, OnResponse, OnError>
196where
197 S: Service<R>,
198 OnRequest: MetricsAttributes<R>,
199 OnResponse: MetricsAttributes<S::Response> + Clone,
200 OnError: MetricsAttributes<S::Error> + Clone,
201{
202 type Response = S::Response;
203 type Error = S::Error;
204 type Future = DurationRecorderFuture<S::Future, OnResponse, OnError>;
205
206 fn poll_ready(
207 &mut self,
208 cx: &mut std::task::Context<'_>,
209 ) -> std::task::Poll<Result<(), Self::Error>> {
210 self.inner.poll_ready(cx)
211 }
212
213 fn call(&mut self, request: R) -> Self::Future {
214 let start = Instant::now();
215 let attributes_from_request = self.on_request.attributes(&request).collect();
216 let inner = self.inner.call(request);
217
218 DurationRecorderFuture {
219 inner,
220 start,
221 histogram: self.histogram.clone(),
222 attributes_from_request,
223 from_response: self.on_response.clone(),
224 from_error: self.on_error.clone(),
225 }
226 }
227}