1use http::Request;
5use http::Response;
6use hyper_util::service::TowerToHyperService;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::task::JoinSet;
11use tokio_rustls::TlsAcceptor;
12use tower::Service;
13use tower::ServiceBuilder;
14use tower::ServiceExt;
15use tracing::trace;
16
17use self::body::BoxBody;
18use self::connection_info::ActiveConnections;
19use self::io::ServerIo;
20
21pub use bytes;
22pub use http;
23pub use tokio_rustls::rustls;
24
25pub mod body;
26mod config;
27mod connection_handler;
28mod connection_info;
29mod fuse;
30mod io;
31mod listener;
32pub mod middleware;
33
34pub use config::Config;
35pub use listener::Listener;
36pub use listener::ListenerExt;
37
38pub use connection_info::ConnectInfo;
39pub use connection_info::ConnectionId;
40pub use connection_info::ConnectionInfo;
41pub use connection_info::PeerCertificates;
42
43pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
44const ALPN_H2: &[u8] = b"h2";
46const ALPN_H1: &[u8] = b"http/1.1";
48
49#[derive(Default)]
50pub struct Builder {
51 config: Config,
52 tls_config: Option<rustls::ServerConfig>,
53}
54
55impl Builder {
56 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn config(mut self, config: Config) -> Self {
61 self.config = config;
62 self
63 }
64
65 pub fn tls_single_cert(
70 self,
71 cert_file: impl AsRef<std::path::Path>,
72 private_key_file: impl AsRef<std::path::Path>,
73 ) -> Result<Self, BoxError> {
74 use rustls::pki_types::CertificateDer;
75 use rustls::pki_types::PrivateKeyDer;
76 use rustls::pki_types::pem::PemObject;
77
78 let certs = CertificateDer::pem_file_iter(cert_file)?.collect::<Result<_, _>>()?;
79 let private_key = PrivateKeyDer::from_pem_file(private_key_file)?;
80 let tls_config = rustls::ServerConfig::builder()
81 .with_no_client_auth()
82 .with_single_cert(certs, private_key)?;
83
84 Ok(self.tls_config(tls_config))
85 }
86
87 pub fn tls_config(mut self, tls_config: rustls::ServerConfig) -> Self {
88 self.tls_config = Some(tls_config);
89 self
90 }
91
92 pub fn serve<A, S, ResponseBody>(
93 self,
94 addr: A,
95 service: S,
96 ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
97 where
98 A: std::net::ToSocketAddrs,
99 S: Service<
100 Request<BoxBody>,
101 Response = Response<ResponseBody>,
102 Error: Into<BoxError>,
103 Future: Send,
104 > + Clone
105 + Send
106 + 'static,
107 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
108 {
109 let listener = listener::TcpListenerWithOptions::new(
110 addr,
111 self.config.tcp_nodelay,
112 self.config.tcp_keepalive,
113 )?;
114
115 Self::serve_with_listener(self, listener, service)
116 }
117
118 fn serve_with_listener<L, S, ResponseBody>(
119 self,
120 listener: L,
121 service: S,
122 ) -> Result<ServerHandle<L::Addr>, BoxError>
123 where
124 L: Listener,
125 S: Service<
126 Request<BoxBody>,
127 Response = Response<ResponseBody>,
128 Error: Into<BoxError>,
129 Future: Send,
130 > + Clone
131 + Send
132 + 'static,
133 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
134 {
135 let local_addr = listener.local_addr()?;
136 let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
137 let connections = ActiveConnections::default();
138
139 let tls_config = self.tls_config.map(|mut tls| {
140 tls.alpn_protocols.push(ALPN_H2.into());
141 if self.config.accept_http1 {
142 tls.alpn_protocols.push(ALPN_H1.into());
143 }
144 Arc::new(tls)
145 });
146
147 let (watch_sender, watch_reciever) = tokio::sync::watch::channel(());
148 let server = Server {
149 config: self.config,
150 tls_config,
151 listener,
152 local_addr: local_addr.clone(),
153 service: ServiceBuilder::new()
154 .layer(tower::util::BoxCloneService::layer())
155 .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
156 .map_err(Into::into)
157 .service(service),
158 pending_connections: JoinSet::new(),
159 connection_handlers: JoinSet::new(),
160 connections: connections.clone(),
161 graceful_shutdown_token: graceful_shutdown_token.clone(),
162 _watch_reciever: watch_reciever,
163 };
164
165 let handle = ServerHandle(Arc::new(HandleInner {
166 local_addr,
167 connections,
168 graceful_shutdown_token,
169 watch_sender,
170 }));
171
172 tokio::spawn(server.serve());
173
174 Ok(handle)
175 }
176}
177
178#[derive(Debug)]
179pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
180
181#[derive(Debug)]
182struct HandleInner<A = std::net::SocketAddr> {
183 local_addr: A,
185 connections: ActiveConnections<A>,
186 graceful_shutdown_token: tokio_util::sync::CancellationToken,
187 watch_sender: tokio::sync::watch::Sender<()>,
188}
189
190impl<A> ServerHandle<A> {
191 pub fn local_addr(&self) -> &A {
193 &self.0.local_addr
194 }
195
196 pub fn trigger_shutdown(&self) {
199 self.0.graceful_shutdown_token.cancel();
200 }
201
202 pub async fn wait_for_shutdown(&self) {
207 self.0.watch_sender.closed().await
208 }
209
210 pub async fn shutdown(&self) {
212 self.trigger_shutdown();
213 self.wait_for_shutdown().await;
214 }
215
216 pub fn is_shutdown(&self) -> bool {
218 self.0.watch_sender.is_closed()
219 }
220
221 pub fn connections(
222 &self,
223 ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
224 self.0.connections.read().unwrap()
225 }
226
227 pub fn number_of_connections(&self) -> usize {
229 self.connections().len()
230 }
231}
232
233type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
234
235struct Server<L: Listener> {
236 config: Config,
237 tls_config: Option<Arc<rustls::ServerConfig>>,
238
239 listener: L,
240 local_addr: L::Addr,
241 service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
242
243 pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
244 connection_handlers: JoinSet<()>,
245 connections: ActiveConnections<L::Addr>,
246 graceful_shutdown_token: tokio_util::sync::CancellationToken,
247 _watch_reciever: tokio::sync::watch::Receiver<()>,
249}
250
251impl<L> Server<L>
252where
253 L: Listener,
254{
255 async fn serve(mut self) -> Result<(), BoxError> {
256 loop {
257 tokio::select! {
258 _ = self.graceful_shutdown_token.cancelled() => {
259 trace!("signal received, shutting down");
260 break;
261 },
262 (io, remote_addr) = self.listener.accept() => {
263 self.handle_incomming(io, remote_addr);
264 },
265 Some(maybe_connection) = self.pending_connections.join_next() => {
266 let (io, remote_addr) = match maybe_connection.unwrap() {
268 Ok((io, remote_addr)) => {
269 (io, remote_addr)
270 }
271 Err(e) => {
272 tracing::debug!(error = %e, "error accepting connection");
273 continue;
274 }
275 };
276
277 trace!("connection accepted");
278 self.handle_connection(io, remote_addr);
279 },
280 Some(connection_handler_output) = self.connection_handlers.join_next() => {
281 let _: () = connection_handler_output.unwrap();
283 },
284 }
285 }
286
287 self.shutdown().await;
289
290 Ok(())
291 }
292
293 fn handle_incomming(&mut self, io: L::Io, remote_addr: L::Addr) {
294 if let Some(tls) = self.tls_config.clone() {
295 let tls_acceptor = TlsAcceptor::from(tls);
296 let allow_insecure = self.config.allow_insecure;
297 self.pending_connections.spawn(async move {
298 if allow_insecure {
299 if let Some(tcp) =
302 <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
303 {
304 let mut buf = [0; 1];
306 tcp.peek(&mut buf).await?;
309 if buf != [0x16] {
312 tracing::trace!("accepting insecure connection");
313 return Ok((ServerIo::new_io(io), remote_addr));
314 }
315 } else {
316 tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
317 }
318 }
319
320 tracing::trace!("accepting TLS connection");
321 let io = tls_acceptor.accept(io).await?;
322 Ok((ServerIo::new_tls_io(io), remote_addr))
323 });
324 } else {
325 self.handle_connection(ServerIo::new_io(io), remote_addr);
326 }
327 }
328
329 fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
330 let connection_shutdown_token = self.graceful_shutdown_token.child_token();
331 let connection_info = ConnectionInfo::new(
332 remote_addr,
333 io.peer_certs(),
334 connection_shutdown_token.clone(),
335 );
336 let connection_id = connection_info.id();
337 let connect_info = connection_info::ConnectInfo {
338 local_addr: self.local_addr.clone(),
339 remote_addr: connection_info.remote_address().clone(),
340 };
341 let peer_certificates = connection_info.peer_certificates().cloned();
342 let hyper_io = hyper_util::rt::TokioIo::new(io);
343
344 let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
345 move |mut request: Request<hyper::body::Incoming>| {
346 request.extensions_mut().insert(connect_info.clone());
347 if let Some(peer_certificates) = peer_certificates.clone() {
348 request.extensions_mut().insert(peer_certificates);
349 }
350
351 request.map(body::boxed)
352 },
353 ));
354
355 self.connections
356 .write()
357 .unwrap()
358 .insert(connection_id, connection_info);
359 let on_connection_close =
360 connection_handler::OnConnectionClose::new(connection_id, self.connections.clone());
361
362 self.connection_handlers
363 .spawn(connection_handler::serve_connection(
364 hyper_io,
365 hyper_svc,
366 self.config.connection_builder(),
367 connection_shutdown_token,
368 self.config.max_connection_age,
369 on_connection_close,
370 ));
371 }
372
373 async fn shutdown(mut self) {
374 const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
377
378 self.graceful_shutdown_token.cancel();
380
381 self.pending_connections.shutdown().await;
383
384 trace!(
386 "waiting for {} connections to close",
387 self.connection_handlers.len()
388 );
389
390 let graceful_shutdown =
391 async { while self.connection_handlers.join_next().await.is_some() {} };
392
393 if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
394 .await
395 .is_err()
396 {
397 tracing::warn!(
398 "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
399 CONNECTION_SHUTDOWN_GRACE_PERIOD
400 );
401 self.connection_handlers.shutdown().await;
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use axum::Router;
410
411 #[tokio::test]
412 async fn simple() {
413 const MESSAGE: &str = "Hello, World!";
414
415 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
416
417 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
418
419 let url = format!("http://{}", handle.local_addr());
420
421 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
422
423 assert_eq!(response, MESSAGE.as_bytes());
424 }
425
426 #[tokio::test]
427 async fn shutdown() {
428 const MESSAGE: &str = "Hello, World!";
429
430 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
431
432 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
433
434 let url = format!("http://{}", handle.local_addr());
435
436 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
437
438 assert_eq!(handle.connections().len(), 1);
440
441 assert_eq!(response, MESSAGE.as_bytes());
442
443 assert!(!handle.is_shutdown());
444
445 handle.shutdown().await;
446
447 assert!(handle.is_shutdown());
448
449 assert_eq!(handle.connections().len(), 0);
451 }
452}