mas_listener/proxy_protocol/
v1.rs1use std::{
8 net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
9 num::ParseIntError,
10 str::Utf8Error,
11};
12
13use bytes::Buf;
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub enum ProxyProtocolV1Info {
18 Tcp {
19 source: SocketAddr,
20 destination: SocketAddr,
21 },
22 Udp {
23 source: SocketAddr,
24 destination: SocketAddr,
25 },
26 Unknown,
27}
28
29#[derive(Error, Debug)]
30#[error("Invalid proxy protocol header")]
31pub enum ParseError {
32 #[error("Not enough bytes provided")]
33 NotEnoughBytes,
34 NoCrLf,
35 NoProxyPreamble,
36 NoProtocol,
37 InvalidProtocol,
38 NoSourceAddress,
39 NoDestinationAddress,
40 NoSourcePort,
41 NoDestinationPort,
42 TooManyFields,
43 InvalidUtf8(#[from] Utf8Error),
44 InvalidAddress(#[from] AddrParseError),
45 InvalidPort(#[from] ParseIntError),
46}
47
48impl ParseError {
49 pub const fn not_enough_bytes(&self) -> bool {
50 matches!(self, &Self::NotEnoughBytes)
51 }
52}
53
54impl ProxyProtocolV1Info {
55 #[allow(clippy::too_many_lines)]
56 pub(super) fn parse<B>(buf: &mut B) -> Result<Self, ParseError>
57 where
58 B: Buf + AsRef<[u8]>,
59 {
60 use ParseError as E;
61 if buf.remaining() < 15 {
65 return Err(E::NotEnoughBytes);
66 }
67
68 let Some(crlf) = buf
70 .as_ref()
71 .windows(2)
72 .take(108)
73 .position(|needle| needle == [0x0D, 0x0A])
74 else {
75 return if buf.remaining() < 108 {
77 Err(E::NotEnoughBytes)
78 } else {
79 Err(E::NoCrLf)
81 };
82 };
83
84 let bytes = &buf.as_ref()[..crlf];
86
87 let mut it = bytes.splitn(6, |c| c == &b' ');
88 if it.next() != Some(b"PROXY") {
90 return Err(E::NoProxyPreamble);
91 }
92
93 let result = match it.next() {
94 Some(b"TCP4") => {
95 let source_address: Ipv4Addr =
96 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
97 let destination_address: Ipv4Addr =
98 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
99 let source_port: u16 =
100 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
101 let destination_port: u16 =
102 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
103 if it.next().is_some() {
104 return Err(E::TooManyFields);
105 }
106
107 let source = (source_address, source_port).into();
108 let destination = (destination_address, destination_port).into();
109
110 Self::Tcp {
111 source,
112 destination,
113 }
114 }
115 Some(b"TCP6") => {
116 let source_address: Ipv6Addr =
117 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
118 let destination_address: Ipv6Addr =
119 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
120 let source_port: u16 =
121 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
122 let destination_port: u16 =
123 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
124 if it.next().is_some() {
125 return Err(E::TooManyFields);
126 }
127
128 let source = (source_address, source_port).into();
129 let destination = (destination_address, destination_port).into();
130
131 Self::Tcp {
132 source,
133 destination,
134 }
135 }
136 Some(b"UDP4") => {
137 let source_address: Ipv4Addr =
138 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
139 let destination_address: Ipv4Addr =
140 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
141 let source_port: u16 =
142 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
143 let destination_port: u16 =
144 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
145 if it.next().is_some() {
146 return Err(E::TooManyFields);
147 }
148
149 let source = (source_address, source_port).into();
150 let destination = (destination_address, destination_port).into();
151
152 Self::Udp {
153 source,
154 destination,
155 }
156 }
157 Some(b"UDP6") => {
158 let source_address: Ipv6Addr =
159 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
160 let destination_address: Ipv6Addr =
161 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
162 let source_port: u16 =
163 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
164 let destination_port: u16 =
165 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
166 if it.next().is_some() {
167 return Err(E::TooManyFields);
168 }
169
170 let source = (source_address, source_port).into();
171 let destination = (destination_address, destination_port).into();
172
173 Self::Udp {
174 source,
175 destination,
176 }
177 }
178 Some(b"UNKNOWN") => Self::Unknown,
179 Some(_) => return Err(E::InvalidProtocol),
180 None => return Err(E::NoProtocol),
181 };
182
183 buf.advance(crlf + 2);
184
185 Ok(result)
186 }
187
188 #[must_use]
189 pub fn is_ipv4(&self) -> bool {
190 match self {
191 Self::Udp {
192 source,
193 destination,
194 }
195 | Self::Tcp {
196 source,
197 destination,
198 } => source.is_ipv4() && destination.is_ipv4(),
199 Self::Unknown => false,
200 }
201 }
202
203 #[must_use]
204 pub fn is_ipv6(&self) -> bool {
205 match self {
206 Self::Udp {
207 source,
208 destination,
209 }
210 | Self::Tcp {
211 source,
212 destination,
213 } => source.is_ipv6() && destination.is_ipv6(),
214 Self::Unknown => false,
215 }
216 }
217
218 #[must_use]
219 pub const fn is_tcp(&self) -> bool {
220 matches!(self, Self::Tcp { .. })
221 }
222
223 #[must_use]
224 pub const fn is_udp(&self) -> bool {
225 matches!(self, Self::Udp { .. })
226 }
227
228 #[must_use]
229 pub const fn is_unknown(&self) -> bool {
230 matches!(self, Self::Unknown)
231 }
232
233 #[must_use]
234 pub const fn source(&self) -> Option<&SocketAddr> {
235 match self {
236 Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
237 Self::Unknown => None,
238 }
239 }
240
241 #[must_use]
242 pub const fn destination(&self) -> Option<&SocketAddr> {
243 match self {
244 Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
245 Self::Unknown => None,
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_parse() {
256 let mut buf =
257 b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world".as_slice();
258 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
259 assert_eq!(buf, b"hello world");
260 assert!(info.is_tcp());
261 assert!(!info.is_udp());
262 assert!(!info.is_unknown());
263 assert!(info.is_ipv4());
264 assert!(!info.is_ipv6());
265
266 let mut buf =
267 b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
268 .as_slice();
269 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
270 assert_eq!(buf, b"hello world");
271 assert!(info.is_tcp());
272 assert!(!info.is_udp());
273 assert!(!info.is_unknown());
274 assert!(!info.is_ipv4());
275 assert!(info.is_ipv6());
276
277 let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
278 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
279 assert_eq!(buf, b"hello world");
280 assert!(!info.is_tcp());
281 assert!(!info.is_udp());
282 assert!(info.is_unknown());
283 assert!(!info.is_ipv4());
284 assert!(!info.is_ipv6());
285
286 let mut buf =
287 b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
288 .as_slice();
289 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
290 assert_eq!(buf, b"hello world");
291 assert!(!info.is_tcp());
292 assert!(!info.is_udp());
293 assert!(info.is_unknown());
294 assert!(!info.is_ipv4());
295 assert!(!info.is_ipv6());
296 }
297}