mas_listener/
rewind.rs

1// Taken from hyper@0.14.20, src/common/io/rewind.rs
2
3use std::{
4    cmp, io,
5    marker::Unpin,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use bytes::{Buf, Bytes};
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12
13/// Combine a buffer with an IO, rewinding reads to use the buffer.
14#[derive(Debug)]
15pub struct Rewind<T> {
16    pre: Option<Bytes>,
17    inner: T,
18}
19
20impl<T> Rewind<T> {
21    pub(crate) fn new(io: T) -> Self {
22        Rewind {
23            pre: None,
24            inner: io,
25        }
26    }
27
28    pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
29        Rewind {
30            pre: Some(buf),
31            inner: io,
32        }
33    }
34
35    #[cfg(test)]
36    pub(crate) fn rewind(&mut self, bs: Bytes) {
37        debug_assert!(self.pre.is_none());
38        self.pre = Some(bs);
39    }
40}
41
42impl<T> AsyncRead for Rewind<T>
43where
44    T: AsyncRead + Unpin,
45{
46    fn poll_read(
47        mut self: Pin<&mut Self>,
48        cx: &mut Context<'_>,
49        buf: &mut ReadBuf<'_>,
50    ) -> Poll<io::Result<()>> {
51        if let Some(mut prefix) = self.pre.take() {
52            // If there are no remaining bytes, let the bytes get dropped.
53            if !prefix.is_empty() {
54                let copy_len = cmp::min(prefix.len(), buf.remaining());
55                // TODO: There should be a way to do following two lines cleaner...
56                buf.put_slice(&prefix[..copy_len]);
57                prefix.advance(copy_len);
58                // Put back what's left
59                if !prefix.is_empty() {
60                    self.pre = Some(prefix);
61                }
62
63                return Poll::Ready(Ok(()));
64            }
65        }
66        Pin::new(&mut self.inner).poll_read(cx, buf)
67    }
68}
69
70impl<T> AsyncWrite for Rewind<T>
71where
72    T: AsyncWrite + Unpin,
73{
74    fn poll_write(
75        mut self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &[u8],
78    ) -> Poll<io::Result<usize>> {
79        Pin::new(&mut self.inner).poll_write(cx, buf)
80    }
81
82    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83        Pin::new(&mut self.inner).poll_flush(cx)
84    }
85
86    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
87        Pin::new(&mut self.inner).poll_shutdown(cx)
88    }
89
90    fn poll_write_vectored(
91        mut self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        bufs: &[io::IoSlice<'_>],
94    ) -> Poll<io::Result<usize>> {
95        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
96    }
97
98    fn is_write_vectored(&self) -> bool {
99        self.inner.is_write_vectored()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    // FIXME: re-implement tests with `async/await`, this import should
106    // trigger a warning to remind us
107    use bytes::Bytes;
108    use tokio::io::AsyncReadExt;
109
110    use super::Rewind;
111
112    #[tokio::test]
113    async fn partial_rewind() {
114        let underlying = [104, 101, 108, 108, 111];
115
116        let mock = tokio_test::io::Builder::new().read(&underlying).build();
117
118        let mut stream = Rewind::new(mock);
119
120        // Read off some bytes, ensure we filled o1
121        let mut buf = [0; 2];
122        stream.read_exact(&mut buf).await.expect("read1");
123
124        // Rewind the stream so that it is as if we never read in the first place.
125        stream.rewind(Bytes::copy_from_slice(&buf[..]));
126
127        let mut buf = [0; 5];
128        stream.read_exact(&mut buf).await.expect("read1");
129
130        // At this point we should have read everything that was in the MockStream
131        assert_eq!(&buf, &underlying);
132    }
133
134    #[tokio::test]
135    async fn full_rewind() {
136        let underlying = [104, 101, 108, 108, 111];
137
138        let mock = tokio_test::io::Builder::new().read(&underlying).build();
139
140        let mut stream = Rewind::new(mock);
141
142        let mut buf = [0; 5];
143        stream.read_exact(&mut buf).await.expect("read1");
144
145        // Rewind the stream so that it is as if we never read in the first place.
146        stream.rewind(Bytes::copy_from_slice(&buf[..]));
147
148        let mut buf = [0; 5];
149        stream.read_exact(&mut buf).await.expect("read1");
150    }
151}