1use connection_handler::OnConnectionClose;
5use http::{Request, Response};
6use hyper_util::service::TowerToHyperService;
7use io::ServerIo;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::task::JoinSet;
12use tokio_rustls::TlsAcceptor;
13use tokio_rustls::rustls;
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::body::BoxBody;
18use self::connection_info::ActiveConnections;
19
20pub use http;
21
22pub mod body;
23mod config;
24mod connection_handler;
25mod connection_info;
26mod fuse;
27mod io;
28mod listener;
29
30pub use config::Config;
31pub use listener::Listener;
32pub use listener::ListenerExt;
33
34pub use connection_info::ConnectInfo;
35pub use connection_info::ConnectionId;
36pub use connection_info::ConnectionInfo;
37pub use connection_info::PeerCertificates;
38
39pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
40const ALPN_H2: &[u8] = b"h2";
42const ALPN_H1: &[u8] = b"http/1.1";
44
45#[derive(Default)]
46pub struct Builder {
47 config: Config,
48 tls_config: Option<rustls::ServerConfig>,
49}
50
51impl Builder {
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn config(mut self, config: Config) -> Self {
57 self.config = config;
58 self
59 }
60
61 pub fn tls_single_cert(
66 self,
67 cert_file: impl AsRef<std::path::Path>,
68 private_key_file: impl AsRef<std::path::Path>,
69 ) -> Result<Self, BoxError> {
70 use rustls::pki_types::pem::PemObject;
71 use rustls::pki_types::{CertificateDer, PrivateKeyDer};
72
73 let certs = CertificateDer::pem_file_iter(cert_file)?.collect::<Result<_, _>>()?;
74 let private_key = PrivateKeyDer::from_pem_file(private_key_file)?;
75 let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(
76 rustls::crypto::ring::default_provider(),
77 ))
78 .with_protocol_versions(rustls::DEFAULT_VERSIONS)?
79 .with_no_client_auth()
80 .with_single_cert(certs, private_key)?;
81
82 Ok(self.tls_config(tls_config))
83 }
84
85 pub fn tls_config(mut self, tls_config: rustls::ServerConfig) -> Self {
86 self.tls_config = Some(tls_config);
87 self
88 }
89
90 pub fn serve<A, S, ResponseBody>(
91 self,
92 addr: A,
93 service: S,
94 ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
95 where
96 A: std::net::ToSocketAddrs,
97 S: Service<
98 Request<BoxBody>,
99 Response = Response<ResponseBody>,
100 Error: Into<BoxError>,
101 Future: Send,
102 > + Clone
103 + Send
104 + 'static,
105 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
106 {
107 let listener = listener::TcpListenerWithOptions::new(
108 addr,
109 true,
110 self.config.tcp_keepalive,
111 )?;
112
113 Self::serve_with_listener(self, listener, service)
114 }
115
116 fn serve_with_listener<L, S, ResponseBody>(
117 self,
118 listener: L,
119 service: S,
120 ) -> Result<ServerHandle<L::Addr>, BoxError>
121 where
122 L: Listener,
123 S: Service<
124 Request<BoxBody>,
125 Response = Response<ResponseBody>,
126 Error: Into<BoxError>,
127 Future: Send,
128 > + Clone
129 + Send
130 + 'static,
131 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
132 {
133 let local_addr = listener.local_addr()?;
134 let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
135 let connections = ActiveConnections::default();
136
137 let tls_config = self.tls_config.map(|mut tls| {
138 tls.alpn_protocols.push(ALPN_H2.into());
139 if self.config.accept_http1 {
140 tls.alpn_protocols.push(ALPN_H1.into());
141 }
142 Arc::new(tls)
143 });
144
145 let (watch_sender, watch_reciever) = tokio::sync::watch::channel(());
146 let server = Server {
147 config: self.config,
148 tls_config,
149 listener,
150 local_addr: local_addr.clone(),
151 service: ServiceBuilder::new()
152 .layer(tower::util::BoxCloneService::layer())
153 .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
154 .map_err(Into::into)
155 .service(service),
156 pending_connections: JoinSet::new(),
157 connection_handlers: JoinSet::new(),
158 connections: connections.clone(),
159 graceful_shutdown_token: graceful_shutdown_token.clone(),
160 _watch_reciever: watch_reciever,
161 };
162
163 let handle = ServerHandle(Arc::new(HandleInner {
164 local_addr,
165 connections,
166 graceful_shutdown_token,
167 watch_sender,
168 }));
169
170 tokio::spawn(server.serve());
171
172 Ok(handle)
173 }
174}
175
176#[derive(Debug)]
177pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
178
179#[derive(Debug)]
180struct HandleInner<A = std::net::SocketAddr> {
181 local_addr: A,
183 connections: ActiveConnections<A>,
184 graceful_shutdown_token: tokio_util::sync::CancellationToken,
185 watch_sender: tokio::sync::watch::Sender<()>,
186}
187
188impl<A> ServerHandle<A> {
189 pub fn local_addr(&self) -> &A {
191 &self.0.local_addr
192 }
193
194 pub fn trigger_shutdown(&self) {
197 self.0.graceful_shutdown_token.cancel();
198 }
199
200 pub async fn wait_for_shutdown(&self) {
205 self.0.watch_sender.closed().await
206 }
207
208 pub async fn shutdown(&self) {
210 self.trigger_shutdown();
211 self.wait_for_shutdown().await;
212 }
213
214 pub fn is_shutdown(&self) -> bool {
216 self.0.watch_sender.is_closed()
217 }
218
219 pub fn connections(
220 &self,
221 ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
222 self.0.connections.read().unwrap()
223 }
224
225 pub fn number_of_connections(&self) -> usize {
227 self.connections().len()
228 }
229}
230
231type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
232
233struct Server<L: Listener> {
234 config: Config,
235 tls_config: Option<Arc<rustls::ServerConfig>>,
236
237 listener: L,
238 local_addr: L::Addr,
239 service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
240
241 pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
242 connection_handlers: JoinSet<()>,
243 connections: ActiveConnections<L::Addr>,
244 graceful_shutdown_token: tokio_util::sync::CancellationToken,
245 _watch_reciever: tokio::sync::watch::Receiver<()>,
247}
248
249impl<L> Server<L>
250where
251 L: Listener,
252{
253 async fn serve(mut self) -> Result<(), BoxError> {
254 loop {
255 tokio::select! {
256 _ = self.graceful_shutdown_token.cancelled() => {
257 trace!("signal received, shutting down");
258 break;
259 },
260 (io, remote_addr) = self.listener.accept() => {
261 self.handle_incomming(io, remote_addr);
262 },
263 Some(maybe_connection) = self.pending_connections.join_next() => {
264 let (io, remote_addr) = match maybe_connection.unwrap() {
266 Ok((io, remote_addr)) => {
267 (io, remote_addr)
268 }
269 Err(e) => {
270 tracing::debug!(error = %e, "error accepting connection");
271 continue;
272 }
273 };
274
275 trace!("connection accepted");
276 self.handle_connection(io, remote_addr);
277 },
278 Some(connection_handler_output) = self.connection_handlers.join_next() => {
279 let _: () = connection_handler_output.unwrap();
281 },
282 }
283 }
284
285 self.shutdown().await;
287
288 Ok(())
289 }
290
291 fn handle_incomming(&mut self, io: L::Io, remote_addr: L::Addr) {
292 if let Some(tls) = self.tls_config.clone() {
293 let tls_acceptor = TlsAcceptor::from(tls);
294 let allow_insecure = self.config.allow_insecure;
295 self.pending_connections.spawn(async move {
296 if allow_insecure {
297 if let Some(tcp) =
300 <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
301 {
302 let mut buf = [0; 1];
304 tcp.peek(&mut buf).await?;
307 if buf != [0x16] {
310 tracing::trace!("accepting insecure connection");
311 return Ok((ServerIo::new_io(io), remote_addr));
312 }
313 } else {
314 tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
315 }
316 }
317
318 tracing::trace!("accepting TLS connection");
319 let io = tls_acceptor.accept(io).await?;
320 Ok((ServerIo::new_tls_io(io), remote_addr))
321 });
322 } else {
323 self.handle_connection(ServerIo::new_io(io), remote_addr);
324 }
325 }
326
327 fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
328 let connection_shutdown_token = self.graceful_shutdown_token.child_token();
329 let connection_info = ConnectionInfo::new(
330 remote_addr,
331 io.peer_certs(),
332 connection_shutdown_token.clone(),
333 );
334 let connection_id = connection_info.id();
335 let connect_info = connection_info::ConnectInfo {
336 local_addr: self.local_addr.clone(),
337 remote_addr: connection_info.remote_address().clone(),
338 };
339 let peer_certificates = connection_info.peer_certificates().cloned();
340 let hyper_io = hyper_util::rt::TokioIo::new(io);
341
342 let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
343 move |mut request: Request<hyper::body::Incoming>| {
344 request.extensions_mut().insert(connect_info.clone());
345 if let Some(peer_certificates) = peer_certificates.clone() {
346 request.extensions_mut().insert(peer_certificates);
347 }
348
349 request.map(body::boxed)
350 },
351 ));
352
353 self.connections
354 .write()
355 .unwrap()
356 .insert(connection_id, connection_info);
357 let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
358
359 self.connection_handlers
360 .spawn(connection_handler::serve_connection(
361 hyper_io,
362 hyper_svc,
363 self.config.connection_builder(),
364 connection_shutdown_token,
365 self.config.max_connection_age,
366 on_connection_close,
367 ));
368 }
369
370 async fn shutdown(mut self) {
371 const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
374
375 self.graceful_shutdown_token.cancel();
377
378 self.pending_connections.shutdown().await;
380
381 trace!(
383 "waiting for {} connections to close",
384 self.connection_handlers.len()
385 );
386
387 let graceful_shutdown =
388 async { while self.connection_handlers.join_next().await.is_some() {} };
389
390 if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
391 .await
392 .is_err()
393 {
394 tracing::warn!(
395 "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
396 CONNECTION_SHUTDOWN_GRACE_PERIOD
397 );
398 self.connection_handlers.shutdown().await;
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use axum::Router;
407
408 #[tokio::test]
409 async fn simple() {
410 const MESSAGE: &str = "Hello, World!";
411
412 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
413
414 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
415
416 let url = format!("http://{}", handle.local_addr());
417
418 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
419
420 assert_eq!(response, MESSAGE.as_bytes());
421 }
422
423 #[tokio::test]
424 async fn shutdown() {
425 const MESSAGE: &str = "Hello, World!";
426
427 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
428
429 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
430
431 let url = format!("http://{}", handle.local_addr());
432
433 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
434
435 assert_eq!(handle.connections().len(), 1);
437
438 assert_eq!(response, MESSAGE.as_bytes());
439
440 assert!(!handle.is_shutdown());
441
442 handle.shutdown().await;
443
444 assert!(handle.is_shutdown());
445
446 assert_eq!(handle.connections().len(), 0);
448 }
449}