sui_tls/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4mod acceptor;
5mod certgen;
6mod verifier;
7
8use std::sync::Arc;
9
10pub use acceptor::{TlsAcceptor, TlsConnectionInfo};
11pub use certgen::SelfSignedCertificate;
12use rustls::ClientConfig;
13pub use verifier::{
14    AllowAll, AllowPublicKeys, Allower, ClientCertVerifier, ServerCertVerifier,
15    public_key_from_certificate,
16};
17
18pub use rustls;
19
20use fastcrypto::ed25519::{Ed25519PrivateKey, Ed25519PublicKey};
21use tokio_rustls::rustls::ServerConfig;
22
23pub const SUI_VALIDATOR_SERVER_NAME: &str = "sui";
24
25pub fn create_rustls_server_config(
26    private_key: Ed25519PrivateKey,
27    server_name: String,
28) -> ServerConfig {
29    // TODO: refactor to use key bytes
30    let self_signed_cert = SelfSignedCertificate::new(private_key, server_name.as_str());
31    let tls_cert = self_signed_cert.rustls_certificate();
32    let tls_private_key = self_signed_cert.rustls_private_key();
33    let mut tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(
34        rustls::crypto::ring::default_provider(),
35    ))
36    .with_protocol_versions(&[&rustls::version::TLS13])
37    .unwrap_or_else(|e| panic!("Failed to create TLS server config: {:?}", e))
38    .with_no_client_auth()
39    .with_single_cert(vec![tls_cert], tls_private_key)
40    .unwrap_or_else(|e| panic!("Failed to create TLS server config: {:?}", e));
41    tls_config.alpn_protocols = vec![b"h2".to_vec()];
42    tls_config
43}
44
45/// Create a TLS server config which requires mTLS, eg the client to also provide a cert and be
46/// verified by the server based on the provided policy
47pub fn create_rustls_server_config_with_client_verifier<A: Allower + 'static>(
48    private_key: Ed25519PrivateKey,
49    server_name: String,
50    allower: A,
51) -> ServerConfig {
52    let verifier = ClientCertVerifier::new(allower, server_name.clone());
53    // TODO: refactor to use key bytes
54    let self_signed_cert = SelfSignedCertificate::new(private_key, server_name.as_str());
55    let tls_cert = self_signed_cert.rustls_certificate();
56    let tls_private_key = self_signed_cert.rustls_private_key();
57    let mut tls_config = verifier
58        .rustls_server_config(vec![tls_cert], tls_private_key)
59        .unwrap_or_else(|e| panic!("Failed to create TLS server config: {:?}", e));
60    tls_config.alpn_protocols = vec![b"h2".to_vec()];
61    tls_config
62}
63
64pub fn create_rustls_client_config(
65    target_public_key: Ed25519PublicKey,
66    server_name: String,
67    client_key: Option<Ed25519PrivateKey>, // optional self-signed cert for client verification
68) -> ClientConfig {
69    let tls_config = ServerCertVerifier::new(target_public_key, server_name.clone());
70    if let Some(private_key) = client_key {
71        let self_signed_cert = SelfSignedCertificate::new(private_key, server_name.as_str());
72        let tls_cert = self_signed_cert.rustls_certificate();
73        let tls_private_key = self_signed_cert.rustls_private_key();
74        tls_config.rustls_client_config_with_client_auth(vec![tls_cert], tls_private_key)
75    } else {
76        tls_config.rustls_client_config_with_no_client_auth()
77    }
78    .unwrap_or_else(|e| panic!("Failed to create TLS client config: {e:?}"))
79}
80
81#[cfg(test)]
82mod tests {
83    use std::collections::BTreeSet;
84
85    use super::*;
86    use fastcrypto::ed25519::Ed25519KeyPair;
87    use fastcrypto::traits::KeyPair;
88    use rustls::client::danger::ServerCertVerifier as _;
89    use rustls::pki_types::ServerName;
90    use rustls::pki_types::UnixTime;
91    use rustls::server::danger::ClientCertVerifier as _;
92
93    #[test]
94    fn verify_allowall() {
95        let mut rng = rand::thread_rng();
96        let allowed = Ed25519KeyPair::generate(&mut rng);
97        let disallowed = Ed25519KeyPair::generate(&mut rng);
98        let random_cert_bob =
99            SelfSignedCertificate::new(allowed.private(), SUI_VALIDATOR_SERVER_NAME);
100        let random_cert_alice =
101            SelfSignedCertificate::new(disallowed.private(), SUI_VALIDATOR_SERVER_NAME);
102
103        let verifier = ClientCertVerifier::new(AllowAll, SUI_VALIDATOR_SERVER_NAME.to_string());
104
105        // The bob passes validation
106        verifier
107            .verify_client_cert(&random_cert_bob.rustls_certificate(), &[], UnixTime::now())
108            .unwrap();
109
110        // The alice passes validation
111        verifier
112            .verify_client_cert(
113                &random_cert_alice.rustls_certificate(),
114                &[],
115                UnixTime::now(),
116            )
117            .unwrap();
118    }
119
120    #[test]
121    fn verify_server_cert() {
122        let mut rng = rand::thread_rng();
123        let allowed = Ed25519KeyPair::generate(&mut rng);
124        let disallowed = Ed25519KeyPair::generate(&mut rng);
125        let allowed_public_key = allowed.public().to_owned();
126        let random_cert_bob =
127            SelfSignedCertificate::new(allowed.private(), SUI_VALIDATOR_SERVER_NAME);
128        let random_cert_alice =
129            SelfSignedCertificate::new(disallowed.private(), SUI_VALIDATOR_SERVER_NAME);
130
131        let verifier =
132            ServerCertVerifier::new(allowed_public_key, SUI_VALIDATOR_SERVER_NAME.to_string());
133
134        // The bob passes validation
135        verifier
136            .verify_server_cert(
137                &random_cert_bob.rustls_certificate(),
138                &[],
139                &ServerName::try_from("example.com").unwrap(),
140                &[],
141                UnixTime::now(),
142            )
143            .unwrap();
144
145        // The alice does not pass validation
146        let err = verifier
147            .verify_server_cert(
148                &random_cert_alice.rustls_certificate(),
149                &[],
150                &ServerName::try_from("example.com").unwrap(),
151                &[],
152                UnixTime::now(),
153            )
154            .unwrap_err();
155        assert!(
156            matches!(err, rustls::Error::General(_)),
157            "Actual error: {err:?}"
158        );
159    }
160
161    #[test]
162    fn verify_hashset() {
163        let mut rng = rand::thread_rng();
164        let allowed = Ed25519KeyPair::generate(&mut rng);
165        let disallowed = Ed25519KeyPair::generate(&mut rng);
166
167        let allowed_public_keys = BTreeSet::from([allowed.public().to_owned()]);
168        let allowed_cert = SelfSignedCertificate::new(allowed.private(), SUI_VALIDATOR_SERVER_NAME);
169
170        let disallowed_cert =
171            SelfSignedCertificate::new(disallowed.private(), SUI_VALIDATOR_SERVER_NAME);
172
173        let allowlist = AllowPublicKeys::new(allowed_public_keys);
174        let verifier =
175            ClientCertVerifier::new(allowlist.clone(), SUI_VALIDATOR_SERVER_NAME.to_string());
176
177        // The allowed cert passes validation
178        verifier
179            .verify_client_cert(&allowed_cert.rustls_certificate(), &[], UnixTime::now())
180            .unwrap();
181
182        // The disallowed cert fails validation
183        let err = verifier
184            .verify_client_cert(&disallowed_cert.rustls_certificate(), &[], UnixTime::now())
185            .unwrap_err();
186        assert!(
187            matches!(err, rustls::Error::General(_)),
188            "Actual error: {err:?}"
189        );
190
191        // After removing the allowed public key from the set it now fails validation
192        allowlist.update(BTreeSet::new());
193        let err = verifier
194            .verify_client_cert(&allowed_cert.rustls_certificate(), &[], UnixTime::now())
195            .unwrap_err();
196        assert!(
197            matches!(err, rustls::Error::General(_)),
198            "Actual error: {err:?}"
199        );
200    }
201
202    #[test]
203    fn invalid_server_name() {
204        let mut rng = rand::thread_rng();
205        let keypair = Ed25519KeyPair::generate(&mut rng);
206        let public_key = keypair.public().to_owned();
207        let cert = SelfSignedCertificate::new(keypair.private(), "not-sui");
208
209        let allowlist = AllowPublicKeys::new(BTreeSet::from([public_key.clone()]));
210        let client_verifier =
211            ClientCertVerifier::new(allowlist.clone(), SUI_VALIDATOR_SERVER_NAME.to_string());
212
213        // Allowed public key but the server-name in the cert is not the required "sui"
214        let err = client_verifier
215            .verify_client_cert(&cert.rustls_certificate(), &[], UnixTime::now())
216            .unwrap_err();
217        assert_eq!(
218            err,
219            rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName),
220            "Actual error: {err:?}"
221        );
222
223        let server_verifier =
224            ServerCertVerifier::new(public_key, SUI_VALIDATOR_SERVER_NAME.to_string());
225
226        // Allowed public key but the server-name in the cert is not the required "sui"
227        let err = server_verifier
228            .verify_server_cert(
229                &cert.rustls_certificate(),
230                &[],
231                &ServerName::try_from("example.com").unwrap(),
232                &[],
233                UnixTime::now(),
234            )
235            .unwrap_err();
236        assert_eq!(
237            err,
238            rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName),
239            "Actual error: {err:?}"
240        );
241    }
242
243    #[tokio::test]
244    async fn axum_acceptor() {
245        use fastcrypto::ed25519::Ed25519KeyPair;
246        use fastcrypto::traits::KeyPair;
247
248        let mut rng = rand::thread_rng();
249        let client_keypair = Ed25519KeyPair::generate(&mut rng);
250        let client_public_key = client_keypair.public().to_owned();
251        let client_certificate =
252            SelfSignedCertificate::new(client_keypair.private(), SUI_VALIDATOR_SERVER_NAME);
253        let server_keypair = Ed25519KeyPair::generate(&mut rng);
254        let server_certificate = SelfSignedCertificate::new(server_keypair.private(), "localhost");
255
256        let client = reqwest::Client::builder()
257            .add_root_certificate(server_certificate.reqwest_certificate())
258            .identity(client_certificate.reqwest_identity())
259            .https_only(true)
260            .build()
261            .unwrap();
262
263        let allowlist = AllowPublicKeys::new(BTreeSet::new());
264        let tls_config =
265            ClientCertVerifier::new(allowlist.clone(), SUI_VALIDATOR_SERVER_NAME.to_string())
266                .rustls_server_config(
267                    vec![server_certificate.rustls_certificate()],
268                    server_certificate.rustls_private_key(),
269                )
270                .unwrap();
271
272        async fn handler(tls_info: axum::Extension<TlsConnectionInfo>) -> String {
273            tls_info.public_key().unwrap().to_string()
274        }
275
276        let app = axum::Router::new().route("/", axum::routing::get(handler));
277        let listener = std::net::TcpListener::bind("localhost:0").unwrap();
278        let server_address = listener.local_addr().unwrap();
279        let acceptor = TlsAcceptor::new(tls_config);
280        let _server = tokio::spawn(async move {
281            axum_server::Server::from_tcp(listener)
282                .acceptor(acceptor)
283                .serve(app.into_make_service())
284                .await
285                .unwrap()
286        });
287
288        let server_url = format!("https://localhost:{}", server_address.port());
289        // Client request is rejected because it isn't in the allowlist
290        client.get(&server_url).send().await.unwrap_err();
291
292        // Insert the client's public key into the allowlist and verify the request is successful
293        allowlist.update(BTreeSet::from([client_public_key.clone()]));
294
295        let res = client.get(&server_url).send().await.unwrap();
296        let body = res.text().await.unwrap();
297        assert_eq!(client_public_key.to_string(), body);
298    }
299}