use arc_swap::ArcSwap;
use fastcrypto::ed25519::Ed25519PublicKey;
use fastcrypto::traits::ToFromBytes;
use rustls::crypto::WebPkiSupportedAlgorithms;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivateKeyDer;
use rustls::pki_types::ServerName;
use rustls::pki_types::SignatureVerificationAlgorithm;
use rustls::pki_types::TrustAnchor;
use rustls::pki_types::UnixTime;
use std::collections::BTreeSet;
use std::sync::Arc;
static SUPPORTED_SIG_ALGS: &[&dyn SignatureVerificationAlgorithm] = &[webpki::ring::ED25519];
static SUPPORTED_ALGORITHMS: WebPkiSupportedAlgorithms = WebPkiSupportedAlgorithms {
all: SUPPORTED_SIG_ALGS,
mapping: &[(rustls::SignatureScheme::ED25519, SUPPORTED_SIG_ALGS)],
};
pub trait Allower: std::fmt::Debug + Send + Sync {
fn allowed(&self, key: &Ed25519PublicKey) -> bool;
}
#[derive(Debug, Clone, Default)]
pub struct AllowAll;
impl Allower for AllowAll {
fn allowed(&self, _: &Ed25519PublicKey) -> bool {
true
}
}
#[derive(Debug, Clone, Default)]
pub struct AllowPublicKeys {
inner: Arc<ArcSwap<BTreeSet<Ed25519PublicKey>>>,
}
impl AllowPublicKeys {
pub fn new(allowed: BTreeSet<Ed25519PublicKey>) -> Self {
Self {
inner: Arc::new(ArcSwap::from_pointee(allowed)),
}
}
pub fn update(&self, new_allowed: BTreeSet<Ed25519PublicKey>) {
self.inner.store(Arc::new(new_allowed));
}
}
impl Allower for AllowPublicKeys {
fn allowed(&self, key: &Ed25519PublicKey) -> bool {
self.inner.load().contains(key)
}
}
#[derive(Clone, Debug)]
pub struct ClientCertVerifier<A> {
allower: A,
name: String,
}
impl<A> ClientCertVerifier<A> {
pub fn new(allower: A, name: String) -> Self {
Self { allower, name }
}
}
impl<A: Allower + 'static> ClientCertVerifier<A> {
pub fn rustls_server_config(
self,
certificates: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
) -> Result<rustls::ServerConfig, rustls::Error> {
let mut config = rustls::ServerConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_protocol_versions(&[&rustls::version::TLS13])?
.with_client_cert_verifier(std::sync::Arc::new(self))
.with_single_cert(certificates, private_key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(config)
}
}
impl<A: Allower> rustls::server::danger::ClientCertVerifier for ClientCertVerifier<A> {
fn offer_client_auth(&self) -> bool {
true
}
fn client_auth_mandatory(&self) -> bool {
true
}
fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
&[]
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer,
intermediates: &[CertificateDer],
now: UnixTime,
) -> Result<rustls::server::danger::ClientCertVerified, rustls::Error> {
let public_key = public_key_from_certificate(end_entity)?;
if !self.allower.allowed(&public_key) {
return Err(rustls::Error::General(format!(
"invalid certificate: {:?} is not in the validator set",
public_key,
)));
}
verify_self_signed_cert(
end_entity,
intermediates,
webpki::KeyUsage::client_auth(),
&self.name,
now,
)
.map(|_| rustls::server::danger::ClientCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(message, cert, dss, &SUPPORTED_ALGORITHMS)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(message, cert, dss, &SUPPORTED_ALGORITHMS)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
SUPPORTED_ALGORITHMS.supported_schemes()
}
}
#[derive(Clone, Debug)]
pub struct ServerCertVerifier {
public_key: Ed25519PublicKey,
name: String,
}
impl ServerCertVerifier {
pub fn new(public_key: Ed25519PublicKey, name: String) -> Self {
Self { public_key, name }
}
pub fn rustls_client_config_with_client_auth(
self,
certificates: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
) -> Result<rustls::ClientConfig, rustls::Error> {
rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_protocol_versions(&[&rustls::version::TLS13])?
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(self))
.with_client_auth_cert(certificates, private_key)
}
pub fn rustls_client_config_with_no_client_auth(
self,
) -> Result<rustls::ClientConfig, rustls::Error> {
Ok(rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_protocol_versions(&[&rustls::version::TLS13])?
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(self))
.with_no_client_auth())
}
}
impl rustls::client::danger::ServerCertVerifier for ServerCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
_server_name: &ServerName,
_ocsp_response: &[u8],
now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let public_key = public_key_from_certificate(end_entity)?;
if public_key != self.public_key {
return Err(rustls::Error::General(format!(
"invalid certificate: {:?} is not the expected server public key",
public_key,
)));
}
verify_self_signed_cert(
end_entity,
intermediates,
webpki::KeyUsage::server_auth(),
&self.name,
now,
)
.map(|_| rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(message, cert, dss, &SUPPORTED_ALGORITHMS)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(message, cert, dss, &SUPPORTED_ALGORITHMS)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
SUPPORTED_ALGORITHMS.supported_schemes()
}
}
fn verify_self_signed_cert(
end_entity: &CertificateDer,
intermediates: &[CertificateDer],
usage: webpki::KeyUsage,
name: &str,
now: UnixTime,
) -> Result<(), rustls::Error> {
let (cert, chain, trustroots) = prepare_for_self_signed(end_entity, intermediates)?;
let verified_cert = cert
.verify_for_usage(
SUPPORTED_SIG_ALGS,
&trustroots,
chain,
now,
usage,
None,
None,
)
.map_err(pki_error)?;
let subject_name =
ServerName::try_from(name).map_err(|_| rustls::Error::UnsupportedNameType)?;
verified_cert
.end_entity()
.verify_is_valid_for_subject_name(&subject_name)
.map_err(pki_error)
}
type CertChainAndRoots<'a> = (
webpki::EndEntityCert<'a>,
&'a [CertificateDer<'a>],
Vec<TrustAnchor<'a>>,
);
fn prepare_for_self_signed<'a>(
end_entity: &'a CertificateDer,
intermediates: &'a [CertificateDer],
) -> Result<CertChainAndRoots<'a>, rustls::Error> {
let cert = webpki::EndEntityCert::try_from(end_entity).map_err(pki_error)?;
let root = webpki::anchor_from_trusted_cert(end_entity).map_err(pki_error)?;
Ok((cert, intermediates, vec![root]))
}
fn pki_error(error: webpki::Error) -> rustls::Error {
use webpki::Error::*;
match error {
BadDer | BadDerTime => {
rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding)
}
InvalidSignatureForPublicKey
| UnsupportedSignatureAlgorithm
| UnsupportedSignatureAlgorithmForPublicKey => {
rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature)
}
CertNotValidForName(_) => {
rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)
}
e => rustls::Error::General(format!("invalid peer certificate: {e}")),
}
}
pub fn public_key_from_certificate(
certificate: &CertificateDer,
) -> Result<Ed25519PublicKey, rustls::Error> {
use x509_parser::{certificate::X509Certificate, prelude::FromDer};
let cert = X509Certificate::from_der(certificate.as_ref())
.map_err(|e| rustls::Error::General(e.to_string()))?;
let spki = cert.1.public_key();
let public_key_bytes =
<ed25519::pkcs8::PublicKeyBytes as pkcs8::DecodePublicKey>::from_public_key_der(spki.raw)
.map_err(|e| rustls::Error::General(format!("invalid ed25519 public key: {e}")))?;
let public_key = Ed25519PublicKey::from_bytes(public_key_bytes.as_ref())
.map_err(|e| rustls::Error::General(format!("invalid ed25519 public key: {e}")))?;
Ok(public_key)
}