1use std::io;
5use std::io::IoSlice;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9use tokio_rustls::server::TlsStream;
10
11pub(crate) enum ServerIo<IO> {
12 Io(IO),
13 TlsIo(Box<TlsStream<IO>>),
14}
15
16impl<IO> ServerIo<IO> {
17 pub(crate) fn new_io(io: IO) -> Self {
18 Self::Io(io)
19 }
20
21 pub(crate) fn new_tls_io(io: TlsStream<IO>) -> Self {
22 Self::TlsIo(Box::new(io))
23 }
24
25 pub(crate) fn peer_certs(
26 &self,
27 ) -> Option<std::sync::Arc<Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>>> {
28 match self {
29 Self::Io(_) => None,
30 Self::TlsIo(io) => {
31 let (_inner, session) = io.get_ref();
32
33 session
34 .peer_certificates()
35 .map(|certs| certs.to_owned().into())
36 }
37 }
38 }
39}
40
41impl<IO> AsyncRead for ServerIo<IO>
42where
43 IO: AsyncWrite + AsyncRead + Unpin,
44{
45 fn poll_read(
46 mut self: Pin<&mut Self>,
47 cx: &mut Context<'_>,
48 buf: &mut ReadBuf<'_>,
49 ) -> Poll<io::Result<()>> {
50 match &mut *self {
51 Self::Io(io) => Pin::new(io).poll_read(cx, buf),
52 Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf),
53 }
54 }
55}
56
57impl<IO> AsyncWrite for ServerIo<IO>
58where
59 IO: AsyncWrite + AsyncRead + Unpin,
60{
61 fn poll_write(
62 mut self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 buf: &[u8],
65 ) -> Poll<io::Result<usize>> {
66 match &mut *self {
67 Self::Io(io) => Pin::new(io).poll_write(cx, buf),
68 Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf),
69 }
70 }
71
72 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
73 match &mut *self {
74 Self::Io(io) => Pin::new(io).poll_flush(cx),
75 Self::TlsIo(io) => Pin::new(io).poll_flush(cx),
76 }
77 }
78
79 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
80 match &mut *self {
81 Self::Io(io) => Pin::new(io).poll_shutdown(cx),
82 Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx),
83 }
84 }
85
86 fn poll_write_vectored(
87 mut self: Pin<&mut Self>,
88 cx: &mut Context<'_>,
89 bufs: &[IoSlice<'_>],
90 ) -> Poll<Result<usize, io::Error>> {
91 match &mut *self {
92 Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs),
93 Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs),
94 }
95 }
96
97 fn is_write_vectored(&self) -> bool {
98 match self {
99 Self::Io(io) => io.is_write_vectored(),
100 Self::TlsIo(io) => io.is_write_vectored(),
101 }
102 }
103}