1use std::{
12 pin::Pin,
13 task::{Context, Poll, ready},
14};
15
16use tokio::{
17 io::{AsyncRead, AsyncWrite},
18 net::{TcpListener, TcpStream, UnixListener, UnixStream},
19};
20
21pub enum SocketAddr {
22 Unix(tokio::net::unix::SocketAddr),
23 Net(std::net::SocketAddr),
24}
25
26impl From<tokio::net::unix::SocketAddr> for SocketAddr {
27 fn from(value: tokio::net::unix::SocketAddr) -> Self {
28 Self::Unix(value)
29 }
30}
31
32impl From<std::net::SocketAddr> for SocketAddr {
33 fn from(value: std::net::SocketAddr) -> Self {
34 Self::Net(value)
35 }
36}
37
38impl std::fmt::Debug for SocketAddr {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Self::Unix(l) => std::fmt::Debug::fmt(l, f),
42 Self::Net(l) => std::fmt::Debug::fmt(l, f),
43 }
44 }
45}
46
47impl SocketAddr {
48 #[must_use]
49 pub fn into_net(self) -> Option<std::net::SocketAddr> {
50 match self {
51 Self::Net(socket) => Some(socket),
52 Self::Unix(_) => None,
53 }
54 }
55
56 #[must_use]
57 pub fn into_unix(self) -> Option<tokio::net::unix::SocketAddr> {
58 match self {
59 Self::Net(_) => None,
60 Self::Unix(socket) => Some(socket),
61 }
62 }
63
64 #[must_use]
65 pub const fn as_net(&self) -> Option<&std::net::SocketAddr> {
66 match self {
67 Self::Net(socket) => Some(socket),
68 Self::Unix(_) => None,
69 }
70 }
71
72 #[must_use]
73 pub const fn as_unix(&self) -> Option<&tokio::net::unix::SocketAddr> {
74 match self {
75 Self::Net(_) => None,
76 Self::Unix(socket) => Some(socket),
77 }
78 }
79}
80
81pub enum UnixOrTcpListener {
82 Unix(UnixListener),
83 Tcp(TcpListener),
84}
85
86impl From<UnixListener> for UnixOrTcpListener {
87 fn from(listener: UnixListener) -> Self {
88 Self::Unix(listener)
89 }
90}
91
92impl From<TcpListener> for UnixOrTcpListener {
93 fn from(listener: TcpListener) -> Self {
94 Self::Tcp(listener)
95 }
96}
97
98impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
99 type Error = std::io::Error;
100
101 fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
102 listener.set_nonblocking(true)?;
103 Ok(Self::Unix(UnixListener::from_std(listener)?))
104 }
105}
106
107impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
108 type Error = std::io::Error;
109
110 fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
111 listener.set_nonblocking(true)?;
112 Ok(Self::Tcp(TcpListener::from_std(listener)?))
113 }
114}
115
116impl UnixOrTcpListener {
117 pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
124 match self {
125 Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
126 Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from),
127 }
128 }
129
130 pub const fn is_unix(&self) -> bool {
131 matches!(self, Self::Unix(_))
132 }
133
134 pub const fn is_tcp(&self) -> bool {
135 matches!(self, Self::Tcp(_))
136 }
137
138 pub async fn accept(&self) -> Result<(SocketAddr, UnixOrTcpConnection), std::io::Error> {
149 match self {
150 Self::Unix(listener) => {
151 let (stream, remote_addr) = listener.accept().await?;
152
153 let socket = socket2::SockRef::from(&stream);
154 socket.set_keepalive(true)?;
155 socket.set_nodelay(true)?;
156
157 Ok((remote_addr.into(), UnixOrTcpConnection::Unix { stream }))
158 }
159 Self::Tcp(listener) => {
160 let (stream, remote_addr) = listener.accept().await?;
161
162 let socket = socket2::SockRef::from(&stream);
163 socket.set_keepalive(true)?;
164 socket.set_nodelay(true)?;
165
166 Ok((remote_addr.into(), UnixOrTcpConnection::Tcp { stream }))
167 }
168 }
169 }
170
171 pub fn poll_accept(
182 &self,
183 cx: &mut Context<'_>,
184 ) -> Poll<Result<(SocketAddr, UnixOrTcpConnection), std::io::Error>> {
185 match self {
186 Self::Unix(listener) => {
187 let (stream, remote_addr) = ready!(listener.poll_accept(cx)?);
188
189 let socket = socket2::SockRef::from(&stream);
190 socket.set_keepalive(true)?;
191 socket.set_nodelay(true)?;
192
193 Poll::Ready(Ok((
194 remote_addr.into(),
195 UnixOrTcpConnection::Unix { stream },
196 )))
197 }
198 Self::Tcp(listener) => {
199 let (stream, remote_addr) = ready!(listener.poll_accept(cx)?);
200
201 let socket = socket2::SockRef::from(&stream);
202 socket.set_keepalive(true)?;
203 socket.set_nodelay(true)?;
204
205 Poll::Ready(Ok((
206 remote_addr.into(),
207 UnixOrTcpConnection::Tcp { stream },
208 )))
209 }
210 }
211 }
212}
213
214pin_project_lite::pin_project! {
215 #[project = UnixOrTcpConnectionProj]
216 pub enum UnixOrTcpConnection {
217 Unix {
218 #[pin]
219 stream: UnixStream,
220 },
221
222 Tcp {
223 #[pin]
224 stream: TcpStream,
225 },
226 }
227}
228
229impl From<TcpStream> for UnixOrTcpConnection {
230 fn from(stream: TcpStream) -> Self {
231 Self::Tcp { stream }
232 }
233}
234
235impl UnixOrTcpConnection {
236 pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
243 match self {
244 Self::Unix { stream } => stream.local_addr().map(SocketAddr::from),
245 Self::Tcp { stream } => stream.local_addr().map(SocketAddr::from),
246 }
247 }
248
249 pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
256 match self {
257 Self::Unix { stream } => stream.peer_addr().map(SocketAddr::from),
258 Self::Tcp { stream } => stream.peer_addr().map(SocketAddr::from),
259 }
260 }
261}
262
263impl AsyncRead for UnixOrTcpConnection {
264 fn poll_read(
265 self: Pin<&mut Self>,
266 cx: &mut Context<'_>,
267 buf: &mut tokio::io::ReadBuf<'_>,
268 ) -> Poll<std::io::Result<()>> {
269 match self.project() {
270 UnixOrTcpConnectionProj::Unix { stream } => stream.poll_read(cx, buf),
271 UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_read(cx, buf),
272 }
273 }
274}
275
276impl AsyncWrite for UnixOrTcpConnection {
277 fn poll_write(
278 self: Pin<&mut Self>,
279 cx: &mut Context<'_>,
280 buf: &[u8],
281 ) -> Poll<Result<usize, std::io::Error>> {
282 match self.project() {
283 UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write(cx, buf),
284 UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write(cx, buf),
285 }
286 }
287
288 fn poll_write_vectored(
289 self: Pin<&mut Self>,
290 cx: &mut Context<'_>,
291 bufs: &[std::io::IoSlice<'_>],
292 ) -> Poll<Result<usize, std::io::Error>> {
293 match self.project() {
294 UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write_vectored(cx, bufs),
295 UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write_vectored(cx, bufs),
296 }
297 }
298
299 fn is_write_vectored(&self) -> bool {
300 match self {
301 UnixOrTcpConnection::Unix { stream } => stream.is_write_vectored(),
302 UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
303 }
304 }
305
306 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
307 match self.project() {
308 UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
309 UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
310 }
311 }
312
313 fn poll_shutdown(
314 self: Pin<&mut Self>,
315 cx: &mut Context<'_>,
316 ) -> Poll<Result<(), std::io::Error>> {
317 match self.project() {
318 UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
319 UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
320 }
321 }
322}