sui_tls/
acceptor.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use axum::{middleware::AddExtension, Extension};
5use axum_server::{
6    accept::Accept,
7    tls_rustls::{RustlsAcceptor, RustlsConfig},
8};
9use fastcrypto::ed25519::Ed25519PublicKey;
10use rustls::pki_types::CertificateDer;
11use std::{io, sync::Arc};
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_rustls::server::TlsStream;
14use tower_layer::Layer;
15
16#[derive(Debug, Clone)]
17pub struct TlsConnectionInfo {
18    sni_hostname: Option<Arc<str>>,
19    peer_certificates: Option<Arc<[CertificateDer<'static>]>>,
20    public_key: Option<Ed25519PublicKey>,
21}
22
23impl TlsConnectionInfo {
24    pub fn sni_hostname(&self) -> Option<&str> {
25        self.sni_hostname.as_deref()
26    }
27
28    pub fn peer_certificates(&self) -> Option<&[CertificateDer<'static>]> {
29        self.peer_certificates.as_deref()
30    }
31
32    pub fn public_key(&self) -> Option<&Ed25519PublicKey> {
33        self.public_key.as_ref()
34    }
35}
36
37/// An `Acceptor` that will provide `TlsConnectionInfo` as an axum `Extension` for use in handlers.
38#[derive(Debug, Clone)]
39pub struct TlsAcceptor {
40    inner: RustlsAcceptor,
41}
42
43impl TlsAcceptor {
44    pub fn new(config: rustls::ServerConfig) -> Self {
45        Self {
46            inner: RustlsAcceptor::new(RustlsConfig::from_config(Arc::new(config))),
47        }
48    }
49}
50
51type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
52
53impl<I, S> Accept<I, S> for TlsAcceptor
54where
55    I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
56    S: Send + 'static,
57{
58    type Stream = TlsStream<I>;
59    type Service = AddExtension<S, TlsConnectionInfo>;
60    type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
61
62    fn accept(&self, stream: I, service: S) -> Self::Future {
63        let acceptor = self.inner.clone();
64
65        Box::pin(async move {
66            let (stream, service) = acceptor.accept(stream, service).await?;
67            let server_conn = stream.get_ref().1;
68
69            let public_key = if let Some([peer_certificate, ..]) = server_conn.peer_certificates() {
70                crate::certgen::public_key_from_certificate(peer_certificate).ok()
71            } else {
72                None
73            };
74
75            let tls_connect_info = TlsConnectionInfo {
76                peer_certificates: server_conn.peer_certificates().map(From::from),
77                sni_hostname: server_conn.server_name().map(From::from),
78                public_key,
79            };
80            let service = Extension(tls_connect_info).layer(service);
81
82            Ok((stream, service))
83        })
84    }
85}