mas_listener/
server.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11    time::Duration,
12};
13
14use futures_util::{StreamExt, stream::SelectAll};
15use hyper::{Request, Response};
16use hyper_util::{
17    rt::{TokioExecutor, TokioIo},
18    server::conn::auto::Connection,
19    service::TowerToHyperService,
20};
21use pin_project_lite::pin_project;
22use thiserror::Error;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
25use tower::Service;
26use tower_http::add_extension::AddExtension;
27use tracing::Instrument;
28
29use crate::{
30    ConnectionInfo,
31    maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo},
32    proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
33    rewind::Rewind,
34    unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
35};
36
37/// The timeout for the handshake to complete
38const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
39
40pub struct Server<S> {
41    tls: Option<Arc<ServerConfig>>,
42    proxy: bool,
43    listener: UnixOrTcpListener,
44    service: S,
45}
46
47impl<S> Server<S> {
48    /// # Errors
49    ///
50    /// Returns an error if the listener couldn't be converted via [`TryInto`]
51    pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
52    where
53        L: TryInto<UnixOrTcpListener>,
54    {
55        Ok(Self {
56            tls: None,
57            proxy: false,
58            listener: listener.try_into()?,
59            service,
60        })
61    }
62
63    #[must_use]
64    pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
65        Self {
66            tls: None,
67            proxy: false,
68            listener: listener.into(),
69            service,
70        }
71    }
72
73    #[must_use]
74    pub const fn with_proxy(mut self) -> Self {
75        self.proxy = true;
76        self
77    }
78
79    #[must_use]
80    pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
81        self.tls = Some(config);
82        self
83    }
84
85    /// Run a single server
86    pub async fn run<B>(
87        self,
88        soft_shutdown_token: CancellationToken,
89        hard_shutdown_token: CancellationToken,
90    ) where
91        S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
92        S::Future: Send + 'static,
93        S::Error: std::error::Error + Send + Sync + 'static,
94        B: http_body::Body + Send + 'static,
95        B::Data: Send,
96        B::Error: std::error::Error + Send + Sync + 'static,
97    {
98        run_servers(
99            std::iter::once(self),
100            soft_shutdown_token,
101            hard_shutdown_token,
102        )
103        .await;
104    }
105}
106
107#[derive(Debug, Error)]
108#[non_exhaustive]
109enum AcceptError {
110    #[error("failed to accept connection from the underlying socket")]
111    Socket {
112        #[source]
113        source: std::io::Error,
114    },
115
116    #[error("failed to complete the TLS handshake")]
117    TlsHandshake {
118        #[source]
119        source: std::io::Error,
120    },
121
122    #[error("failed to complete the proxy protocol handshake")]
123    ProxyHandshake {
124        #[source]
125        source: ProxyAcceptError,
126    },
127
128    #[error("connection handshake timed out")]
129    HandshakeTimeout {
130        #[source]
131        source: tokio::time::error::Elapsed,
132    },
133}
134
135impl AcceptError {
136    fn socket(source: std::io::Error) -> Self {
137        Self::Socket { source }
138    }
139
140    fn tls_handshake(source: std::io::Error) -> Self {
141        Self::TlsHandshake { source }
142    }
143
144    fn proxy_handshake(source: ProxyAcceptError) -> Self {
145        Self::ProxyHandshake { source }
146    }
147
148    fn handshake_timeout(source: tokio::time::error::Elapsed) -> Self {
149        Self::HandshakeTimeout { source }
150    }
151}
152
153/// Accept a connection and do the proxy protocol and TLS handshake
154///
155/// Returns an error if the proxy protocol or TLS handshake failed.
156/// Returns the connection, which should be used to spawn a task to serve the
157/// connection.
158#[allow(clippy::type_complexity)]
159#[tracing::instrument(
160    name = "accept",
161    skip_all,
162    fields(
163        network.protocol.name = "http",
164        network.peer.address,
165        network.peer.port,
166    ),
167    err,
168)]
169async fn accept<S, B>(
170    maybe_proxy_acceptor: &MaybeProxyAcceptor,
171    maybe_tls_acceptor: &MaybeTlsAcceptor,
172    peer_addr: SocketAddr,
173    stream: UnixOrTcpConnection,
174    service: S,
175) -> Result<
176    Connection<
177        'static,
178        TokioIo<MaybeTlsStream<Rewind<UnixOrTcpConnection>>>,
179        TowerToHyperService<AddExtension<S, ConnectionInfo>>,
180        TokioExecutor,
181    >,
182    AcceptError,
183>
184where
185    S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
186    S::Error: std::error::Error + Send + Sync + 'static,
187    S::Future: Send + 'static,
188    B: http_body::Body + Send + 'static,
189    B::Data: Send,
190    B::Error: std::error::Error + Send + Sync + 'static,
191{
192    let span = tracing::Span::current();
193
194    match peer_addr {
195        SocketAddr::Net(addr) => {
196            span.record("network.peer.address", tracing::field::display(addr.ip()));
197            span.record("network.peer.port", addr.port());
198        }
199        SocketAddr::Unix(ref addr) => {
200            span.record("network.peer.address", tracing::field::debug(addr));
201        }
202    }
203
204    // Wrap the connection acceptation logic in a timeout
205    tokio::time::timeout(HANDSHAKE_TIMEOUT, async move {
206        let (proxy, stream) = maybe_proxy_acceptor
207            .accept(stream)
208            .await
209            .map_err(AcceptError::proxy_handshake)?;
210
211        let stream = maybe_tls_acceptor
212            .accept(stream)
213            .await
214            .map_err(AcceptError::tls_handshake)?;
215
216        let tls = stream.tls_info();
217
218        // Figure out if it's HTTP/2 based on the negociated ALPN info
219        let is_h2 = tls.as_ref().is_some_and(TlsStreamInfo::is_alpn_h2);
220
221        let info = ConnectionInfo {
222            tls,
223            proxy,
224            net_peer_addr: peer_addr.into_net(),
225        };
226
227        let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
228        if is_h2 {
229            builder = builder.http2_only();
230        }
231        builder.http1().keep_alive(true);
232
233        let service = TowerToHyperService::new(AddExtension::new(service, info));
234
235        let conn = builder
236            .serve_connection(TokioIo::new(stream), service)
237            .into_owned();
238
239        Ok(conn)
240    })
241    .instrument(span)
242    .await
243    .map_err(AcceptError::handshake_timeout)?
244}
245
246pin_project! {
247    /// A wrapper around a connection that can be aborted when a shutdown signal is received.
248    ///
249    /// This works by sharing an atomic boolean between all connections, and when a shutdown
250    /// signal is received, the boolean is set to true. The connection will then check the
251    /// boolean before polling the underlying connection, and if it's true, it will start a
252    /// graceful shutdown.
253    ///
254    /// We also use an event listener to wake up the connection when the shutdown signal is
255    /// received, because the connection needs to be polled again to start the graceful shutdown.
256    struct AbortableConnection<C> {
257        #[pin]
258        connection: C,
259        #[pin]
260        cancellation_future: WaitForCancellationFutureOwned,
261        did_start_shutdown: bool,
262    }
263}
264
265impl<C> AbortableConnection<C> {
266    fn new(connection: C, cancellation_token: CancellationToken) -> Self {
267        Self {
268            connection,
269            cancellation_future: cancellation_token.cancelled_owned(),
270            did_start_shutdown: false,
271        }
272    }
273}
274
275impl<T, S, B> Future
276    for AbortableConnection<Connection<'static, T, TowerToHyperService<S>, TokioExecutor>>
277where
278    Connection<'static, T, TowerToHyperService<S>, TokioExecutor>: Future,
279    S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
280    S::Future: Send + 'static,
281    S::Error: std::error::Error + Send + Sync,
282    T: hyper::rt::Read + hyper::rt::Write + Unpin,
283    B: http_body::Body + Send + 'static,
284    B::Data: Send,
285    B::Error: std::error::Error + Send + Sync + 'static,
286{
287    type Output = <Connection<'static, T, TowerToHyperService<S>, TokioExecutor> as Future>::Output;
288
289    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
290        let mut this = self.project();
291
292        if let Poll::Ready(()) = this.cancellation_future.poll(cx) {
293            if !*this.did_start_shutdown {
294                *this.did_start_shutdown = true;
295                this.connection.as_mut().graceful_shutdown();
296            }
297        }
298
299        this.connection.poll(cx)
300    }
301}
302
303#[allow(clippy::too_many_lines)]
304pub async fn run_servers<S, B>(
305    listeners: impl IntoIterator<Item = Server<S>>,
306    soft_shutdown_token: CancellationToken,
307    hard_shutdown_token: CancellationToken,
308) where
309    S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
310    S::Future: Send + 'static,
311    S::Error: std::error::Error + Send + Sync + 'static,
312    B: http_body::Body + Send + 'static,
313    B::Data: Send,
314    B::Error: std::error::Error + Send + Sync + 'static,
315{
316    // This guard on the shutdown token is to ensure that if this task crashes for
317    // any reason, the server will shut down
318    let _guard = soft_shutdown_token.clone().drop_guard();
319
320    // Create a stream of accepted connections out of the listeners
321    let mut accept_stream: SelectAll<_> = listeners
322        .into_iter()
323        .map(|server| {
324            let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
325            let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
326            futures_util::stream::poll_fn(move |cx| {
327                let res =
328                    std::task::ready!(server.listener.poll_accept(cx)).map(|(addr, stream)| {
329                        (
330                            maybe_proxy_acceptor,
331                            maybe_tls_acceptor.clone(),
332                            server.service.clone(),
333                            addr,
334                            stream,
335                        )
336                    });
337                Poll::Ready(Some(res))
338            })
339        })
340        .collect();
341
342    // A JoinSet which collects connections that are being accepted
343    let mut accept_tasks = tokio::task::JoinSet::new();
344    // A JoinSet which collects connections that are being served
345    let mut connection_tasks = tokio::task::JoinSet::new();
346
347    loop {
348        tokio::select! {
349            biased;
350
351            // First look for the shutdown signal
352            () = soft_shutdown_token.cancelled() => {
353                tracing::debug!("Shutting down listeners");
354                break;
355            },
356
357            // Poll on the JoinSet to collect connections to serve
358            res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
359                match res {
360                    Some(Ok(Ok(connection))) => {
361                        tracing::trace!("Accepted connection");
362                        let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
363                        connection_tasks.spawn(conn);
364                    },
365                    Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
366                    Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
367                    None => tracing::error!("Join set was polled even though it was empty"),
368                }
369            },
370
371            // Poll on the JoinSet to collect finished connections
372            res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
373                match res {
374                    Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
375                    Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
376                    Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
377                    None => tracing::error!("Join set was polled even though it was empty"),
378                }
379            },
380
381            // Look for connections to accept
382            res = accept_stream.next() => {
383                let Some(res) = res else { continue };
384
385                // Spawn the connection in the set, so we don't have to wait for the handshake to
386                // accept the next connection. This allows us to keep track of active connections
387                // and waiting on them for a graceful shutdown
388                accept_tasks.spawn(async move {
389                    let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res
390                        .map_err(AcceptError::socket)?;
391                    accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
392                });
393            },
394        };
395    }
396
397    // Wait for connections to cleanup
398    if !accept_tasks.is_empty() || !connection_tasks.is_empty() {
399        tracing::info!(
400            "There are {active} active connections ({pending} pending), performing a graceful shutdown. Send the shutdown signal again to force.",
401            active = connection_tasks.len(),
402            pending = accept_tasks.len(),
403        );
404
405        while !accept_tasks.is_empty() || !connection_tasks.is_empty() {
406            tokio::select! {
407                biased;
408
409                // Poll on the JoinSet to collect connections to serve
410                res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
411                    match res {
412                        Some(Ok(Ok(connection))) => {
413                            tracing::trace!("Accepted connection");
414                            let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
415                            connection_tasks.spawn(conn);
416                        }
417                        Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
418                        Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
419                        None => tracing::error!("Join set was polled even though it was empty"),
420                    }
421                },
422
423                // Poll on the JoinSet to collect finished connections
424                res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
425                    match res {
426                        Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
427                        Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
428                        Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
429                        None => tracing::error!("Join set was polled even though it was empty"),
430                    }
431                },
432
433                // Handle when we are asked to hard shutdown
434                () = hard_shutdown_token.cancelled() => {
435                    tracing::warn!(
436                        "Forcing shutdown ({active} active connections, {pending} pending connections)",
437                        active = connection_tasks.len(),
438                        pending = accept_tasks.len(),
439                    );
440                    break;
441                },
442            }
443        }
444    }
445
446    accept_tasks.shutdown().await;
447    connection_tasks.shutdown().await;
448}