mas_listener/proxy_protocol/
maybe.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use tokio::io::AsyncRead;
8
9use super::{ProxyAcceptor, ProxyProtocolV1Info, acceptor::ProxyAcceptError};
10use crate::rewind::Rewind;
11
12#[derive(Clone, Copy)]
13pub struct MaybeProxyAcceptor {
14    acceptor: Option<ProxyAcceptor>,
15}
16
17impl MaybeProxyAcceptor {
18    #[must_use]
19    pub const fn new(proxied: bool) -> Self {
20        let acceptor = if proxied {
21            Some(ProxyAcceptor::new())
22        } else {
23            None
24        };
25
26        Self { acceptor }
27    }
28
29    #[must_use]
30    pub const fn new_proxied(acceptor: ProxyAcceptor) -> Self {
31        Self {
32            acceptor: Some(acceptor),
33        }
34    }
35
36    #[must_use]
37    pub const fn new_unproxied() -> Self {
38        Self { acceptor: None }
39    }
40
41    #[must_use]
42    pub const fn is_proxied(&self) -> bool {
43        self.acceptor.is_some()
44    }
45
46    /// Accept a connection and do the proxy protocol handshake
47    ///
48    /// # Errors
49    ///
50    /// Returns an error if the proxy protocol handshake failed
51    pub async fn accept<T>(
52        &self,
53        stream: T,
54    ) -> Result<(Option<ProxyProtocolV1Info>, Rewind<T>), ProxyAcceptError>
55    where
56        T: AsyncRead + Unpin,
57    {
58        if let Some(acceptor) = self.acceptor {
59            let (info, stream) = acceptor.accept(stream).await?;
60            Ok((Some(info), stream))
61        } else {
62            let stream = Rewind::new(stream);
63            Ok((None, stream))
64        }
65    }
66}