1use std::{
4 cmp, io,
5 marker::Unpin,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use bytes::{Buf, Bytes};
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12
13#[derive(Debug)]
15pub struct Rewind<T> {
16 pre: Option<Bytes>,
17 inner: T,
18}
19
20impl<T> Rewind<T> {
21 pub(crate) fn new(io: T) -> Self {
22 Rewind {
23 pre: None,
24 inner: io,
25 }
26 }
27
28 pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
29 Rewind {
30 pre: Some(buf),
31 inner: io,
32 }
33 }
34
35 #[cfg(test)]
36 pub(crate) fn rewind(&mut self, bs: Bytes) {
37 debug_assert!(self.pre.is_none());
38 self.pre = Some(bs);
39 }
40}
41
42impl<T> AsyncRead for Rewind<T>
43where
44 T: AsyncRead + Unpin,
45{
46 fn poll_read(
47 mut self: Pin<&mut Self>,
48 cx: &mut Context<'_>,
49 buf: &mut ReadBuf<'_>,
50 ) -> Poll<io::Result<()>> {
51 if let Some(mut prefix) = self.pre.take() {
52 if !prefix.is_empty() {
54 let copy_len = cmp::min(prefix.len(), buf.remaining());
55 buf.put_slice(&prefix[..copy_len]);
57 prefix.advance(copy_len);
58 if !prefix.is_empty() {
60 self.pre = Some(prefix);
61 }
62
63 return Poll::Ready(Ok(()));
64 }
65 }
66 Pin::new(&mut self.inner).poll_read(cx, buf)
67 }
68}
69
70impl<T> AsyncWrite for Rewind<T>
71where
72 T: AsyncWrite + Unpin,
73{
74 fn poll_write(
75 mut self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 buf: &[u8],
78 ) -> Poll<io::Result<usize>> {
79 Pin::new(&mut self.inner).poll_write(cx, buf)
80 }
81
82 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83 Pin::new(&mut self.inner).poll_flush(cx)
84 }
85
86 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
87 Pin::new(&mut self.inner).poll_shutdown(cx)
88 }
89
90 fn poll_write_vectored(
91 mut self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 bufs: &[io::IoSlice<'_>],
94 ) -> Poll<io::Result<usize>> {
95 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
96 }
97
98 fn is_write_vectored(&self) -> bool {
99 self.inner.is_write_vectored()
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use bytes::Bytes;
108 use tokio::io::AsyncReadExt;
109
110 use super::Rewind;
111
112 #[tokio::test]
113 async fn partial_rewind() {
114 let underlying = [104, 101, 108, 108, 111];
115
116 let mock = tokio_test::io::Builder::new().read(&underlying).build();
117
118 let mut stream = Rewind::new(mock);
119
120 let mut buf = [0; 2];
122 stream.read_exact(&mut buf).await.expect("read1");
123
124 stream.rewind(Bytes::copy_from_slice(&buf[..]));
126
127 let mut buf = [0; 5];
128 stream.read_exact(&mut buf).await.expect("read1");
129
130 assert_eq!(&buf, &underlying);
132 }
133
134 #[tokio::test]
135 async fn full_rewind() {
136 let underlying = [104, 101, 108, 108, 111];
137
138 let mock = tokio_test::io::Builder::new().read(&underlying).build();
139
140 let mut stream = Rewind::new(mock);
141
142 let mut buf = [0; 5];
143 stream.read_exact(&mut buf).await.expect("read1");
144
145 stream.rewind(Bytes::copy_from_slice(&buf[..]));
147
148 let mut buf = [0; 5];
149 stream.read_exact(&mut buf).await.expect("read1");
150 }
151}