mas_listener/proxy_protocol/
acceptor.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use bytes::BytesMut;
8use thiserror::Error;
9use tokio::io::{AsyncRead, AsyncReadExt};
10
11use super::ProxyProtocolV1Info;
12use crate::rewind::Rewind;
13
14#[derive(Clone, Copy, Debug, Default)]
15pub struct ProxyAcceptor {
16    _private: (),
17}
18
19#[derive(Debug, Error)]
20#[error(transparent)]
21pub enum ProxyAcceptError {
22    Parse(#[from] super::v1::ParseError),
23    Read(#[from] std::io::Error),
24}
25
26impl ProxyAcceptor {
27    #[must_use]
28    pub const fn new() -> Self {
29        Self { _private: () }
30    }
31
32    /// Accept a proxy-protocol stream
33    ///
34    /// # Errors
35    ///
36    /// Returns an error on read error on the underlying stream, or when the
37    /// proxy protocol preamble couldn't be parsed
38    pub async fn accept<T>(
39        &self,
40        mut stream: T,
41    ) -> Result<(ProxyProtocolV1Info, Rewind<T>), ProxyAcceptError>
42    where
43        T: AsyncRead + Unpin,
44    {
45        let mut buf = BytesMut::new();
46        let info = loop {
47            stream.read_buf(&mut buf).await?;
48
49            match ProxyProtocolV1Info::parse(&mut buf) {
50                Ok(info) => break info,
51                Err(e) if e.not_enough_bytes() => {}
52                Err(e) => return Err(e.into()),
53            }
54        };
55
56        let stream = Rewind::new_buffered(stream, buf.into());
57
58        Ok((info, stream))
59    }
60}