mas_listener/
unix_or_tcp.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A listener which can listen on either TCP sockets or on UNIX domain sockets
8
9// TODO: Unlink the UNIX socket on drop?
10
11use std::{
12    pin::Pin,
13    task::{Context, Poll, ready},
14};
15
16use tokio::{
17    io::{AsyncRead, AsyncWrite},
18    net::{TcpListener, TcpStream, UnixListener, UnixStream},
19};
20
21pub enum SocketAddr {
22    Unix(tokio::net::unix::SocketAddr),
23    Net(std::net::SocketAddr),
24}
25
26impl From<tokio::net::unix::SocketAddr> for SocketAddr {
27    fn from(value: tokio::net::unix::SocketAddr) -> Self {
28        Self::Unix(value)
29    }
30}
31
32impl From<std::net::SocketAddr> for SocketAddr {
33    fn from(value: std::net::SocketAddr) -> Self {
34        Self::Net(value)
35    }
36}
37
38impl std::fmt::Debug for SocketAddr {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::Unix(l) => std::fmt::Debug::fmt(l, f),
42            Self::Net(l) => std::fmt::Debug::fmt(l, f),
43        }
44    }
45}
46
47impl SocketAddr {
48    #[must_use]
49    pub fn into_net(self) -> Option<std::net::SocketAddr> {
50        match self {
51            Self::Net(socket) => Some(socket),
52            Self::Unix(_) => None,
53        }
54    }
55
56    #[must_use]
57    pub fn into_unix(self) -> Option<tokio::net::unix::SocketAddr> {
58        match self {
59            Self::Net(_) => None,
60            Self::Unix(socket) => Some(socket),
61        }
62    }
63
64    #[must_use]
65    pub const fn as_net(&self) -> Option<&std::net::SocketAddr> {
66        match self {
67            Self::Net(socket) => Some(socket),
68            Self::Unix(_) => None,
69        }
70    }
71
72    #[must_use]
73    pub const fn as_unix(&self) -> Option<&tokio::net::unix::SocketAddr> {
74        match self {
75            Self::Net(_) => None,
76            Self::Unix(socket) => Some(socket),
77        }
78    }
79}
80
81pub enum UnixOrTcpListener {
82    Unix(UnixListener),
83    Tcp(TcpListener),
84}
85
86impl From<UnixListener> for UnixOrTcpListener {
87    fn from(listener: UnixListener) -> Self {
88        Self::Unix(listener)
89    }
90}
91
92impl From<TcpListener> for UnixOrTcpListener {
93    fn from(listener: TcpListener) -> Self {
94        Self::Tcp(listener)
95    }
96}
97
98impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
99    type Error = std::io::Error;
100
101    fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
102        listener.set_nonblocking(true)?;
103        Ok(Self::Unix(UnixListener::from_std(listener)?))
104    }
105}
106
107impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
108    type Error = std::io::Error;
109
110    fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
111        listener.set_nonblocking(true)?;
112        Ok(Self::Tcp(TcpListener::from_std(listener)?))
113    }
114}
115
116impl UnixOrTcpListener {
117    /// Get the local address of the listener
118    ///
119    /// # Errors
120    ///
121    /// Returns an error on rare cases where the underlying [`TcpListener`] or
122    /// [`UnixListener`] couldn't provide the local address
123    pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
124        match self {
125            Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
126            Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from),
127        }
128    }
129
130    pub const fn is_unix(&self) -> bool {
131        matches!(self, Self::Unix(_))
132    }
133
134    pub const fn is_tcp(&self) -> bool {
135        matches!(self, Self::Tcp(_))
136    }
137
138    /// Accept an incoming connection
139    ///
140    /// # Cancel safety
141    ///
142    /// This function is safe to cancel, as both [`UnixListener::accept`] and
143    /// [`TcpListener::accept`] are safe to cancel.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if the underlying socket couldn't accept the connection
148    pub async fn accept(&self) -> Result<(SocketAddr, UnixOrTcpConnection), std::io::Error> {
149        match self {
150            Self::Unix(listener) => {
151                let (stream, remote_addr) = listener.accept().await?;
152
153                let socket = socket2::SockRef::from(&stream);
154                socket.set_keepalive(true)?;
155                socket.set_nodelay(true)?;
156
157                Ok((remote_addr.into(), UnixOrTcpConnection::Unix { stream }))
158            }
159            Self::Tcp(listener) => {
160                let (stream, remote_addr) = listener.accept().await?;
161
162                let socket = socket2::SockRef::from(&stream);
163                socket.set_keepalive(true)?;
164                socket.set_nodelay(true)?;
165
166                Ok((remote_addr.into(), UnixOrTcpConnection::Tcp { stream }))
167            }
168        }
169    }
170
171    /// Poll for an incoming connection
172    ///
173    /// # Cancel safety
174    ///
175    /// This function is safe to cancel, as both [`UnixListener::poll_accept`]
176    /// and [`TcpListener::poll_accept`] are safe to cancel.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if the underlying socket couldn't accept the connection
181    pub fn poll_accept(
182        &self,
183        cx: &mut Context<'_>,
184    ) -> Poll<Result<(SocketAddr, UnixOrTcpConnection), std::io::Error>> {
185        match self {
186            Self::Unix(listener) => {
187                let (stream, remote_addr) = ready!(listener.poll_accept(cx)?);
188
189                let socket = socket2::SockRef::from(&stream);
190                socket.set_keepalive(true)?;
191                socket.set_nodelay(true)?;
192
193                Poll::Ready(Ok((
194                    remote_addr.into(),
195                    UnixOrTcpConnection::Unix { stream },
196                )))
197            }
198            Self::Tcp(listener) => {
199                let (stream, remote_addr) = ready!(listener.poll_accept(cx)?);
200
201                let socket = socket2::SockRef::from(&stream);
202                socket.set_keepalive(true)?;
203                socket.set_nodelay(true)?;
204
205                Poll::Ready(Ok((
206                    remote_addr.into(),
207                    UnixOrTcpConnection::Tcp { stream },
208                )))
209            }
210        }
211    }
212}
213
214pin_project_lite::pin_project! {
215    #[project = UnixOrTcpConnectionProj]
216    pub enum UnixOrTcpConnection {
217        Unix {
218            #[pin]
219            stream: UnixStream,
220        },
221
222        Tcp {
223            #[pin]
224            stream: TcpStream,
225        },
226    }
227}
228
229impl From<TcpStream> for UnixOrTcpConnection {
230    fn from(stream: TcpStream) -> Self {
231        Self::Tcp { stream }
232    }
233}
234
235impl UnixOrTcpConnection {
236    /// Get the local address of the stream
237    ///
238    /// # Errors
239    ///
240    /// Returns an error on rare cases where the underlying [`TcpStream`] or
241    /// [`UnixStream`] couldn't provide the local address
242    pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
243        match self {
244            Self::Unix { stream } => stream.local_addr().map(SocketAddr::from),
245            Self::Tcp { stream } => stream.local_addr().map(SocketAddr::from),
246        }
247    }
248
249    /// Get the remote address of the stream
250    ///
251    /// # Errors
252    ///
253    /// Returns an error on rare cases where the underlying [`TcpStream`] or
254    /// [`UnixStream`] couldn't provide the remote address
255    pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
256        match self {
257            Self::Unix { stream } => stream.peer_addr().map(SocketAddr::from),
258            Self::Tcp { stream } => stream.peer_addr().map(SocketAddr::from),
259        }
260    }
261}
262
263impl AsyncRead for UnixOrTcpConnection {
264    fn poll_read(
265        self: Pin<&mut Self>,
266        cx: &mut Context<'_>,
267        buf: &mut tokio::io::ReadBuf<'_>,
268    ) -> Poll<std::io::Result<()>> {
269        match self.project() {
270            UnixOrTcpConnectionProj::Unix { stream } => stream.poll_read(cx, buf),
271            UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_read(cx, buf),
272        }
273    }
274}
275
276impl AsyncWrite for UnixOrTcpConnection {
277    fn poll_write(
278        self: Pin<&mut Self>,
279        cx: &mut Context<'_>,
280        buf: &[u8],
281    ) -> Poll<Result<usize, std::io::Error>> {
282        match self.project() {
283            UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write(cx, buf),
284            UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write(cx, buf),
285        }
286    }
287
288    fn poll_write_vectored(
289        self: Pin<&mut Self>,
290        cx: &mut Context<'_>,
291        bufs: &[std::io::IoSlice<'_>],
292    ) -> Poll<Result<usize, std::io::Error>> {
293        match self.project() {
294            UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write_vectored(cx, bufs),
295            UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write_vectored(cx, bufs),
296        }
297    }
298
299    fn is_write_vectored(&self) -> bool {
300        match self {
301            UnixOrTcpConnection::Unix { stream } => stream.is_write_vectored(),
302            UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
303        }
304    }
305
306    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
307        match self.project() {
308            UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
309            UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
310        }
311    }
312
313    fn poll_shutdown(
314        self: Pin<&mut Self>,
315        cx: &mut Context<'_>,
316    ) -> Poll<Result<(), std::io::Error>> {
317        match self.project() {
318            UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
319            UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
320        }
321    }
322}