sui_http/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use connection_handler::OnConnectionClose;
5use http::{Request, Response};
6use hyper_util::service::TowerToHyperService;
7use io::ServerIo;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::task::JoinSet;
12use tokio_rustls::TlsAcceptor;
13use tokio_rustls::rustls;
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::body::BoxBody;
18use self::connection_info::ActiveConnections;
19
20pub use http;
21
22pub mod body;
23mod config;
24mod connection_handler;
25mod connection_info;
26mod fuse;
27mod io;
28mod listener;
29
30pub use config::Config;
31pub use listener::Listener;
32pub use listener::ListenerExt;
33
34pub use connection_info::ConnectInfo;
35pub use connection_info::ConnectionId;
36pub use connection_info::ConnectionInfo;
37pub use connection_info::PeerCertificates;
38
39pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
40/// h2 alpn in plain format for rustls.
41const ALPN_H2: &[u8] = b"h2";
42/// h1 alpn in plain format for rustls.
43const ALPN_H1: &[u8] = b"http/1.1";
44
45#[derive(Default)]
46pub struct Builder {
47    config: Config,
48    tls_config: Option<rustls::ServerConfig>,
49}
50
51impl Builder {
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    pub fn config(mut self, config: Config) -> Self {
57        self.config = config;
58        self
59    }
60
61    // Convenience method for configuring TLS with a single server cert
62    //
63    // Attempts to load PEM formatted files for the certificate chain and private key material from
64    // the provided file system paths.
65    pub fn tls_single_cert(
66        self,
67        cert_file: impl AsRef<std::path::Path>,
68        private_key_file: impl AsRef<std::path::Path>,
69    ) -> Result<Self, BoxError> {
70        use rustls::pki_types::pem::PemObject;
71        use rustls::pki_types::{CertificateDer, PrivateKeyDer};
72
73        let certs = CertificateDer::pem_file_iter(cert_file)?.collect::<Result<_, _>>()?;
74        let private_key = PrivateKeyDer::from_pem_file(private_key_file)?;
75        let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(
76            rustls::crypto::ring::default_provider(),
77        ))
78        .with_protocol_versions(rustls::DEFAULT_VERSIONS)?
79        .with_no_client_auth()
80        .with_single_cert(certs, private_key)?;
81
82        Ok(self.tls_config(tls_config))
83    }
84
85    pub fn tls_config(mut self, tls_config: rustls::ServerConfig) -> Self {
86        self.tls_config = Some(tls_config);
87        self
88    }
89
90    pub fn serve<A, S, ResponseBody>(
91        self,
92        addr: A,
93        service: S,
94    ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
95    where
96        A: std::net::ToSocketAddrs,
97        S: Service<
98                Request<BoxBody>,
99                Response = Response<ResponseBody>,
100                Error: Into<BoxError>,
101                Future: Send,
102            > + Clone
103            + Send
104            + 'static,
105        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
106    {
107        let listener = listener::TcpListenerWithOptions::new(
108            addr,
109            /* nodelay */ true,
110            self.config.tcp_keepalive,
111        )?;
112
113        Self::serve_with_listener(self, listener, service)
114    }
115
116    fn serve_with_listener<L, S, ResponseBody>(
117        self,
118        listener: L,
119        service: S,
120    ) -> Result<ServerHandle<L::Addr>, BoxError>
121    where
122        L: Listener,
123        S: Service<
124                Request<BoxBody>,
125                Response = Response<ResponseBody>,
126                Error: Into<BoxError>,
127                Future: Send,
128            > + Clone
129            + Send
130            + 'static,
131        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
132    {
133        let local_addr = listener.local_addr()?;
134        let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
135        let connections = ActiveConnections::default();
136
137        let tls_config = self.tls_config.map(|mut tls| {
138            tls.alpn_protocols.push(ALPN_H2.into());
139            if self.config.accept_http1 {
140                tls.alpn_protocols.push(ALPN_H1.into());
141            }
142            Arc::new(tls)
143        });
144
145        let (watch_sender, watch_reciever) = tokio::sync::watch::channel(());
146        let server = Server {
147            config: self.config,
148            tls_config,
149            listener,
150            local_addr: local_addr.clone(),
151            service: ServiceBuilder::new()
152                .layer(tower::util::BoxCloneService::layer())
153                .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
154                .map_err(Into::into)
155                .service(service),
156            pending_connections: JoinSet::new(),
157            connection_handlers: JoinSet::new(),
158            connections: connections.clone(),
159            graceful_shutdown_token: graceful_shutdown_token.clone(),
160            _watch_reciever: watch_reciever,
161        };
162
163        let handle = ServerHandle(Arc::new(HandleInner {
164            local_addr,
165            connections,
166            graceful_shutdown_token,
167            watch_sender,
168        }));
169
170        tokio::spawn(server.serve());
171
172        Ok(handle)
173    }
174}
175
176#[derive(Debug)]
177pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
178
179#[derive(Debug)]
180struct HandleInner<A = std::net::SocketAddr> {
181    /// The local address of the server.
182    local_addr: A,
183    connections: ActiveConnections<A>,
184    graceful_shutdown_token: tokio_util::sync::CancellationToken,
185    watch_sender: tokio::sync::watch::Sender<()>,
186}
187
188impl<A> ServerHandle<A> {
189    /// Returns the local address of the server
190    pub fn local_addr(&self) -> &A {
191        &self.0.local_addr
192    }
193
194    /// Trigger a graceful shutdown of the server, but don't wait till the server has completed
195    /// shutting down
196    pub fn trigger_shutdown(&self) {
197        self.0.graceful_shutdown_token.cancel();
198    }
199
200    /// Completes once the network has been shutdown.
201    ///
202    /// This explicitly *does not* trigger the network to shutdown, see `trigger_shutdown` or
203    /// `shutdown` if you want to trigger shutting down the server.
204    pub async fn wait_for_shutdown(&self) {
205        self.0.watch_sender.closed().await
206    }
207
208    /// Triggers a shutdown of the server and waits for it to complete shutting down.
209    pub async fn shutdown(&self) {
210        self.trigger_shutdown();
211        self.wait_for_shutdown().await;
212    }
213
214    /// Checks if the Server has been shutdown.
215    pub fn is_shutdown(&self) -> bool {
216        self.0.watch_sender.is_closed()
217    }
218
219    pub fn connections(
220        &self,
221    ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
222        self.0.connections.read().unwrap()
223    }
224
225    /// Returns the number of active connections the server is handling
226    pub fn number_of_connections(&self) -> usize {
227        self.connections().len()
228    }
229}
230
231type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
232
233struct Server<L: Listener> {
234    config: Config,
235    tls_config: Option<Arc<rustls::ServerConfig>>,
236
237    listener: L,
238    local_addr: L::Addr,
239    service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
240
241    pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
242    connection_handlers: JoinSet<()>,
243    connections: ActiveConnections<L::Addr>,
244    graceful_shutdown_token: tokio_util::sync::CancellationToken,
245    // Used to signal to a ServerHandle when the server has completed shutting down
246    _watch_reciever: tokio::sync::watch::Receiver<()>,
247}
248
249impl<L> Server<L>
250where
251    L: Listener,
252{
253    async fn serve(mut self) -> Result<(), BoxError> {
254        loop {
255            tokio::select! {
256                _ = self.graceful_shutdown_token.cancelled() => {
257                    trace!("signal received, shutting down");
258                    break;
259                },
260                (io, remote_addr) = self.listener.accept() => {
261                    self.handle_incomming(io, remote_addr);
262                },
263                Some(maybe_connection) = self.pending_connections.join_next() => {
264                    // If a task panics, just propagate it
265                    let (io, remote_addr) = match maybe_connection.unwrap() {
266                        Ok((io, remote_addr)) => {
267                            (io, remote_addr)
268                        }
269                        Err(e) => {
270                            tracing::debug!(error = %e, "error accepting connection");
271                            continue;
272                        }
273                    };
274
275                    trace!("connection accepted");
276                    self.handle_connection(io, remote_addr);
277                },
278                Some(connection_handler_output) = self.connection_handlers.join_next() => {
279                    // If a task panics, just propagate it
280                    let _: () = connection_handler_output.unwrap();
281                },
282            }
283        }
284
285        // Shutting down, wait for all connection handlers to finish
286        self.shutdown().await;
287
288        Ok(())
289    }
290
291    fn handle_incomming(&mut self, io: L::Io, remote_addr: L::Addr) {
292        if let Some(tls) = self.tls_config.clone() {
293            let tls_acceptor = TlsAcceptor::from(tls);
294            let allow_insecure = self.config.allow_insecure;
295            self.pending_connections.spawn(async move {
296                if allow_insecure {
297                    // XXX: If we want to allow for supporting insecure traffic from other types of
298                    // io, we'll need to implement a generic peekable IO type
299                    if let Some(tcp) =
300                        <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
301                    {
302                        // Determine whether new connection is TLS.
303                        let mut buf = [0; 1];
304                        // `peek` blocks until at least some data is available, so if there is no error then
305                        // it must return the one byte we are requesting.
306                        tcp.peek(&mut buf).await?;
307                        // First byte of a TLS handshake is 0x16, so if it isn't 0x16 then its
308                        // insecure
309                        if buf != [0x16] {
310                            tracing::trace!("accepting insecure connection");
311                            return Ok((ServerIo::new_io(io), remote_addr));
312                        }
313                    } else {
314                        tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
315                    }
316                }
317
318                tracing::trace!("accepting TLS connection");
319                let io = tls_acceptor.accept(io).await?;
320                Ok((ServerIo::new_tls_io(io), remote_addr))
321            });
322        } else {
323            self.handle_connection(ServerIo::new_io(io), remote_addr);
324        }
325    }
326
327    fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
328        let connection_shutdown_token = self.graceful_shutdown_token.child_token();
329        let connection_info = ConnectionInfo::new(
330            remote_addr,
331            io.peer_certs(),
332            connection_shutdown_token.clone(),
333        );
334        let connection_id = connection_info.id();
335        let connect_info = connection_info::ConnectInfo {
336            local_addr: self.local_addr.clone(),
337            remote_addr: connection_info.remote_address().clone(),
338        };
339        let peer_certificates = connection_info.peer_certificates().cloned();
340        let hyper_io = hyper_util::rt::TokioIo::new(io);
341
342        let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
343            move |mut request: Request<hyper::body::Incoming>| {
344                request.extensions_mut().insert(connect_info.clone());
345                if let Some(peer_certificates) = peer_certificates.clone() {
346                    request.extensions_mut().insert(peer_certificates);
347                }
348
349                request.map(body::boxed)
350            },
351        ));
352
353        self.connections
354            .write()
355            .unwrap()
356            .insert(connection_id, connection_info);
357        let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
358
359        self.connection_handlers
360            .spawn(connection_handler::serve_connection(
361                hyper_io,
362                hyper_svc,
363                self.config.connection_builder(),
364                connection_shutdown_token,
365                self.config.max_connection_age,
366                on_connection_close,
367            ));
368    }
369
370    async fn shutdown(mut self) {
371        // The time we are willing to wait for a connection to get gracefully shutdown before we
372        // attempt to forcefully shutdown all active connections
373        const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
374
375        // Just to be careful make sure the token is canceled
376        self.graceful_shutdown_token.cancel();
377
378        // Terminate any in-progress pending connections
379        self.pending_connections.shutdown().await;
380
381        // Wait for all connection handlers to terminate
382        trace!(
383            "waiting for {} connections to close",
384            self.connection_handlers.len()
385        );
386
387        let graceful_shutdown =
388            async { while self.connection_handlers.join_next().await.is_some() {} };
389
390        if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
391            .await
392            .is_err()
393        {
394            tracing::warn!(
395                "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
396                CONNECTION_SHUTDOWN_GRACE_PERIOD
397            );
398            self.connection_handlers.shutdown().await;
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use axum::Router;
407
408    #[tokio::test]
409    async fn simple() {
410        const MESSAGE: &str = "Hello, World!";
411
412        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
413
414        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
415
416        let url = format!("http://{}", handle.local_addr());
417
418        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
419
420        assert_eq!(response, MESSAGE.as_bytes());
421    }
422
423    #[tokio::test]
424    async fn shutdown() {
425        const MESSAGE: &str = "Hello, World!";
426
427        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
428
429        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
430
431        let url = format!("http://{}", handle.local_addr());
432
433        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
434
435        // a request was just made so we should have 1 active connection
436        assert_eq!(handle.connections().len(), 1);
437
438        assert_eq!(response, MESSAGE.as_bytes());
439
440        assert!(!handle.is_shutdown());
441
442        handle.shutdown().await;
443
444        assert!(handle.is_shutdown());
445
446        // Now that the network has been shutdown there should be zero connections
447        assert_eq!(handle.connections().len(), 0);
448    }
449}