mas_http/
reqwest.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use std::{
7    str::FromStr,
8    sync::{Arc, LazyLock},
9    time::Duration,
10};
11
12use futures_util::FutureExt as _;
13use headers::{ContentLength, HeaderMapExt as _, UserAgent};
14use hyper_util::client::legacy::connect::{
15    HttpInfo,
16    dns::{GaiResolver, Name},
17};
18use opentelemetry::{
19    KeyValue,
20    metrics::{Histogram, UpDownCounter},
21};
22use opentelemetry_http::HeaderInjector;
23use opentelemetry_semantic_conventions::{
24    attribute::{HTTP_REQUEST_BODY_SIZE, HTTP_RESPONSE_BODY_SIZE},
25    metric::{HTTP_CLIENT_ACTIVE_REQUESTS, HTTP_CLIENT_REQUEST_DURATION},
26    trace::{
27        ERROR_TYPE, HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, NETWORK_LOCAL_ADDRESS,
28        NETWORK_LOCAL_PORT, NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT, NETWORK_TRANSPORT,
29        NETWORK_TYPE, SERVER_ADDRESS, SERVER_PORT, URL_FULL, URL_SCHEME, USER_AGENT_ORIGINAL,
30    },
31};
32use rustls_platform_verifier::ConfigVerifierExt;
33use tokio::time::Instant;
34use tower::{BoxError, Service as _};
35use tracing::Instrument;
36use tracing_opentelemetry::OpenTelemetrySpanExt;
37
38use crate::METER;
39
40static USER_AGENT: &str = concat!("matrix-authentication-service/", env!("CARGO_PKG_VERSION"));
41
42static HTTP_REQUESTS_DURATION_HISTOGRAM: LazyLock<Histogram<u64>> = LazyLock::new(|| {
43    METER
44        .u64_histogram(HTTP_CLIENT_REQUEST_DURATION)
45        .with_unit("ms")
46        .with_description("Duration of HTTP client requests")
47        .build()
48});
49
50static HTTP_REQUESTS_IN_FLIGHT: LazyLock<UpDownCounter<i64>> = LazyLock::new(|| {
51    METER
52        .i64_up_down_counter(HTTP_CLIENT_ACTIVE_REQUESTS)
53        .with_unit("{requests}")
54        .with_description("Number of HTTP client requests in flight")
55        .build()
56});
57
58struct TracingResolver {
59    inner: GaiResolver,
60}
61
62impl TracingResolver {
63    fn new() -> Self {
64        let inner = GaiResolver::new();
65        Self { inner }
66    }
67}
68
69impl reqwest::dns::Resolve for TracingResolver {
70    fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
71        let span = tracing::info_span!("dns.resolve", name = name.as_str());
72        let inner = &mut self.inner.clone();
73        Box::pin(
74            inner
75                .call(Name::from_str(name.as_str()).unwrap())
76                .map(|result| {
77                    result
78                        .map(|addrs| -> reqwest::dns::Addrs { Box::new(addrs) })
79                        .map_err(|err| -> BoxError { Box::new(err) })
80                })
81                .instrument(span),
82        )
83    }
84}
85
86/// Create a new [`reqwest::Client`] with sane parameters
87///
88/// # Panics
89///
90/// Panics if the client fails to build, which should never happen
91#[must_use]
92pub fn client() -> reqwest::Client {
93    // TODO: can/should we limit in-flight requests?
94    let tls_config = rustls::ClientConfig::with_platform_verifier();
95    reqwest::Client::builder()
96        .dns_resolver(Arc::new(TracingResolver::new()))
97        .use_preconfigured_tls(tls_config)
98        .user_agent(USER_AGENT)
99        .timeout(Duration::from_secs(60))
100        .connect_timeout(Duration::from_secs(30))
101        .read_timeout(Duration::from_secs(30))
102        .build()
103        .expect("failed to create HTTP client")
104}
105
106async fn send_traced(
107    request: reqwest::RequestBuilder,
108) -> Result<reqwest::Response, reqwest::Error> {
109    let start = Instant::now();
110    let (client, request) = request.build_split();
111    let mut request = request?;
112
113    let headers = request.headers();
114    let server_address = request.url().host_str().map(ToOwned::to_owned);
115    let server_port = request.url().port_or_known_default();
116    let scheme = request.url().scheme().to_owned();
117    let user_agent = headers
118        .typed_get::<UserAgent>()
119        .map(tracing::field::display);
120    let content_length = headers.typed_get().map(|ContentLength(len)| len);
121    let method = request.method().to_string();
122
123    // Create a new span for the request
124    let span = tracing::info_span!(
125        "http.client.request",
126        "otel.kind" = "client",
127        "otel.status_code" = tracing::field::Empty,
128        { HTTP_REQUEST_METHOD } = method,
129        { URL_FULL } = %request.url(),
130        { HTTP_RESPONSE_STATUS_CODE } = tracing::field::Empty,
131        { SERVER_ADDRESS } = server_address,
132        { SERVER_PORT } = server_port,
133        { HTTP_REQUEST_BODY_SIZE } = content_length,
134        { HTTP_RESPONSE_BODY_SIZE } = tracing::field::Empty,
135        { NETWORK_TRANSPORT } = "tcp",
136        { NETWORK_TYPE } = tracing::field::Empty,
137        { NETWORK_LOCAL_ADDRESS } = tracing::field::Empty,
138        { NETWORK_LOCAL_PORT } = tracing::field::Empty,
139        { NETWORK_PEER_ADDRESS } = tracing::field::Empty,
140        { NETWORK_PEER_PORT } = tracing::field::Empty,
141        { USER_AGENT_ORIGINAL } = user_agent,
142        "rust.error" = tracing::field::Empty,
143    );
144
145    // Inject the span context into the request headers
146    let context = span.context();
147    opentelemetry::global::get_text_map_propagator(|propagator| {
148        let mut injector = HeaderInjector(request.headers_mut());
149        propagator.inject_context(&context, &mut injector);
150    });
151
152    let mut metrics_labels = vec![
153        KeyValue::new(HTTP_REQUEST_METHOD, method.clone()),
154        KeyValue::new(URL_SCHEME, scheme),
155    ];
156
157    if let Some(server_address) = server_address {
158        metrics_labels.push(KeyValue::new(SERVER_ADDRESS, server_address));
159    }
160
161    if let Some(server_port) = server_port {
162        metrics_labels.push(KeyValue::new(SERVER_PORT, i64::from(server_port)));
163    }
164
165    HTTP_REQUESTS_IN_FLIGHT.add(1, &metrics_labels);
166    async move {
167        let span = tracing::Span::current();
168        let result = client.execute(request).await;
169
170        // XXX: We *could* loose this if the future is dropped before this, but let's
171        // not worry about it for now. Ideally we would use a `Drop` guard to decrement
172        // the counter
173        HTTP_REQUESTS_IN_FLIGHT.add(-1, &metrics_labels);
174
175        let duration = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
176        let result = match result {
177            Ok(response) => {
178                span.record("otel.status_code", "OK");
179                span.record(HTTP_RESPONSE_STATUS_CODE, response.status().as_u16());
180
181                if let Some(ContentLength(content_length)) = response.headers().typed_get() {
182                    span.record(HTTP_RESPONSE_BODY_SIZE, content_length);
183                }
184
185                if let Some(http_info) = response.extensions().get::<HttpInfo>() {
186                    let local = http_info.local_addr();
187                    let peer = http_info.remote_addr();
188                    let family = if local.is_ipv4() { "ipv4" } else { "ipv6" };
189                    span.record(NETWORK_TYPE, family);
190                    span.record(NETWORK_LOCAL_ADDRESS, local.ip().to_string());
191                    span.record(NETWORK_LOCAL_PORT, local.port());
192                    span.record(NETWORK_PEER_ADDRESS, peer.ip().to_string());
193                    span.record(NETWORK_PEER_PORT, peer.port());
194                } else {
195                    tracing::warn!("No HttpInfo injected in response extensions");
196                }
197
198                metrics_labels.push(KeyValue::new(
199                    HTTP_RESPONSE_STATUS_CODE,
200                    i64::from(response.status().as_u16()),
201                ));
202
203                Ok(response)
204            }
205            Err(err) => {
206                span.record("otel.status_code", "ERROR");
207                span.record("rust.error", &err as &dyn std::error::Error);
208
209                metrics_labels.push(KeyValue::new(ERROR_TYPE, "NO_RESPONSE"));
210
211                Err(err)
212            }
213        };
214
215        HTTP_REQUESTS_DURATION_HISTOGRAM.record(duration, &metrics_labels);
216
217        result
218    }
219    .instrument(span)
220    .await
221}
222
223/// An extension trait implemented for [`reqwest::RequestBuilder`] to send a
224/// request with a tracing span, and span context propagated.
225pub trait RequestBuilderExt {
226    /// Send the request with a tracing span, and span context propagated.
227    fn send_traced(self) -> impl Future<Output = Result<reqwest::Response, reqwest::Error>> + Send;
228}
229
230impl RequestBuilderExt for reqwest::RequestBuilder {
231    fn send_traced(self) -> impl Future<Output = Result<reqwest::Response, reqwest::Error>> + Send {
232        send_traced(self)
233    }
234}