mas_listener/proxy_protocol/
v1.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
9    num::ParseIntError,
10    str::Utf8Error,
11};
12
13use bytes::Buf;
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub enum ProxyProtocolV1Info {
18    Tcp {
19        source: SocketAddr,
20        destination: SocketAddr,
21    },
22    Udp {
23        source: SocketAddr,
24        destination: SocketAddr,
25    },
26    Unknown,
27}
28
29#[derive(Error, Debug)]
30#[error("Invalid proxy protocol header")]
31pub enum ParseError {
32    #[error("Not enough bytes provided")]
33    NotEnoughBytes,
34    NoCrLf,
35    NoProxyPreamble,
36    NoProtocol,
37    InvalidProtocol,
38    NoSourceAddress,
39    NoDestinationAddress,
40    NoSourcePort,
41    NoDestinationPort,
42    TooManyFields,
43    InvalidUtf8(#[from] Utf8Error),
44    InvalidAddress(#[from] AddrParseError),
45    InvalidPort(#[from] ParseIntError),
46}
47
48impl ParseError {
49    pub const fn not_enough_bytes(&self) -> bool {
50        matches!(self, &Self::NotEnoughBytes)
51    }
52}
53
54impl ProxyProtocolV1Info {
55    #[allow(clippy::too_many_lines)]
56    pub(super) fn parse<B>(buf: &mut B) -> Result<Self, ParseError>
57    where
58        B: Buf + AsRef<[u8]>,
59    {
60        use ParseError as E;
61        // First, check if we *possibly* have enough bytes.
62        // Minimum is 15: "PROXY UNKNOWN\r\n"
63
64        if buf.remaining() < 15 {
65            return Err(E::NotEnoughBytes);
66        }
67
68        // Let's check in the first 108 bytes if we find a CRLF
69        let Some(crlf) = buf
70            .as_ref()
71            .windows(2)
72            .take(108)
73            .position(|needle| needle == [0x0D, 0x0A])
74        else {
75            // If not, it might be because we don't have enough bytes
76            return if buf.remaining() < 108 {
77                Err(E::NotEnoughBytes)
78            } else {
79                // Else it's just invalid
80                Err(E::NoCrLf)
81            };
82        };
83
84        // Trim to everything before the CRLF
85        let bytes = &buf.as_ref()[..crlf];
86
87        let mut it = bytes.splitn(6, |c| c == &b' ');
88        // Check for the preamble
89        if it.next() != Some(b"PROXY") {
90            return Err(E::NoProxyPreamble);
91        }
92
93        let result = match it.next() {
94            Some(b"TCP4") => {
95                let source_address: Ipv4Addr =
96                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
97                let destination_address: Ipv4Addr =
98                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
99                let source_port: u16 =
100                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
101                let destination_port: u16 =
102                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
103                if it.next().is_some() {
104                    return Err(E::TooManyFields);
105                }
106
107                let source = (source_address, source_port).into();
108                let destination = (destination_address, destination_port).into();
109
110                Self::Tcp {
111                    source,
112                    destination,
113                }
114            }
115            Some(b"TCP6") => {
116                let source_address: Ipv6Addr =
117                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
118                let destination_address: Ipv6Addr =
119                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
120                let source_port: u16 =
121                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
122                let destination_port: u16 =
123                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
124                if it.next().is_some() {
125                    return Err(E::TooManyFields);
126                }
127
128                let source = (source_address, source_port).into();
129                let destination = (destination_address, destination_port).into();
130
131                Self::Tcp {
132                    source,
133                    destination,
134                }
135            }
136            Some(b"UDP4") => {
137                let source_address: Ipv4Addr =
138                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
139                let destination_address: Ipv4Addr =
140                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
141                let source_port: u16 =
142                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
143                let destination_port: u16 =
144                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
145                if it.next().is_some() {
146                    return Err(E::TooManyFields);
147                }
148
149                let source = (source_address, source_port).into();
150                let destination = (destination_address, destination_port).into();
151
152                Self::Udp {
153                    source,
154                    destination,
155                }
156            }
157            Some(b"UDP6") => {
158                let source_address: Ipv6Addr =
159                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
160                let destination_address: Ipv6Addr =
161                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
162                let source_port: u16 =
163                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
164                let destination_port: u16 =
165                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
166                if it.next().is_some() {
167                    return Err(E::TooManyFields);
168                }
169
170                let source = (source_address, source_port).into();
171                let destination = (destination_address, destination_port).into();
172
173                Self::Udp {
174                    source,
175                    destination,
176                }
177            }
178            Some(b"UNKNOWN") => Self::Unknown,
179            Some(_) => return Err(E::InvalidProtocol),
180            None => return Err(E::NoProtocol),
181        };
182
183        buf.advance(crlf + 2);
184
185        Ok(result)
186    }
187
188    #[must_use]
189    pub fn is_ipv4(&self) -> bool {
190        match self {
191            Self::Udp {
192                source,
193                destination,
194            }
195            | Self::Tcp {
196                source,
197                destination,
198            } => source.is_ipv4() && destination.is_ipv4(),
199            Self::Unknown => false,
200        }
201    }
202
203    #[must_use]
204    pub fn is_ipv6(&self) -> bool {
205        match self {
206            Self::Udp {
207                source,
208                destination,
209            }
210            | Self::Tcp {
211                source,
212                destination,
213            } => source.is_ipv6() && destination.is_ipv6(),
214            Self::Unknown => false,
215        }
216    }
217
218    #[must_use]
219    pub const fn is_tcp(&self) -> bool {
220        matches!(self, Self::Tcp { .. })
221    }
222
223    #[must_use]
224    pub const fn is_udp(&self) -> bool {
225        matches!(self, Self::Udp { .. })
226    }
227
228    #[must_use]
229    pub const fn is_unknown(&self) -> bool {
230        matches!(self, Self::Unknown)
231    }
232
233    #[must_use]
234    pub const fn source(&self) -> Option<&SocketAddr> {
235        match self {
236            Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
237            Self::Unknown => None,
238        }
239    }
240
241    #[must_use]
242    pub const fn destination(&self) -> Option<&SocketAddr> {
243        match self {
244            Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
245            Self::Unknown => None,
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_parse() {
256        let mut buf =
257            b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world".as_slice();
258        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
259        assert_eq!(buf, b"hello world");
260        assert!(info.is_tcp());
261        assert!(!info.is_udp());
262        assert!(!info.is_unknown());
263        assert!(info.is_ipv4());
264        assert!(!info.is_ipv6());
265
266        let mut buf =
267            b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
268            .as_slice();
269        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
270        assert_eq!(buf, b"hello world");
271        assert!(info.is_tcp());
272        assert!(!info.is_udp());
273        assert!(!info.is_unknown());
274        assert!(!info.is_ipv4());
275        assert!(info.is_ipv6());
276
277        let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
278        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
279        assert_eq!(buf, b"hello world");
280        assert!(!info.is_tcp());
281        assert!(!info.is_udp());
282        assert!(info.is_unknown());
283        assert!(!info.is_ipv4());
284        assert!(!info.is_ipv6());
285
286        let mut buf =
287            b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
288            .as_slice();
289        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
290        assert_eq!(buf, b"hello world");
291        assert!(!info.is_tcp());
292        assert!(!info.is_udp());
293        assert!(info.is_unknown());
294        assert!(!info.is_ipv4());
295        assert!(!info.is_ipv6());
296    }
297}