mas_listener/
maybe_tls.rs1use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use tokio_rustls::{
15 TlsAcceptor,
16 rustls::{
17 ProtocolVersion, ServerConfig, ServerConnection, SupportedCipherSuite,
18 pki_types::CertificateDer,
19 },
20};
21
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub struct TlsStreamInfo {
25 pub protocol_version: ProtocolVersion,
26 pub negotiated_cipher_suite: SupportedCipherSuite,
27 pub sni_hostname: Option<String>,
28 pub alpn_protocol: Option<Vec<u8>>,
29 pub peer_certificates: Option<Vec<CertificateDer<'static>>>,
30}
31
32impl TlsStreamInfo {
33 #[must_use]
34 pub fn is_alpn_h2(&self) -> bool {
35 matches!(self.alpn_protocol.as_deref(), Some(b"h2"))
36 }
37}
38
39pin_project_lite::pin_project! {
40 #[project = MaybeTlsStreamProj]
41 pub enum MaybeTlsStream<T> {
42 Secure {
43 #[pin]
44 stream: tokio_rustls::server::TlsStream<T>
45 },
46 Insecure {
47 #[pin]
48 stream: T,
49 },
50 }
51}
52
53impl<T> MaybeTlsStream<T> {
54 pub fn get_ref(&self) -> &T {
60 match self {
61 Self::Secure { stream } => stream.get_ref().0,
62 Self::Insecure { stream } => stream,
63 }
64 }
65
66 pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
70 match self {
71 Self::Secure { stream } => Some(stream.get_ref().1),
72 Self::Insecure { .. } => None,
73 }
74 }
75
76 pub fn tls_info(&self) -> Option<TlsStreamInfo> {
83 let conn = self.get_tls_connection()?;
84
85 let protocol_version = conn
88 .protocol_version()
89 .expect("TLS handshake is not done yet");
90 let negotiated_cipher_suite = conn
91 .negotiated_cipher_suite()
92 .expect("TLS handshake is not done yet");
93
94 let sni_hostname = conn.server_name().map(ToOwned::to_owned);
95 let alpn_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
96 let peer_certificates = conn.peer_certificates().map(|certs| {
97 certs
98 .iter()
99 .cloned()
100 .map(CertificateDer::into_owned)
101 .collect()
102 });
103 Some(TlsStreamInfo {
104 protocol_version,
105 negotiated_cipher_suite,
106 sni_hostname,
107 alpn_protocol,
108 peer_certificates,
109 })
110 }
111}
112
113impl<T> AsyncRead for MaybeTlsStream<T>
114where
115 T: AsyncRead + AsyncWrite + Unpin,
116{
117 fn poll_read(
118 self: Pin<&mut Self>,
119 cx: &mut Context,
120 buf: &mut ReadBuf,
121 ) -> Poll<std::io::Result<()>> {
122 match self.project() {
123 MaybeTlsStreamProj::Secure { stream } => stream.poll_read(cx, buf),
124 MaybeTlsStreamProj::Insecure { stream } => stream.poll_read(cx, buf),
125 }
126 }
127}
128
129impl<T> AsyncWrite for MaybeTlsStream<T>
130where
131 T: AsyncRead + AsyncWrite + Unpin,
132{
133 fn poll_write(
134 self: Pin<&mut Self>,
135 cx: &mut Context<'_>,
136 buf: &[u8],
137 ) -> Poll<std::io::Result<usize>> {
138 match self.project() {
139 MaybeTlsStreamProj::Secure { stream } => stream.poll_write(cx, buf),
140 MaybeTlsStreamProj::Insecure { stream } => stream.poll_write(cx, buf),
141 }
142 }
143
144 fn poll_write_vectored(
145 self: Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 bufs: &[std::io::IoSlice<'_>],
148 ) -> Poll<Result<usize, std::io::Error>> {
149 match self.project() {
150 MaybeTlsStreamProj::Secure { stream } => stream.poll_write_vectored(cx, bufs),
151 MaybeTlsStreamProj::Insecure { stream } => stream.poll_write_vectored(cx, bufs),
152 }
153 }
154
155 fn is_write_vectored(&self) -> bool {
156 match self {
157 Self::Secure { stream } => stream.is_write_vectored(),
158 Self::Insecure { stream } => stream.is_write_vectored(),
159 }
160 }
161
162 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
163 match self.project() {
164 MaybeTlsStreamProj::Secure { stream } => stream.poll_flush(cx),
165 MaybeTlsStreamProj::Insecure { stream } => stream.poll_flush(cx),
166 }
167 }
168
169 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
170 match self.project() {
171 MaybeTlsStreamProj::Secure { stream } => stream.poll_shutdown(cx),
172 MaybeTlsStreamProj::Insecure { stream } => stream.poll_shutdown(cx),
173 }
174 }
175}
176
177#[derive(Clone)]
178pub struct MaybeTlsAcceptor {
179 tls_config: Option<Arc<ServerConfig>>,
180}
181
182impl MaybeTlsAcceptor {
183 #[must_use]
184 pub fn new(tls_config: Option<Arc<ServerConfig>>) -> Self {
185 Self { tls_config }
186 }
187
188 #[must_use]
189 pub fn new_secure(tls_config: Arc<ServerConfig>) -> Self {
190 Self {
191 tls_config: Some(tls_config),
192 }
193 }
194
195 #[must_use]
196 pub fn new_insecure() -> Self {
197 Self { tls_config: None }
198 }
199
200 #[must_use]
201 pub const fn is_secure(&self) -> bool {
202 self.tls_config.is_some()
203 }
204
205 pub async fn accept<T>(&self, stream: T) -> Result<MaybeTlsStream<T>, std::io::Error>
211 where
212 T: AsyncRead + AsyncWrite + Unpin,
213 {
214 match &self.tls_config {
215 Some(config) => {
216 let acceptor = TlsAcceptor::from(config.clone());
217 let stream = acceptor.accept(stream).await?;
218 Ok(MaybeTlsStream::Secure { stream })
219 }
220 None => Ok(MaybeTlsStream::Insecure { stream }),
221 }
222 }
223}