1use axum::{middleware::AddExtension, Extension};
5use axum_server::{
6 accept::Accept,
7 tls_rustls::{RustlsAcceptor, RustlsConfig},
8};
9use fastcrypto::ed25519::Ed25519PublicKey;
10use rustls::pki_types::CertificateDer;
11use std::{io, sync::Arc};
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio_rustls::server::TlsStream;
14use tower_layer::Layer;
15
16#[derive(Debug, Clone)]
17pub struct TlsConnectionInfo {
18 sni_hostname: Option<Arc<str>>,
19 peer_certificates: Option<Arc<[CertificateDer<'static>]>>,
20 public_key: Option<Ed25519PublicKey>,
21}
22
23impl TlsConnectionInfo {
24 pub fn sni_hostname(&self) -> Option<&str> {
25 self.sni_hostname.as_deref()
26 }
27
28 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'static>]> {
29 self.peer_certificates.as_deref()
30 }
31
32 pub fn public_key(&self) -> Option<&Ed25519PublicKey> {
33 self.public_key.as_ref()
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct TlsAcceptor {
40 inner: RustlsAcceptor,
41}
42
43impl TlsAcceptor {
44 pub fn new(config: rustls::ServerConfig) -> Self {
45 Self {
46 inner: RustlsAcceptor::new(RustlsConfig::from_config(Arc::new(config))),
47 }
48 }
49}
50
51type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
52
53impl<I, S> Accept<I, S> for TlsAcceptor
54where
55 I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
56 S: Send + 'static,
57{
58 type Stream = TlsStream<I>;
59 type Service = AddExtension<S, TlsConnectionInfo>;
60 type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
61
62 fn accept(&self, stream: I, service: S) -> Self::Future {
63 let acceptor = self.inner.clone();
64
65 Box::pin(async move {
66 let (stream, service) = acceptor.accept(stream, service).await?;
67 let server_conn = stream.get_ref().1;
68
69 let public_key = if let Some([peer_certificate, ..]) = server_conn.peer_certificates() {
70 crate::certgen::public_key_from_certificate(peer_certificate).ok()
71 } else {
72 None
73 };
74
75 let tls_connect_info = TlsConnectionInfo {
76 peer_certificates: server_conn.peer_certificates().map(From::from),
77 sni_hostname: server_conn.server_name().map(From::from),
78 public_key,
79 };
80 let service = Extension(tls_connect_info).layer(service);
81
82 Ok((stream, service))
83 })
84 }
85}