sui_http/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use http::Request;
5use http::Response;
6use hyper_util::service::TowerToHyperService;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::task::JoinSet;
11use tokio_rustls::TlsAcceptor;
12use tower::Service;
13use tower::ServiceBuilder;
14use tower::ServiceExt;
15use tracing::trace;
16
17use self::body::BoxBody;
18use self::connection_info::ActiveConnections;
19use self::io::ServerIo;
20
21pub use bytes;
22pub use http;
23pub use tokio_rustls::rustls;
24
25pub mod body;
26mod config;
27mod connection_handler;
28mod connection_info;
29mod fuse;
30mod io;
31mod listener;
32pub mod middleware;
33
34pub use config::Config;
35pub use listener::Listener;
36pub use listener::ListenerExt;
37
38pub use connection_info::ConnectInfo;
39pub use connection_info::ConnectionId;
40pub use connection_info::ConnectionInfo;
41pub use connection_info::PeerCertificates;
42
43pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
44/// h2 alpn in plain format for rustls.
45const ALPN_H2: &[u8] = b"h2";
46/// h1 alpn in plain format for rustls.
47const ALPN_H1: &[u8] = b"http/1.1";
48
49#[derive(Default)]
50pub struct Builder {
51    config: Config,
52    tls_config: Option<rustls::ServerConfig>,
53}
54
55impl Builder {
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    pub fn config(mut self, config: Config) -> Self {
61        self.config = config;
62        self
63    }
64
65    // Convenience method for configuring TLS with a single server cert
66    //
67    // Attempts to load PEM formatted files for the certificate chain and private key material from
68    // the provided file system paths.
69    pub fn tls_single_cert(
70        self,
71        cert_file: impl AsRef<std::path::Path>,
72        private_key_file: impl AsRef<std::path::Path>,
73    ) -> Result<Self, BoxError> {
74        use rustls::pki_types::CertificateDer;
75        use rustls::pki_types::PrivateKeyDer;
76        use rustls::pki_types::pem::PemObject;
77
78        let certs = CertificateDer::pem_file_iter(cert_file)?.collect::<Result<_, _>>()?;
79        let private_key = PrivateKeyDer::from_pem_file(private_key_file)?;
80        let tls_config = rustls::ServerConfig::builder()
81            .with_no_client_auth()
82            .with_single_cert(certs, private_key)?;
83
84        Ok(self.tls_config(tls_config))
85    }
86
87    pub fn tls_config(mut self, tls_config: rustls::ServerConfig) -> Self {
88        self.tls_config = Some(tls_config);
89        self
90    }
91
92    pub fn serve<A, S, ResponseBody>(
93        self,
94        addr: A,
95        service: S,
96    ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
97    where
98        A: std::net::ToSocketAddrs,
99        S: Service<
100                Request<BoxBody>,
101                Response = Response<ResponseBody>,
102                Error: Into<BoxError>,
103                Future: Send,
104            > + Clone
105            + Send
106            + 'static,
107        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
108    {
109        let listener = listener::TcpListenerWithOptions::new(
110            addr,
111            self.config.tcp_nodelay,
112            self.config.tcp_keepalive,
113        )?;
114
115        Self::serve_with_listener(self, listener, service)
116    }
117
118    fn serve_with_listener<L, S, ResponseBody>(
119        self,
120        listener: L,
121        service: S,
122    ) -> Result<ServerHandle<L::Addr>, BoxError>
123    where
124        L: Listener,
125        S: Service<
126                Request<BoxBody>,
127                Response = Response<ResponseBody>,
128                Error: Into<BoxError>,
129                Future: Send,
130            > + Clone
131            + Send
132            + 'static,
133        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
134    {
135        let local_addr = listener.local_addr()?;
136        let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
137        let connections = ActiveConnections::default();
138
139        let tls_config = self.tls_config.map(|mut tls| {
140            tls.alpn_protocols.push(ALPN_H2.into());
141            if self.config.accept_http1 {
142                tls.alpn_protocols.push(ALPN_H1.into());
143            }
144            Arc::new(tls)
145        });
146
147        let (watch_sender, watch_reciever) = tokio::sync::watch::channel(());
148        let server = Server {
149            config: self.config,
150            tls_config,
151            listener,
152            local_addr: local_addr.clone(),
153            service: ServiceBuilder::new()
154                .layer(tower::util::BoxCloneService::layer())
155                .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
156                .map_err(Into::into)
157                .service(service),
158            pending_connections: JoinSet::new(),
159            connection_handlers: JoinSet::new(),
160            connections: connections.clone(),
161            graceful_shutdown_token: graceful_shutdown_token.clone(),
162            _watch_reciever: watch_reciever,
163        };
164
165        let handle = ServerHandle(Arc::new(HandleInner {
166            local_addr,
167            connections,
168            graceful_shutdown_token,
169            watch_sender,
170        }));
171
172        tokio::spawn(server.serve());
173
174        Ok(handle)
175    }
176}
177
178#[derive(Debug)]
179pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
180
181#[derive(Debug)]
182struct HandleInner<A = std::net::SocketAddr> {
183    /// The local address of the server.
184    local_addr: A,
185    connections: ActiveConnections<A>,
186    graceful_shutdown_token: tokio_util::sync::CancellationToken,
187    watch_sender: tokio::sync::watch::Sender<()>,
188}
189
190impl<A> ServerHandle<A> {
191    /// Returns the local address of the server
192    pub fn local_addr(&self) -> &A {
193        &self.0.local_addr
194    }
195
196    /// Trigger a graceful shutdown of the server, but don't wait till the server has completed
197    /// shutting down
198    pub fn trigger_shutdown(&self) {
199        self.0.graceful_shutdown_token.cancel();
200    }
201
202    /// Completes once the network has been shutdown.
203    ///
204    /// This explicitly *does not* trigger the network to shutdown, see `trigger_shutdown` or
205    /// `shutdown` if you want to trigger shutting down the server.
206    pub async fn wait_for_shutdown(&self) {
207        self.0.watch_sender.closed().await
208    }
209
210    /// Triggers a shutdown of the server and waits for it to complete shutting down.
211    pub async fn shutdown(&self) {
212        self.trigger_shutdown();
213        self.wait_for_shutdown().await;
214    }
215
216    /// Checks if the Server has been shutdown.
217    pub fn is_shutdown(&self) -> bool {
218        self.0.watch_sender.is_closed()
219    }
220
221    pub fn connections(
222        &self,
223    ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
224        self.0.connections.read().unwrap()
225    }
226
227    /// Returns the number of active connections the server is handling
228    pub fn number_of_connections(&self) -> usize {
229        self.connections().len()
230    }
231}
232
233type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
234
235struct Server<L: Listener> {
236    config: Config,
237    tls_config: Option<Arc<rustls::ServerConfig>>,
238
239    listener: L,
240    local_addr: L::Addr,
241    service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
242
243    pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
244    connection_handlers: JoinSet<()>,
245    connections: ActiveConnections<L::Addr>,
246    graceful_shutdown_token: tokio_util::sync::CancellationToken,
247    // Used to signal to a ServerHandle when the server has completed shutting down
248    _watch_reciever: tokio::sync::watch::Receiver<()>,
249}
250
251impl<L> Server<L>
252where
253    L: Listener,
254{
255    async fn serve(mut self) -> Result<(), BoxError> {
256        loop {
257            tokio::select! {
258                _ = self.graceful_shutdown_token.cancelled() => {
259                    trace!("signal received, shutting down");
260                    break;
261                },
262                (io, remote_addr) = self.listener.accept() => {
263                    self.handle_incomming(io, remote_addr);
264                },
265                Some(maybe_connection) = self.pending_connections.join_next() => {
266                    // If a task panics, just propagate it
267                    let (io, remote_addr) = match maybe_connection.unwrap() {
268                        Ok((io, remote_addr)) => {
269                            (io, remote_addr)
270                        }
271                        Err(e) => {
272                            tracing::debug!(error = %e, "error accepting connection");
273                            continue;
274                        }
275                    };
276
277                    trace!("connection accepted");
278                    self.handle_connection(io, remote_addr);
279                },
280                Some(connection_handler_output) = self.connection_handlers.join_next() => {
281                    // If a task panics, just propagate it
282                    let _: () = connection_handler_output.unwrap();
283                },
284            }
285        }
286
287        // Shutting down, wait for all connection handlers to finish
288        self.shutdown().await;
289
290        Ok(())
291    }
292
293    fn handle_incomming(&mut self, io: L::Io, remote_addr: L::Addr) {
294        if let Some(tls) = self.tls_config.clone() {
295            let tls_acceptor = TlsAcceptor::from(tls);
296            let allow_insecure = self.config.allow_insecure;
297            self.pending_connections.spawn(async move {
298                if allow_insecure {
299                    // XXX: If we want to allow for supporting insecure traffic from other types of
300                    // io, we'll need to implement a generic peekable IO type
301                    if let Some(tcp) =
302                        <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
303                    {
304                        // Determine whether new connection is TLS.
305                        let mut buf = [0; 1];
306                        // `peek` blocks until at least some data is available, so if there is no error then
307                        // it must return the one byte we are requesting.
308                        tcp.peek(&mut buf).await?;
309                        // First byte of a TLS handshake is 0x16, so if it isn't 0x16 then its
310                        // insecure
311                        if buf != [0x16] {
312                            tracing::trace!("accepting insecure connection");
313                            return Ok((ServerIo::new_io(io), remote_addr));
314                        }
315                    } else {
316                        tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
317                    }
318                }
319
320                tracing::trace!("accepting TLS connection");
321                let io = tls_acceptor.accept(io).await?;
322                Ok((ServerIo::new_tls_io(io), remote_addr))
323            });
324        } else {
325            self.handle_connection(ServerIo::new_io(io), remote_addr);
326        }
327    }
328
329    fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
330        let connection_shutdown_token = self.graceful_shutdown_token.child_token();
331        let connection_info = ConnectionInfo::new(
332            remote_addr,
333            io.peer_certs(),
334            connection_shutdown_token.clone(),
335        );
336        let connection_id = connection_info.id();
337        let connect_info = connection_info::ConnectInfo {
338            local_addr: self.local_addr.clone(),
339            remote_addr: connection_info.remote_address().clone(),
340        };
341        let peer_certificates = connection_info.peer_certificates().cloned();
342        let hyper_io = hyper_util::rt::TokioIo::new(io);
343
344        let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
345            move |mut request: Request<hyper::body::Incoming>| {
346                request.extensions_mut().insert(connect_info.clone());
347                if let Some(peer_certificates) = peer_certificates.clone() {
348                    request.extensions_mut().insert(peer_certificates);
349                }
350
351                request.map(body::boxed)
352            },
353        ));
354
355        self.connections
356            .write()
357            .unwrap()
358            .insert(connection_id, connection_info);
359        let on_connection_close =
360            connection_handler::OnConnectionClose::new(connection_id, self.connections.clone());
361
362        self.connection_handlers
363            .spawn(connection_handler::serve_connection(
364                hyper_io,
365                hyper_svc,
366                self.config.connection_builder(),
367                connection_shutdown_token,
368                self.config.max_connection_age,
369                on_connection_close,
370            ));
371    }
372
373    async fn shutdown(mut self) {
374        // The time we are willing to wait for a connection to get gracefully shutdown before we
375        // attempt to forcefully shutdown all active connections
376        const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
377
378        // Just to be careful make sure the token is canceled
379        self.graceful_shutdown_token.cancel();
380
381        // Terminate any in-progress pending connections
382        self.pending_connections.shutdown().await;
383
384        // Wait for all connection handlers to terminate
385        trace!(
386            "waiting for {} connections to close",
387            self.connection_handlers.len()
388        );
389
390        let graceful_shutdown =
391            async { while self.connection_handlers.join_next().await.is_some() {} };
392
393        if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
394            .await
395            .is_err()
396        {
397            tracing::warn!(
398                "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
399                CONNECTION_SHUTDOWN_GRACE_PERIOD
400            );
401            self.connection_handlers.shutdown().await;
402        }
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use axum::Router;
410
411    #[tokio::test]
412    async fn simple() {
413        const MESSAGE: &str = "Hello, World!";
414
415        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
416
417        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
418
419        let url = format!("http://{}", handle.local_addr());
420
421        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
422
423        assert_eq!(response, MESSAGE.as_bytes());
424    }
425
426    #[tokio::test]
427    async fn shutdown() {
428        const MESSAGE: &str = "Hello, World!";
429
430        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
431
432        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
433
434        let url = format!("http://{}", handle.local_addr());
435
436        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
437
438        // a request was just made so we should have 1 active connection
439        assert_eq!(handle.connections().len(), 1);
440
441        assert_eq!(response, MESSAGE.as_bytes());
442
443        assert!(!handle.is_shutdown());
444
445        handle.shutdown().await;
446
447        assert!(handle.is_shutdown());
448
449        // Now that the network has been shutdown there should be zero connections
450        assert_eq!(handle.connections().len(), 0);
451    }
452}