mas_listener/proxy_protocol/
v1.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{
8    net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
9    num::ParseIntError,
10    str::Utf8Error,
11};
12
13use bytes::Buf;
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub enum ProxyProtocolV1Info {
18    Tcp {
19        source: SocketAddr,
20        destination: SocketAddr,
21    },
22    Udp {
23        source: SocketAddr,
24        destination: SocketAddr,
25    },
26    Unknown,
27}
28
29#[derive(Error, Debug)]
30#[error("Invalid proxy protocol header")]
31pub enum ParseError {
32    #[error("Not enough bytes provided")]
33    NotEnoughBytes,
34    NoCrLf,
35    NoProxyPreamble,
36    NoProtocol,
37    InvalidProtocol,
38    NoSourceAddress,
39    NoDestinationAddress,
40    NoSourcePort,
41    NoDestinationPort,
42    TooManyFields,
43    InvalidUtf8(#[from] Utf8Error),
44    InvalidAddress(#[from] AddrParseError),
45    InvalidPort(#[from] ParseIntError),
46}
47
48impl ParseError {
49    pub const fn not_enough_bytes(&self) -> bool {
50        matches!(self, &Self::NotEnoughBytes)
51    }
52}
53
54impl ProxyProtocolV1Info {
55    pub(super) fn parse<B>(buf: &mut B) -> Result<Self, ParseError>
56    where
57        B: Buf + AsRef<[u8]>,
58    {
59        use ParseError as E;
60        // First, check if we *possibly* have enough bytes.
61        // Minimum is 15: "PROXY UNKNOWN\r\n"
62
63        if buf.remaining() < 15 {
64            return Err(E::NotEnoughBytes);
65        }
66
67        // Let's check in the first 108 bytes if we find a CRLF
68        let Some(crlf) = buf
69            .as_ref()
70            .windows(2)
71            .take(108)
72            .position(|needle| needle == [0x0D, 0x0A])
73        else {
74            // If not, it might be because we don't have enough bytes
75            return if buf.remaining() < 108 {
76                Err(E::NotEnoughBytes)
77            } else {
78                // Else it's just invalid
79                Err(E::NoCrLf)
80            };
81        };
82
83        // Trim to everything before the CRLF
84        let bytes = &buf.as_ref()[..crlf];
85
86        let mut it = bytes.splitn(6, |c| c == &b' ');
87        // Check for the preamble
88        if it.next() != Some(b"PROXY") {
89            return Err(E::NoProxyPreamble);
90        }
91
92        let result = match it.next() {
93            Some(b"TCP4") => {
94                let source_address: Ipv4Addr =
95                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
96                let destination_address: Ipv4Addr =
97                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
98                let source_port: u16 =
99                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
100                let destination_port: u16 =
101                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
102                if it.next().is_some() {
103                    return Err(E::TooManyFields);
104                }
105
106                let source = (source_address, source_port).into();
107                let destination = (destination_address, destination_port).into();
108
109                Self::Tcp {
110                    source,
111                    destination,
112                }
113            }
114            Some(b"TCP6") => {
115                let source_address: Ipv6Addr =
116                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
117                let destination_address: Ipv6Addr =
118                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
119                let source_port: u16 =
120                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
121                let destination_port: u16 =
122                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
123                if it.next().is_some() {
124                    return Err(E::TooManyFields);
125                }
126
127                let source = (source_address, source_port).into();
128                let destination = (destination_address, destination_port).into();
129
130                Self::Tcp {
131                    source,
132                    destination,
133                }
134            }
135            Some(b"UDP4") => {
136                let source_address: Ipv4Addr =
137                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
138                let destination_address: Ipv4Addr =
139                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
140                let source_port: u16 =
141                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
142                let destination_port: u16 =
143                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
144                if it.next().is_some() {
145                    return Err(E::TooManyFields);
146                }
147
148                let source = (source_address, source_port).into();
149                let destination = (destination_address, destination_port).into();
150
151                Self::Udp {
152                    source,
153                    destination,
154                }
155            }
156            Some(b"UDP6") => {
157                let source_address: Ipv6Addr =
158                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
159                let destination_address: Ipv6Addr =
160                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
161                let source_port: u16 =
162                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
163                let destination_port: u16 =
164                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
165                if it.next().is_some() {
166                    return Err(E::TooManyFields);
167                }
168
169                let source = (source_address, source_port).into();
170                let destination = (destination_address, destination_port).into();
171
172                Self::Udp {
173                    source,
174                    destination,
175                }
176            }
177            Some(b"UNKNOWN") => Self::Unknown,
178            Some(_) => return Err(E::InvalidProtocol),
179            None => return Err(E::NoProtocol),
180        };
181
182        buf.advance(crlf + 2);
183
184        Ok(result)
185    }
186
187    #[must_use]
188    pub fn is_ipv4(&self) -> bool {
189        match self {
190            Self::Udp {
191                source,
192                destination,
193            }
194            | Self::Tcp {
195                source,
196                destination,
197            } => source.is_ipv4() && destination.is_ipv4(),
198            Self::Unknown => false,
199        }
200    }
201
202    #[must_use]
203    pub fn is_ipv6(&self) -> bool {
204        match self {
205            Self::Udp {
206                source,
207                destination,
208            }
209            | Self::Tcp {
210                source,
211                destination,
212            } => source.is_ipv6() && destination.is_ipv6(),
213            Self::Unknown => false,
214        }
215    }
216
217    #[must_use]
218    pub const fn is_tcp(&self) -> bool {
219        matches!(self, Self::Tcp { .. })
220    }
221
222    #[must_use]
223    pub const fn is_udp(&self) -> bool {
224        matches!(self, Self::Udp { .. })
225    }
226
227    #[must_use]
228    pub const fn is_unknown(&self) -> bool {
229        matches!(self, Self::Unknown)
230    }
231
232    #[must_use]
233    pub const fn source(&self) -> Option<&SocketAddr> {
234        match self {
235            Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
236            Self::Unknown => None,
237        }
238    }
239
240    #[must_use]
241    pub const fn destination(&self) -> Option<&SocketAddr> {
242        match self {
243            Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
244            Self::Unknown => None,
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_parse() {
255        let mut buf =
256            b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world".as_slice();
257        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
258        assert_eq!(buf, b"hello world");
259        assert!(info.is_tcp());
260        assert!(!info.is_udp());
261        assert!(!info.is_unknown());
262        assert!(info.is_ipv4());
263        assert!(!info.is_ipv6());
264
265        let mut buf =
266            b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
267            .as_slice();
268        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
269        assert_eq!(buf, b"hello world");
270        assert!(info.is_tcp());
271        assert!(!info.is_udp());
272        assert!(!info.is_unknown());
273        assert!(!info.is_ipv4());
274        assert!(info.is_ipv6());
275
276        let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
277        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
278        assert_eq!(buf, b"hello world");
279        assert!(!info.is_tcp());
280        assert!(!info.is_udp());
281        assert!(info.is_unknown());
282        assert!(!info.is_ipv4());
283        assert!(!info.is_ipv6());
284
285        let mut buf =
286            b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
287            .as_slice();
288        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
289        assert_eq!(buf, b"hello world");
290        assert!(!info.is_tcp());
291        assert!(!info.is_udp());
292        assert!(info.is_unknown());
293        assert!(!info.is_ipv4());
294        assert!(!info.is_ipv6());
295    }
296}