mas_listener/proxy_protocol/
v1.rs1use std::{
8 net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
9 num::ParseIntError,
10 str::Utf8Error,
11};
12
13use bytes::Buf;
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub enum ProxyProtocolV1Info {
18 Tcp {
19 source: SocketAddr,
20 destination: SocketAddr,
21 },
22 Udp {
23 source: SocketAddr,
24 destination: SocketAddr,
25 },
26 Unknown,
27}
28
29#[derive(Error, Debug)]
30#[error("Invalid proxy protocol header")]
31pub enum ParseError {
32 #[error("Not enough bytes provided")]
33 NotEnoughBytes,
34 NoCrLf,
35 NoProxyPreamble,
36 NoProtocol,
37 InvalidProtocol,
38 NoSourceAddress,
39 NoDestinationAddress,
40 NoSourcePort,
41 NoDestinationPort,
42 TooManyFields,
43 InvalidUtf8(#[from] Utf8Error),
44 InvalidAddress(#[from] AddrParseError),
45 InvalidPort(#[from] ParseIntError),
46}
47
48impl ParseError {
49 pub const fn not_enough_bytes(&self) -> bool {
50 matches!(self, &Self::NotEnoughBytes)
51 }
52}
53
54impl ProxyProtocolV1Info {
55 pub(super) fn parse<B>(buf: &mut B) -> Result<Self, ParseError>
56 where
57 B: Buf + AsRef<[u8]>,
58 {
59 use ParseError as E;
60 if buf.remaining() < 15 {
64 return Err(E::NotEnoughBytes);
65 }
66
67 let Some(crlf) = buf
69 .as_ref()
70 .windows(2)
71 .take(108)
72 .position(|needle| needle == [0x0D, 0x0A])
73 else {
74 return if buf.remaining() < 108 {
76 Err(E::NotEnoughBytes)
77 } else {
78 Err(E::NoCrLf)
80 };
81 };
82
83 let bytes = &buf.as_ref()[..crlf];
85
86 let mut it = bytes.splitn(6, |c| c == &b' ');
87 if it.next() != Some(b"PROXY") {
89 return Err(E::NoProxyPreamble);
90 }
91
92 let result = match it.next() {
93 Some(b"TCP4") => {
94 let source_address: Ipv4Addr =
95 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
96 let destination_address: Ipv4Addr =
97 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
98 let source_port: u16 =
99 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
100 let destination_port: u16 =
101 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
102 if it.next().is_some() {
103 return Err(E::TooManyFields);
104 }
105
106 let source = (source_address, source_port).into();
107 let destination = (destination_address, destination_port).into();
108
109 Self::Tcp {
110 source,
111 destination,
112 }
113 }
114 Some(b"TCP6") => {
115 let source_address: Ipv6Addr =
116 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
117 let destination_address: Ipv6Addr =
118 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
119 let source_port: u16 =
120 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
121 let destination_port: u16 =
122 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
123 if it.next().is_some() {
124 return Err(E::TooManyFields);
125 }
126
127 let source = (source_address, source_port).into();
128 let destination = (destination_address, destination_port).into();
129
130 Self::Tcp {
131 source,
132 destination,
133 }
134 }
135 Some(b"UDP4") => {
136 let source_address: Ipv4Addr =
137 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
138 let destination_address: Ipv4Addr =
139 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
140 let source_port: u16 =
141 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
142 let destination_port: u16 =
143 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
144 if it.next().is_some() {
145 return Err(E::TooManyFields);
146 }
147
148 let source = (source_address, source_port).into();
149 let destination = (destination_address, destination_port).into();
150
151 Self::Udp {
152 source,
153 destination,
154 }
155 }
156 Some(b"UDP6") => {
157 let source_address: Ipv6Addr =
158 std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
159 let destination_address: Ipv6Addr =
160 std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
161 let source_port: u16 =
162 std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
163 let destination_port: u16 =
164 std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
165 if it.next().is_some() {
166 return Err(E::TooManyFields);
167 }
168
169 let source = (source_address, source_port).into();
170 let destination = (destination_address, destination_port).into();
171
172 Self::Udp {
173 source,
174 destination,
175 }
176 }
177 Some(b"UNKNOWN") => Self::Unknown,
178 Some(_) => return Err(E::InvalidProtocol),
179 None => return Err(E::NoProtocol),
180 };
181
182 buf.advance(crlf + 2);
183
184 Ok(result)
185 }
186
187 #[must_use]
188 pub fn is_ipv4(&self) -> bool {
189 match self {
190 Self::Udp {
191 source,
192 destination,
193 }
194 | Self::Tcp {
195 source,
196 destination,
197 } => source.is_ipv4() && destination.is_ipv4(),
198 Self::Unknown => false,
199 }
200 }
201
202 #[must_use]
203 pub fn is_ipv6(&self) -> bool {
204 match self {
205 Self::Udp {
206 source,
207 destination,
208 }
209 | Self::Tcp {
210 source,
211 destination,
212 } => source.is_ipv6() && destination.is_ipv6(),
213 Self::Unknown => false,
214 }
215 }
216
217 #[must_use]
218 pub const fn is_tcp(&self) -> bool {
219 matches!(self, Self::Tcp { .. })
220 }
221
222 #[must_use]
223 pub const fn is_udp(&self) -> bool {
224 matches!(self, Self::Udp { .. })
225 }
226
227 #[must_use]
228 pub const fn is_unknown(&self) -> bool {
229 matches!(self, Self::Unknown)
230 }
231
232 #[must_use]
233 pub const fn source(&self) -> Option<&SocketAddr> {
234 match self {
235 Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
236 Self::Unknown => None,
237 }
238 }
239
240 #[must_use]
241 pub const fn destination(&self) -> Option<&SocketAddr> {
242 match self {
243 Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
244 Self::Unknown => None,
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_parse() {
255 let mut buf =
256 b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world".as_slice();
257 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
258 assert_eq!(buf, b"hello world");
259 assert!(info.is_tcp());
260 assert!(!info.is_udp());
261 assert!(!info.is_unknown());
262 assert!(info.is_ipv4());
263 assert!(!info.is_ipv6());
264
265 let mut buf =
266 b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
267 .as_slice();
268 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
269 assert_eq!(buf, b"hello world");
270 assert!(info.is_tcp());
271 assert!(!info.is_udp());
272 assert!(!info.is_unknown());
273 assert!(!info.is_ipv4());
274 assert!(info.is_ipv6());
275
276 let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
277 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
278 assert_eq!(buf, b"hello world");
279 assert!(!info.is_tcp());
280 assert!(!info.is_udp());
281 assert!(info.is_unknown());
282 assert!(!info.is_ipv4());
283 assert!(!info.is_ipv6());
284
285 let mut buf =
286 b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
287 .as_slice();
288 let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
289 assert_eq!(buf, b"hello world");
290 assert!(!info.is_tcp());
291 assert!(!info.is_udp());
292 assert!(info.is_unknown());
293 assert!(!info.is_ipv4());
294 assert!(!info.is_ipv6());
295 }
296}