mas_listener/
maybe_tls.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11};
12
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use tokio_rustls::{
15    TlsAcceptor,
16    rustls::{
17        ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite,
18        pki_types::CertificateDer,
19    },
20};
21
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct TlsStreamInfo {
25    pub protocol_version: ProtocolVersion,
26    pub negotiated_cipher_suite: SupportedCipherSuite,
27    pub sni_hostname: Option<String>,
28    pub alpn_protocol: Option<Vec<u8>>,
29    pub peer_certificates: Option<Vec<CertificateDer<'static>>>,
30}
31
32impl TlsStreamInfo {
33    #[must_use]
34    pub fn is_alpn_h2(&self) -> bool {
35        matches!(self.alpn_protocol.as_deref(), Some(b"h2"))
36    }
37}
38
39pin_project_lite::pin_project! {
40    #[project = MaybeTlsStreamProj]
41    pub enum MaybeTlsStream<T> {
42        Secure {
43            #[pin]
44            stream: tokio_rustls::server::TlsStream<T>
45        },
46        Insecure {
47            #[pin]
48            stream: T,
49        },
50    }
51}
52
53impl<T> MaybeTlsStream<T> {
54    /// Get a reference to the underlying IO stream
55    ///
56    /// Returns [`None`] if the stream closed before the TLS handshake finished.
57    /// It is guaranteed to return [`Some`] value after the handshake finished,
58    /// or if it is a non-TLS connection.
59    pub fn get_ref(&self) -> &T {
60        match self {
61            Self::Secure { stream } => stream.get_ref().0,
62            Self::Insecure { stream } => stream,
63        }
64    }
65
66    /// Get a ref to the [`ServerConnection`] of the establish TLS stream.
67    ///
68    /// Returns [`None`] for non-TLS connections.
69    pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
70        match self {
71            Self::Secure { stream } => Some(stream.get_ref().1),
72            Self::Insecure { .. } => None,
73        }
74    }
75
76    /// Gather informations about the TLS connection. Returns `None` if the
77    /// stream is not a TLS stream.
78    ///
79    /// # Panics
80    ///
81    /// Panics if the TLS handshake is not done yet, which should never happen
82    pub fn tls_info(&self) -> Option<TlsStreamInfo> {
83        let conn = self.get_tls_connection()?;
84
85        // SAFETY: we're getting the protocol version and cipher suite *after* the
86        // handshake, so this should never lead to a panic
87        let protocol_version = conn
88            .protocol_version()
89            .expect("TLS handshake is not done yet");
90        let negotiated_cipher_suite = conn
91            .negotiated_cipher_suite()
92            .expect("TLS handshake is not done yet");
93
94        let sni_hostname = conn.server_name().map(ToOwned::to_owned);
95        let alpn_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
96        let peer_certificates = conn.peer_certificates().map(|certs| {
97            certs
98                .iter()
99                .cloned()
100                .map(CertificateDer::into_owned)
101                .collect()
102        });
103        Some(TlsStreamInfo {
104            protocol_version,
105            negotiated_cipher_suite,
106            sni_hostname,
107            alpn_protocol,
108            peer_certificates,
109        })
110    }
111}
112
113impl<T> AsyncRead for MaybeTlsStream<T>
114where
115    T: AsyncRead + AsyncWrite + Unpin,
116{
117    fn poll_read(
118        self: Pin<&mut Self>,
119        cx: &mut Context,
120        buf: &mut ReadBuf,
121    ) -> Poll<std::io::Result<()>> {
122        match self.project() {
123            MaybeTlsStreamProj::Secure { stream } => stream.poll_read(cx, buf),
124            MaybeTlsStreamProj::Insecure { stream } => stream.poll_read(cx, buf),
125        }
126    }
127}
128
129impl<T> AsyncWrite for MaybeTlsStream<T>
130where
131    T: AsyncRead + AsyncWrite + Unpin,
132{
133    fn poll_write(
134        self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136        buf: &[u8],
137    ) -> Poll<std::io::Result<usize>> {
138        match self.project() {
139            MaybeTlsStreamProj::Secure { stream } => stream.poll_write(cx, buf),
140            MaybeTlsStreamProj::Insecure { stream } => stream.poll_write(cx, buf),
141        }
142    }
143
144    fn poll_write_vectored(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        bufs: &[std::io::IoSlice<'_>],
148    ) -> Poll<Result<usize, std::io::Error>> {
149        match self.project() {
150            MaybeTlsStreamProj::Secure { stream } => stream.poll_write_vectored(cx, bufs),
151            MaybeTlsStreamProj::Insecure { stream } => stream.poll_write_vectored(cx, bufs),
152        }
153    }
154
155    fn is_write_vectored(&self) -> bool {
156        match self {
157            Self::Secure { stream } => stream.is_write_vectored(),
158            Self::Insecure { stream } => stream.is_write_vectored(),
159        }
160    }
161
162    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
163        match self.project() {
164            MaybeTlsStreamProj::Secure { stream } => stream.poll_flush(cx),
165            MaybeTlsStreamProj::Insecure { stream } => stream.poll_flush(cx),
166        }
167    }
168
169    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
170        match self.project() {
171            MaybeTlsStreamProj::Secure { stream } => stream.poll_shutdown(cx),
172            MaybeTlsStreamProj::Insecure { stream } => stream.poll_shutdown(cx),
173        }
174    }
175}
176
177#[derive(Clone)]
178pub struct MaybeTlsAcceptor {
179    tls_config: Option<Arc<ServerConfig>>,
180}
181
182impl MaybeTlsAcceptor {
183    #[must_use]
184    pub fn new(tls_config: Option<Arc<ServerConfig>>) -> Self {
185        Self { tls_config }
186    }
187
188    #[must_use]
189    pub fn new_secure(tls_config: Arc<ServerConfig>) -> Self {
190        Self {
191            tls_config: Some(tls_config),
192        }
193    }
194
195    #[must_use]
196    pub fn new_insecure() -> Self {
197        Self { tls_config: None }
198    }
199
200    #[must_use]
201    pub const fn is_secure(&self) -> bool {
202        self.tls_config.is_some()
203    }
204
205    /// Accept a connection and do the TLS handshake
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if the TLS handshake failed
210    pub async fn accept<T>(&self, stream: T) -> Result<MaybeTlsStream<T>, std::io::Error>
211    where
212        T: AsyncRead + AsyncWrite + Unpin,
213    {
214        match &self.tls_config {
215            Some(config) => {
216                let acceptor = TlsAcceptor::from(config.clone());
217                let stream = acceptor.accept(stream).await?;
218                Ok(MaybeTlsStream::Secure { stream })
219            }
220            None => Ok(MaybeTlsStream::Insecure { stream }),
221        }
222    }
223}