1use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11 time::Duration,
12};
13
14use futures_util::{StreamExt, stream::SelectAll};
15use hyper::{Request, Response};
16use hyper_util::{
17 rt::{TokioExecutor, TokioIo},
18 server::conn::auto::Connection,
19 service::TowerToHyperService,
20};
21use pin_project_lite::pin_project;
22use thiserror::Error;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
25use tower::Service;
26use tower_http::add_extension::AddExtension;
27use tracing::Instrument;
28
29use crate::{
30 ConnectionInfo,
31 maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo},
32 proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
33 rewind::Rewind,
34 unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
35};
36
37const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
39
40pub struct Server<S> {
41 tls: Option<Arc<ServerConfig>>,
42 proxy: bool,
43 listener: UnixOrTcpListener,
44 service: S,
45}
46
47impl<S> Server<S> {
48 pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
52 where
53 L: TryInto<UnixOrTcpListener>,
54 {
55 Ok(Self {
56 tls: None,
57 proxy: false,
58 listener: listener.try_into()?,
59 service,
60 })
61 }
62
63 #[must_use]
64 pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
65 Self {
66 tls: None,
67 proxy: false,
68 listener: listener.into(),
69 service,
70 }
71 }
72
73 #[must_use]
74 pub const fn with_proxy(mut self) -> Self {
75 self.proxy = true;
76 self
77 }
78
79 #[must_use]
80 pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
81 self.tls = Some(config);
82 self
83 }
84
85 pub async fn run<B>(
87 self,
88 soft_shutdown_token: CancellationToken,
89 hard_shutdown_token: CancellationToken,
90 ) where
91 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
92 S::Future: Send + 'static,
93 S::Error: std::error::Error + Send + Sync + 'static,
94 B: http_body::Body + Send + 'static,
95 B::Data: Send,
96 B::Error: std::error::Error + Send + Sync + 'static,
97 {
98 run_servers(
99 std::iter::once(self),
100 soft_shutdown_token,
101 hard_shutdown_token,
102 )
103 .await;
104 }
105}
106
107#[derive(Debug, Error)]
108#[non_exhaustive]
109enum AcceptError {
110 #[error("failed to accept connection from the underlying socket")]
111 Socket {
112 #[source]
113 source: std::io::Error,
114 },
115
116 #[error("failed to complete the TLS handshake")]
117 TlsHandshake {
118 #[source]
119 source: std::io::Error,
120 },
121
122 #[error("failed to complete the proxy protocol handshake")]
123 ProxyHandshake {
124 #[source]
125 source: ProxyAcceptError,
126 },
127
128 #[error("connection handshake timed out")]
129 HandshakeTimeout {
130 #[source]
131 source: tokio::time::error::Elapsed,
132 },
133}
134
135impl AcceptError {
136 fn socket(source: std::io::Error) -> Self {
137 Self::Socket { source }
138 }
139
140 fn tls_handshake(source: std::io::Error) -> Self {
141 Self::TlsHandshake { source }
142 }
143
144 fn proxy_handshake(source: ProxyAcceptError) -> Self {
145 Self::ProxyHandshake { source }
146 }
147
148 fn handshake_timeout(source: tokio::time::error::Elapsed) -> Self {
149 Self::HandshakeTimeout { source }
150 }
151}
152
153#[allow(clippy::type_complexity)]
159#[tracing::instrument(
160 name = "accept",
161 skip_all,
162 fields(
163 network.protocol.name = "http",
164 network.peer.address,
165 network.peer.port,
166 ),
167 err,
168)]
169async fn accept<S, B>(
170 maybe_proxy_acceptor: &MaybeProxyAcceptor,
171 maybe_tls_acceptor: &MaybeTlsAcceptor,
172 peer_addr: SocketAddr,
173 stream: UnixOrTcpConnection,
174 service: S,
175) -> Result<
176 Connection<
177 'static,
178 TokioIo<MaybeTlsStream<Rewind<UnixOrTcpConnection>>>,
179 TowerToHyperService<AddExtension<S, ConnectionInfo>>,
180 TokioExecutor,
181 >,
182 AcceptError,
183>
184where
185 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
186 S::Error: std::error::Error + Send + Sync + 'static,
187 S::Future: Send + 'static,
188 B: http_body::Body + Send + 'static,
189 B::Data: Send,
190 B::Error: std::error::Error + Send + Sync + 'static,
191{
192 let span = tracing::Span::current();
193
194 match peer_addr {
195 SocketAddr::Net(addr) => {
196 span.record("network.peer.address", tracing::field::display(addr.ip()));
197 span.record("network.peer.port", addr.port());
198 }
199 SocketAddr::Unix(ref addr) => {
200 span.record("network.peer.address", tracing::field::debug(addr));
201 }
202 }
203
204 tokio::time::timeout(HANDSHAKE_TIMEOUT, async move {
206 let (proxy, stream) = maybe_proxy_acceptor
207 .accept(stream)
208 .await
209 .map_err(AcceptError::proxy_handshake)?;
210
211 let stream = maybe_tls_acceptor
212 .accept(stream)
213 .await
214 .map_err(AcceptError::tls_handshake)?;
215
216 let tls = stream.tls_info();
217
218 let is_h2 = tls.as_ref().is_some_and(TlsStreamInfo::is_alpn_h2);
220
221 let info = ConnectionInfo {
222 tls,
223 proxy,
224 net_peer_addr: peer_addr.into_net(),
225 };
226
227 let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
228 if is_h2 {
229 builder = builder.http2_only();
230 }
231 builder.http1().keep_alive(true);
232
233 let service = TowerToHyperService::new(AddExtension::new(service, info));
234
235 let conn = builder
236 .serve_connection(TokioIo::new(stream), service)
237 .into_owned();
238
239 Ok(conn)
240 })
241 .instrument(span)
242 .await
243 .map_err(AcceptError::handshake_timeout)?
244}
245
246pin_project! {
247 struct AbortableConnection<C> {
257 #[pin]
258 connection: C,
259 #[pin]
260 cancellation_future: WaitForCancellationFutureOwned,
261 did_start_shutdown: bool,
262 }
263}
264
265impl<C> AbortableConnection<C> {
266 fn new(connection: C, cancellation_token: CancellationToken) -> Self {
267 Self {
268 connection,
269 cancellation_future: cancellation_token.cancelled_owned(),
270 did_start_shutdown: false,
271 }
272 }
273}
274
275impl<T, S, B> Future
276 for AbortableConnection<Connection<'static, T, TowerToHyperService<S>, TokioExecutor>>
277where
278 Connection<'static, T, TowerToHyperService<S>, TokioExecutor>: Future,
279 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
280 S::Future: Send + 'static,
281 S::Error: std::error::Error + Send + Sync,
282 T: hyper::rt::Read + hyper::rt::Write + Unpin,
283 B: http_body::Body + Send + 'static,
284 B::Data: Send,
285 B::Error: std::error::Error + Send + Sync + 'static,
286{
287 type Output = <Connection<'static, T, TowerToHyperService<S>, TokioExecutor> as Future>::Output;
288
289 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
290 let mut this = self.project();
291
292 if let Poll::Ready(()) = this.cancellation_future.poll(cx) {
293 if !*this.did_start_shutdown {
294 *this.did_start_shutdown = true;
295 this.connection.as_mut().graceful_shutdown();
296 }
297 }
298
299 this.connection.poll(cx)
300 }
301}
302
303#[allow(clippy::too_many_lines)]
304pub async fn run_servers<S, B>(
305 listeners: impl IntoIterator<Item = Server<S>>,
306 soft_shutdown_token: CancellationToken,
307 hard_shutdown_token: CancellationToken,
308) where
309 S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
310 S::Future: Send + 'static,
311 S::Error: std::error::Error + Send + Sync + 'static,
312 B: http_body::Body + Send + 'static,
313 B::Data: Send,
314 B::Error: std::error::Error + Send + Sync + 'static,
315{
316 let _guard = soft_shutdown_token.clone().drop_guard();
319
320 let mut accept_stream: SelectAll<_> = listeners
322 .into_iter()
323 .map(|server| {
324 let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
325 let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
326 futures_util::stream::poll_fn(move |cx| {
327 let res =
328 std::task::ready!(server.listener.poll_accept(cx)).map(|(addr, stream)| {
329 (
330 maybe_proxy_acceptor,
331 maybe_tls_acceptor.clone(),
332 server.service.clone(),
333 addr,
334 stream,
335 )
336 });
337 Poll::Ready(Some(res))
338 })
339 })
340 .collect();
341
342 let mut accept_tasks = tokio::task::JoinSet::new();
344 let mut connection_tasks = tokio::task::JoinSet::new();
346
347 loop {
348 tokio::select! {
349 biased;
350
351 () = soft_shutdown_token.cancelled() => {
353 tracing::debug!("Shutting down listeners");
354 break;
355 },
356
357 res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
359 match res {
360 Some(Ok(Ok(connection))) => {
361 tracing::trace!("Accepted connection");
362 let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
363 connection_tasks.spawn(conn);
364 },
365 Some(Ok(Err(_e))) => { },
366 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
367 None => tracing::error!("Join set was polled even though it was empty"),
368 }
369 },
370
371 res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
373 match res {
374 Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
375 Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
376 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
377 None => tracing::error!("Join set was polled even though it was empty"),
378 }
379 },
380
381 res = accept_stream.next() => {
383 let Some(res) = res else { continue };
384
385 accept_tasks.spawn(async move {
389 let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = res
390 .map_err(AcceptError::socket)?;
391 accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await
392 });
393 },
394 };
395 }
396
397 if !accept_tasks.is_empty() || !connection_tasks.is_empty() {
399 tracing::info!(
400 "There are {active} active connections ({pending} pending), performing a graceful shutdown. Send the shutdown signal again to force.",
401 active = connection_tasks.len(),
402 pending = accept_tasks.len(),
403 );
404
405 while !accept_tasks.is_empty() || !connection_tasks.is_empty() {
406 tokio::select! {
407 biased;
408
409 res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
411 match res {
412 Some(Ok(Ok(connection))) => {
413 tracing::trace!("Accepted connection");
414 let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token());
415 connection_tasks.spawn(conn);
416 }
417 Some(Ok(Err(_e))) => { },
418 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
419 None => tracing::error!("Join set was polled even though it was empty"),
420 }
421 },
422
423 res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
425 match res {
426 Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
427 Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
428 Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
429 None => tracing::error!("Join set was polled even though it was empty"),
430 }
431 },
432
433 () = hard_shutdown_token.cancelled() => {
435 tracing::warn!(
436 "Forcing shutdown ({active} active connections, {pending} pending connections)",
437 active = connection_tasks.len(),
438 pending = accept_tasks.len(),
439 );
440 break;
441 },
442 }
443 }
444 }
445
446 accept_tasks.shutdown().await;
447 connection_tasks.shutdown().await;
448}