1use std::time::Duration;
5
6pub trait Listener: Send + 'static {
8 type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static;
10
11 type Addr: Clone + Send + Sync + 'static;
14
15 fn accept(&mut self) -> impl std::future::Future<Output = (Self::Io, Self::Addr)> + Send;
20
21 fn local_addr(&self) -> std::io::Result<Self::Addr>;
23}
24
25pub trait ListenerExt: Listener + Sized {
27 fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
47 where
48 F: FnMut(&mut Self::Io) + Send + 'static,
49 {
50 TapIo {
51 listener: self,
52 tap_fn,
53 }
54 }
55}
56
57impl<L: Listener> ListenerExt for L {}
58
59impl Listener for tokio::net::TcpListener {
60 type Io = tokio::net::TcpStream;
61 type Addr = std::net::SocketAddr;
62
63 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
64 loop {
65 match Self::accept(self).await {
66 Ok(tup) => return tup,
67 Err(e) => handle_accept_error(e).await,
68 }
69 }
70 }
71
72 #[inline]
73 fn local_addr(&self) -> std::io::Result<Self::Addr> {
74 Self::local_addr(self)
75 }
76}
77
78#[derive(Debug)]
79pub struct TcpListenerWithOptions {
80 inner: tokio::net::TcpListener,
81 nodelay: bool,
82 keepalive: Option<Duration>,
83}
84
85impl TcpListenerWithOptions {
86 pub fn new<A: std::net::ToSocketAddrs>(
87 addr: A,
88 nodelay: bool,
89 keepalive: Option<Duration>,
90 ) -> Result<Self, crate::BoxError> {
91 let std_listener = std::net::TcpListener::bind(addr)?;
92 std_listener.set_nonblocking(true)?;
93 let listener = tokio::net::TcpListener::from_std(std_listener)?;
94
95 Ok(Self::from_listener(listener, nodelay, keepalive))
96 }
97
98 pub fn from_listener(
100 listener: tokio::net::TcpListener,
101 nodelay: bool,
102 keepalive: Option<Duration>,
103 ) -> Self {
104 Self {
105 inner: listener,
106 nodelay,
107 keepalive,
108 }
109 }
110
111 fn set_accepted_socket_options(&self, stream: &tokio::net::TcpStream) {
113 if self.nodelay {
114 if let Err(e) = stream.set_nodelay(true) {
115 tracing::warn!("error trying to set TCP nodelay: {}", e);
116 }
117 }
118
119 if let Some(timeout) = self.keepalive {
120 let sock_ref = socket2::SockRef::from(&stream);
121 let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
122
123 if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
124 tracing::warn!("error trying to set TCP keepalive: {}", e);
125 }
126 }
127 }
128}
129
130impl Listener for TcpListenerWithOptions {
131 type Io = tokio::net::TcpStream;
132 type Addr = std::net::SocketAddr;
133
134 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
135 let (io, addr) = Listener::accept(&mut self.inner).await;
136 self.set_accepted_socket_options(&io);
137 (io, addr)
138 }
139
140 #[inline]
141 fn local_addr(&self) -> std::io::Result<Self::Addr> {
142 Listener::local_addr(&self.inner)
143 }
144}
145
146pub struct TapIo<L, F> {
171 listener: L,
172 tap_fn: F,
173}
174
175impl<L, F> std::fmt::Debug for TapIo<L, F>
176where
177 L: Listener + std::fmt::Debug,
178{
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 f.debug_struct("TapIo")
181 .field("listener", &self.listener)
182 .finish_non_exhaustive()
183 }
184}
185
186impl<L, F> Listener for TapIo<L, F>
187where
188 L: Listener,
189 F: FnMut(&mut L::Io) + Send + 'static,
190{
191 type Io = L::Io;
192 type Addr = L::Addr;
193
194 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
195 let (mut io, addr) = self.listener.accept().await;
196 (self.tap_fn)(&mut io);
197 (io, addr)
198 }
199
200 fn local_addr(&self) -> std::io::Result<Self::Addr> {
201 self.listener.local_addr()
202 }
203}
204
205async fn handle_accept_error(e: std::io::Error) {
206 if is_connection_error(&e) {
207 return;
208 }
209
210 tracing::error!("accept error: {e}");
222 tokio::time::sleep(Duration::from_secs(1)).await;
223}
224
225fn is_connection_error(e: &std::io::Error) -> bool {
226 use std::io::ErrorKind;
227
228 matches!(
229 e.kind(),
230 ErrorKind::ConnectionRefused
231 | ErrorKind::ConnectionAborted
232 | ErrorKind::ConnectionReset
233 | ErrorKind::BrokenPipe
234 | ErrorKind::Interrupted
235 | ErrorKind::WouldBlock
236 | ErrorKind::TimedOut
237 )
238}