1use std::io;
5use std::io::IoSlice;
6use std::pin::Pin;
7use std::task::Context;
8use std::task::Poll;
9use tokio::io::AsyncRead;
10use tokio::io::AsyncWrite;
11use tokio::io::ReadBuf;
12use tokio_rustls::server::TlsStream;
13
14pub(crate) enum ServerIo<IO> {
15 Io(IO),
16 TlsIo(Box<TlsStream<IO>>),
17}
18
19impl<IO> ServerIo<IO> {
20 pub(crate) fn new_io(io: IO) -> Self {
21 Self::Io(io)
22 }
23
24 pub(crate) fn new_tls_io(io: TlsStream<IO>) -> Self {
25 Self::TlsIo(Box::new(io))
26 }
27
28 pub(crate) fn peer_certs(
29 &self,
30 ) -> Option<std::sync::Arc<Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>>> {
31 match self {
32 Self::Io(_) => None,
33 Self::TlsIo(io) => {
34 let (_inner, session) = io.get_ref();
35
36 session
37 .peer_certificates()
38 .map(|certs| certs.to_owned().into())
39 }
40 }
41 }
42}
43
44impl<IO> AsyncRead for ServerIo<IO>
45where
46 IO: AsyncWrite + AsyncRead + Unpin,
47{
48 fn poll_read(
49 mut self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 buf: &mut ReadBuf<'_>,
52 ) -> Poll<io::Result<()>> {
53 match &mut *self {
54 Self::Io(io) => Pin::new(io).poll_read(cx, buf),
55 Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf),
56 }
57 }
58}
59
60impl<IO> AsyncWrite for ServerIo<IO>
61where
62 IO: AsyncWrite + AsyncRead + Unpin,
63{
64 fn poll_write(
65 mut self: Pin<&mut Self>,
66 cx: &mut Context<'_>,
67 buf: &[u8],
68 ) -> Poll<io::Result<usize>> {
69 match &mut *self {
70 Self::Io(io) => Pin::new(io).poll_write(cx, buf),
71 Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf),
72 }
73 }
74
75 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
76 match &mut *self {
77 Self::Io(io) => Pin::new(io).poll_flush(cx),
78 Self::TlsIo(io) => Pin::new(io).poll_flush(cx),
79 }
80 }
81
82 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83 match &mut *self {
84 Self::Io(io) => Pin::new(io).poll_shutdown(cx),
85 Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx),
86 }
87 }
88
89 fn poll_write_vectored(
90 mut self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 bufs: &[IoSlice<'_>],
93 ) -> Poll<Result<usize, io::Error>> {
94 match &mut *self {
95 Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs),
96 Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs),
97 }
98 }
99
100 fn is_write_vectored(&self) -> bool {
101 match self {
102 Self::Io(io) => io.is_write_vectored(),
103 Self::TlsIo(io) => io.is_write_vectored(),
104 }
105 }
106}