sui_http/
listener.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::time::Duration;
5
6/// Types that can listen for connections.
7pub trait Listener: Send + 'static {
8    /// The listener's IO type.
9    type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static;
10
11    /// The listener's address type.
12    // all these bounds are necessary to add this information in a request extension
13    type Addr: Clone + Send + Sync + 'static;
14
15    /// Accept a new incoming connection to this listener.
16    ///
17    /// If the underlying accept call can return an error, this function must
18    /// take care of logging and retrying.
19    fn accept(&mut self) -> impl std::future::Future<Output = (Self::Io, Self::Addr)> + Send;
20
21    /// Returns the local address that this listener is bound to.
22    fn local_addr(&self) -> std::io::Result<Self::Addr>;
23}
24
25/// Extensions to [`Listener`].
26pub trait ListenerExt: Listener + Sized {
27    /// Run a mutable closure on every accepted `Io`.
28    ///
29    /// # Example
30    ///
31    /// ```
32    /// use sui_http::ListenerExt;
33    /// use tracing::trace;
34    ///
35    /// # async {
36    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
37    ///     .await
38    ///     .unwrap()
39    ///     .tap_io(|tcp_stream| {
40    ///         if let Err(err) = tcp_stream.set_nodelay(true) {
41    ///             trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
42    ///         }
43    ///     });
44    /// # };
45    /// ```
46    fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
47    where
48        F: FnMut(&mut Self::Io) + Send + 'static,
49    {
50        TapIo {
51            listener: self,
52            tap_fn,
53        }
54    }
55}
56
57impl<L: Listener> ListenerExt for L {}
58
59impl Listener for tokio::net::TcpListener {
60    type Io = tokio::net::TcpStream;
61    type Addr = std::net::SocketAddr;
62
63    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
64        loop {
65            match Self::accept(self).await {
66                Ok(tup) => return tup,
67                Err(e) => handle_accept_error(e).await,
68            }
69        }
70    }
71
72    #[inline]
73    fn local_addr(&self) -> std::io::Result<Self::Addr> {
74        Self::local_addr(self)
75    }
76}
77
78#[derive(Debug)]
79pub struct TcpListenerWithOptions {
80    inner: tokio::net::TcpListener,
81    nodelay: bool,
82    keepalive: Option<Duration>,
83}
84
85impl TcpListenerWithOptions {
86    pub fn new<A: std::net::ToSocketAddrs>(
87        addr: A,
88        nodelay: bool,
89        keepalive: Option<Duration>,
90    ) -> Result<Self, crate::BoxError> {
91        let std_listener = std::net::TcpListener::bind(addr)?;
92        std_listener.set_nonblocking(true)?;
93        let listener = tokio::net::TcpListener::from_std(std_listener)?;
94
95        Ok(Self::from_listener(listener, nodelay, keepalive))
96    }
97
98    /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
99    pub fn from_listener(
100        listener: tokio::net::TcpListener,
101        nodelay: bool,
102        keepalive: Option<Duration>,
103    ) -> Self {
104        Self {
105            inner: listener,
106            nodelay,
107            keepalive,
108        }
109    }
110
111    // Consistent with hyper-0.14, this function does not return an error.
112    fn set_accepted_socket_options(&self, stream: &tokio::net::TcpStream) {
113        if self.nodelay {
114            if let Err(e) = stream.set_nodelay(true) {
115                tracing::warn!("error trying to set TCP nodelay: {}", e);
116            }
117        }
118
119        if let Some(timeout) = self.keepalive {
120            let sock_ref = socket2::SockRef::from(&stream);
121            let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
122
123            if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
124                tracing::warn!("error trying to set TCP keepalive: {}", e);
125            }
126        }
127    }
128}
129
130impl Listener for TcpListenerWithOptions {
131    type Io = tokio::net::TcpStream;
132    type Addr = std::net::SocketAddr;
133
134    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
135        let (io, addr) = Listener::accept(&mut self.inner).await;
136        self.set_accepted_socket_options(&io);
137        (io, addr)
138    }
139
140    #[inline]
141    fn local_addr(&self) -> std::io::Result<Self::Addr> {
142        Listener::local_addr(&self.inner)
143    }
144}
145
146// Uncomment once we update tokio to >=1.41.0
147// #[cfg(unix)]
148// impl Listener for tokio::net::UnixListener {
149//     type Io = tokio::net::UnixStream;
150//     type Addr = std::os::unix::net::SocketAddr;
151
152//     async fn accept(&mut self) -> (Self::Io, Self::Addr) {
153//         loop {
154//             match Self::accept(self).await {
155//                 Ok((io, addr)) => return (io, addr.into()),
156//                 Err(e) => handle_accept_error(e).await,
157//             }
158//         }
159//     }
160
161//     #[inline]
162//     fn local_addr(&self) -> std::io::Result<Self::Addr> {
163//         Self::local_addr(self).map(Into::into)
164//     }
165// }
166
167/// Return type of [`ListenerExt::tap_io`].
168///
169/// See that method for details.
170pub struct TapIo<L, F> {
171    listener: L,
172    tap_fn: F,
173}
174
175impl<L, F> std::fmt::Debug for TapIo<L, F>
176where
177    L: Listener + std::fmt::Debug,
178{
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        f.debug_struct("TapIo")
181            .field("listener", &self.listener)
182            .finish_non_exhaustive()
183    }
184}
185
186impl<L, F> Listener for TapIo<L, F>
187where
188    L: Listener,
189    F: FnMut(&mut L::Io) + Send + 'static,
190{
191    type Io = L::Io;
192    type Addr = L::Addr;
193
194    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
195        let (mut io, addr) = self.listener.accept().await;
196        (self.tap_fn)(&mut io);
197        (io, addr)
198    }
199
200    fn local_addr(&self) -> std::io::Result<Self::Addr> {
201        self.listener.local_addr()
202    }
203}
204
205async fn handle_accept_error(e: std::io::Error) {
206    if is_connection_error(&e) {
207        return;
208    }
209
210    // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
211    //
212    // > A possible scenario is that the process has hit the max open files
213    // > allowed, and so trying to accept a new connection will fail with
214    // > `EMFILE`. In some cases, it's preferable to just wait for some time, if
215    // > the application will likely close some files (or connections), and try
216    // > to accept the connection again. If this option is `true`, the error
217    // > will be logged at the `error` level, since it is still a big deal,
218    // > and then the listener will sleep for 1 second.
219    //
220    // hyper allowed customizing this but axum does not.
221    tracing::error!("accept error: {e}");
222    tokio::time::sleep(Duration::from_secs(1)).await;
223}
224
225fn is_connection_error(e: &std::io::Error) -> bool {
226    use std::io::ErrorKind;
227
228    matches!(
229        e.kind(),
230        ErrorKind::ConnectionRefused
231            | ErrorKind::ConnectionAborted
232            | ErrorKind::ConnectionReset
233            | ErrorKind::BrokenPipe
234            | ErrorKind::Interrupted
235            | ErrorKind::WouldBlock
236            | ErrorKind::TimedOut
237    )
238}