1use connection_handler::OnConnectionClose;
5use http::{Request, Response};
6use hyper_util::service::TowerToHyperService;
7use io::ServerIo;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::task::JoinSet;
12use tokio_rustls::TlsAcceptor;
13use tokio_rustls::rustls;
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::body::BoxBody;
18use self::connection_info::ActiveConnections;
19
20pub use http;
21
22pub mod body;
23mod config;
24mod connection_handler;
25mod connection_info;
26mod fuse;
27mod io;
28mod listener;
29
30pub use config::Config;
31pub use listener::Listener;
32pub use listener::ListenerExt;
33
34pub use connection_info::ConnectInfo;
35pub use connection_info::ConnectionId;
36pub use connection_info::ConnectionInfo;
37pub use connection_info::PeerCertificates;
38
39pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
40const ALPN_H2: &[u8] = b"h2";
42const ALPN_H1: &[u8] = b"http/1.1";
44
45#[derive(Default)]
46pub struct Builder {
47 config: Config,
48 tls_config: Option<rustls::ServerConfig>,
49}
50
51impl Builder {
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn config(mut self, config: Config) -> Self {
57 self.config = config;
58 self
59 }
60
61 pub fn tls_single_cert(
66 self,
67 cert_file: impl AsRef<std::path::Path>,
68 private_key_file: impl AsRef<std::path::Path>,
69 ) -> Result<Self, BoxError> {
70 use rustls::pki_types::pem::PemObject;
71 use rustls::pki_types::{CertificateDer, PrivateKeyDer};
72
73 let certs = CertificateDer::pem_file_iter(cert_file)?.collect::<Result<_, _>>()?;
74 let private_key = PrivateKeyDer::from_pem_file(private_key_file)?;
75 let tls_config = rustls::ServerConfig::builder()
76 .with_no_client_auth()
77 .with_single_cert(certs, private_key)?;
78
79 Ok(self.tls_config(tls_config))
80 }
81
82 pub fn tls_config(mut self, tls_config: rustls::ServerConfig) -> Self {
83 self.tls_config = Some(tls_config);
84 self
85 }
86
87 pub fn serve<A, S, ResponseBody>(
88 self,
89 addr: A,
90 service: S,
91 ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
92 where
93 A: std::net::ToSocketAddrs,
94 S: Service<
95 Request<BoxBody>,
96 Response = Response<ResponseBody>,
97 Error: Into<BoxError>,
98 Future: Send,
99 > + Clone
100 + Send
101 + 'static,
102 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
103 {
104 let listener = listener::TcpListenerWithOptions::new(
105 addr,
106 true,
107 self.config.tcp_keepalive,
108 )?;
109
110 Self::serve_with_listener(self, listener, service)
111 }
112
113 fn serve_with_listener<L, S, ResponseBody>(
114 self,
115 listener: L,
116 service: S,
117 ) -> Result<ServerHandle<L::Addr>, BoxError>
118 where
119 L: Listener,
120 S: Service<
121 Request<BoxBody>,
122 Response = Response<ResponseBody>,
123 Error: Into<BoxError>,
124 Future: Send,
125 > + Clone
126 + Send
127 + 'static,
128 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
129 {
130 let local_addr = listener.local_addr()?;
131 let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
132 let connections = ActiveConnections::default();
133
134 let tls_config = self.tls_config.map(|mut tls| {
135 tls.alpn_protocols.push(ALPN_H2.into());
136 if self.config.accept_http1 {
137 tls.alpn_protocols.push(ALPN_H1.into());
138 }
139 Arc::new(tls)
140 });
141
142 let (watch_sender, watch_reciever) = tokio::sync::watch::channel(());
143 let server = Server {
144 config: self.config,
145 tls_config,
146 listener,
147 local_addr: local_addr.clone(),
148 service: ServiceBuilder::new()
149 .layer(tower::util::BoxCloneService::layer())
150 .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
151 .map_err(Into::into)
152 .service(service),
153 pending_connections: JoinSet::new(),
154 connection_handlers: JoinSet::new(),
155 connections: connections.clone(),
156 graceful_shutdown_token: graceful_shutdown_token.clone(),
157 _watch_reciever: watch_reciever,
158 };
159
160 let handle = ServerHandle(Arc::new(HandleInner {
161 local_addr,
162 connections,
163 graceful_shutdown_token,
164 watch_sender,
165 }));
166
167 tokio::spawn(server.serve());
168
169 Ok(handle)
170 }
171}
172
173#[derive(Debug)]
174pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
175
176#[derive(Debug)]
177struct HandleInner<A = std::net::SocketAddr> {
178 local_addr: A,
180 connections: ActiveConnections<A>,
181 graceful_shutdown_token: tokio_util::sync::CancellationToken,
182 watch_sender: tokio::sync::watch::Sender<()>,
183}
184
185impl<A> ServerHandle<A> {
186 pub fn local_addr(&self) -> &A {
188 &self.0.local_addr
189 }
190
191 pub fn trigger_shutdown(&self) {
194 self.0.graceful_shutdown_token.cancel();
195 }
196
197 pub async fn wait_for_shutdown(&self) {
202 self.0.watch_sender.closed().await
203 }
204
205 pub async fn shutdown(&self) {
207 self.trigger_shutdown();
208 self.wait_for_shutdown().await;
209 }
210
211 pub fn is_shutdown(&self) -> bool {
213 self.0.watch_sender.is_closed()
214 }
215
216 pub fn connections(
217 &self,
218 ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
219 self.0.connections.read().unwrap()
220 }
221
222 pub fn number_of_connections(&self) -> usize {
224 self.connections().len()
225 }
226}
227
228type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
229
230struct Server<L: Listener> {
231 config: Config,
232 tls_config: Option<Arc<rustls::ServerConfig>>,
233
234 listener: L,
235 local_addr: L::Addr,
236 service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
237
238 pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
239 connection_handlers: JoinSet<()>,
240 connections: ActiveConnections<L::Addr>,
241 graceful_shutdown_token: tokio_util::sync::CancellationToken,
242 _watch_reciever: tokio::sync::watch::Receiver<()>,
244}
245
246impl<L> Server<L>
247where
248 L: Listener,
249{
250 async fn serve(mut self) -> Result<(), BoxError> {
251 loop {
252 tokio::select! {
253 _ = self.graceful_shutdown_token.cancelled() => {
254 trace!("signal received, shutting down");
255 break;
256 },
257 (io, remote_addr) = self.listener.accept() => {
258 self.handle_incomming(io, remote_addr);
259 },
260 Some(maybe_connection) = self.pending_connections.join_next() => {
261 let (io, remote_addr) = match maybe_connection.unwrap() {
263 Ok((io, remote_addr)) => {
264 (io, remote_addr)
265 }
266 Err(e) => {
267 tracing::debug!(error = %e, "error accepting connection");
268 continue;
269 }
270 };
271
272 trace!("connection accepted");
273 self.handle_connection(io, remote_addr);
274 },
275 Some(connection_handler_output) = self.connection_handlers.join_next() => {
276 let _: () = connection_handler_output.unwrap();
278 },
279 }
280 }
281
282 self.shutdown().await;
284
285 Ok(())
286 }
287
288 fn handle_incomming(&mut self, io: L::Io, remote_addr: L::Addr) {
289 if let Some(tls) = self.tls_config.clone() {
290 let tls_acceptor = TlsAcceptor::from(tls);
291 let allow_insecure = self.config.allow_insecure;
292 self.pending_connections.spawn(async move {
293 if allow_insecure {
294 if let Some(tcp) =
297 <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
298 {
299 let mut buf = [0; 1];
301 tcp.peek(&mut buf).await?;
304 if buf != [0x16] {
307 tracing::trace!("accepting insecure connection");
308 return Ok((ServerIo::new_io(io), remote_addr));
309 }
310 } else {
311 tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
312 }
313 }
314
315 tracing::trace!("accepting TLS connection");
316 let io = tls_acceptor.accept(io).await?;
317 Ok((ServerIo::new_tls_io(io), remote_addr))
318 });
319 } else {
320 self.handle_connection(ServerIo::new_io(io), remote_addr);
321 }
322 }
323
324 fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
325 let connection_shutdown_token = self.graceful_shutdown_token.child_token();
326 let connection_info = ConnectionInfo::new(
327 remote_addr,
328 io.peer_certs(),
329 connection_shutdown_token.clone(),
330 );
331 let connection_id = connection_info.id();
332 let connect_info = connection_info::ConnectInfo {
333 local_addr: self.local_addr.clone(),
334 remote_addr: connection_info.remote_address().clone(),
335 };
336 let peer_certificates = connection_info.peer_certificates().cloned();
337 let hyper_io = hyper_util::rt::TokioIo::new(io);
338
339 let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
340 move |mut request: Request<hyper::body::Incoming>| {
341 request.extensions_mut().insert(connect_info.clone());
342 if let Some(peer_certificates) = peer_certificates.clone() {
343 request.extensions_mut().insert(peer_certificates);
344 }
345
346 request.map(body::boxed)
347 },
348 ));
349
350 self.connections
351 .write()
352 .unwrap()
353 .insert(connection_id, connection_info);
354 let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
355
356 self.connection_handlers
357 .spawn(connection_handler::serve_connection(
358 hyper_io,
359 hyper_svc,
360 self.config.connection_builder(),
361 connection_shutdown_token,
362 self.config.max_connection_age,
363 on_connection_close,
364 ));
365 }
366
367 async fn shutdown(mut self) {
368 const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
371
372 self.graceful_shutdown_token.cancel();
374
375 self.pending_connections.shutdown().await;
377
378 trace!(
380 "waiting for {} connections to close",
381 self.connection_handlers.len()
382 );
383
384 let graceful_shutdown =
385 async { while self.connection_handlers.join_next().await.is_some() {} };
386
387 if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
388 .await
389 .is_err()
390 {
391 tracing::warn!(
392 "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
393 CONNECTION_SHUTDOWN_GRACE_PERIOD
394 );
395 self.connection_handlers.shutdown().await;
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use axum::Router;
404
405 #[tokio::test]
406 async fn simple() {
407 const MESSAGE: &str = "Hello, World!";
408
409 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
410
411 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
412
413 let url = format!("http://{}", handle.local_addr());
414
415 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
416
417 assert_eq!(response, MESSAGE.as_bytes());
418 }
419
420 #[tokio::test]
421 async fn shutdown() {
422 const MESSAGE: &str = "Hello, World!";
423
424 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
425
426 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
427
428 let url = format!("http://{}", handle.local_addr());
429
430 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
431
432 assert_eq!(handle.connections().len(), 1);
434
435 assert_eq!(response, MESSAGE.as_bytes());
436
437 assert!(!handle.is_shutdown());
438
439 handle.shutdown().await;
440
441 assert!(handle.is_shutdown());
442
443 assert_eq!(handle.connections().len(), 0);
445 }
446}