sui_crypto/zklogin/
mod.rs

1use std::collections::HashMap;
2
3use crate::SignatureError;
4use poseidon::POSEIDON;
5use signature::Verifier;
6use sui_sdk_types::Jwk;
7use sui_sdk_types::JwkId;
8use sui_sdk_types::UserSignature;
9use sui_sdk_types::ZkLoginAuthenticator;
10use sui_sdk_types::ZkLoginInputs;
11
12mod poseidon;
13mod verify;
14
15#[cfg(test)]
16mod tests;
17
18pub struct ZkloginVerifier {
19    proof_verifying_key: verify::VerifyingKey,
20    jwks: HashMap<JwkId, Jwk>,
21}
22
23impl ZkloginVerifier {
24    fn new(proof_verifying_key: verify::VerifyingKey) -> Self {
25        Self {
26            proof_verifying_key,
27            jwks: Default::default(),
28        }
29    }
30
31    pub fn new_mainnet() -> Self {
32        Self::new(verify::VerifyingKey::new_mainnet())
33    }
34
35    pub fn new_dev() -> Self {
36        Self::new(verify::VerifyingKey::new_dev())
37    }
38
39    pub fn jwks(&self) -> &HashMap<JwkId, Jwk> {
40        &self.jwks
41    }
42
43    pub fn jwks_mut(&mut self) -> &mut HashMap<JwkId, Jwk> {
44        &mut self.jwks
45    }
46}
47
48impl Verifier<ZkLoginAuthenticator> for ZkloginVerifier {
49    fn verify(
50        &self,
51        message: &[u8],
52        signature: &ZkLoginAuthenticator,
53    ) -> Result<(), SignatureError> {
54        // 1. check that we have a valid corrisponding Jwk
55        let jwt_details = JwtDetails::from_zklogin_inputs(&signature.inputs)?;
56        let jwk = self.jwks.get(&jwt_details.id).ok_or_else(|| {
57            SignatureError::from_source(format!(
58                "unable to find corrisponding jwk with id '{:?}' for provided authenticator",
59                jwt_details.id
60            ))
61        })?;
62
63        // 2. verify that the provided SimpleSignature is valid
64        crate::simple::SimpleVerifier.verify(message, &signature.signature)?;
65
66        // 3. verify groth16 proof
67        self.proof_verifying_key.verify_zklogin(
68            jwk,
69            &signature.inputs,
70            &signature.signature,
71            signature.max_epoch,
72        )
73    }
74}
75
76impl Verifier<UserSignature> for ZkloginVerifier {
77    fn verify(&self, message: &[u8], signature: &UserSignature) -> Result<(), SignatureError> {
78        let UserSignature::ZkLogin(zklogin_authenticator) = signature else {
79            return Err(SignatureError::from_source("not a zklogin signature"));
80        };
81
82        self.verify(message, zklogin_authenticator.as_ref())
83    }
84}
85
86/// A structed of parsed JWT details, consists of kid, header, iss.
87#[derive(Debug, Clone, PartialEq, Eq)]
88struct JwtDetails {
89    header: JwtHeader,
90    id: JwkId,
91}
92
93impl JwtDetails {
94    fn from_zklogin_inputs(inputs: &ZkLoginInputs) -> Result<Self, SignatureError> {
95        let header = JwtHeader::from_base64(&inputs.header_base64)?;
96        let id = JwkId {
97            iss: inputs.iss().map_err(SignatureError::from_source)?,
98            kid: header.kid.clone(),
99        };
100        Ok(JwtDetails { header, id })
101    }
102}
103
104/// Struct that represents a standard JWT header according to
105/// https://openid.net/specs/openid-connect-core-1_0.html
106#[derive(Debug, Clone, PartialEq, Eq)]
107struct JwtHeader {
108    alg: String,
109    kid: String,
110    typ: Option<String>,
111}
112
113impl JwtHeader {
114    fn from_base64(s: &str) -> Result<Self, SignatureError> {
115        use base64ct::Base64UrlUnpadded;
116        use base64ct::Encoding;
117
118        #[derive(serde_derive::Serialize, serde_derive::Deserialize)]
119        struct Header {
120            alg: String,
121            kid: String,
122            #[serde(skip_serializing_if = "Option::is_none")]
123            typ: Option<String>,
124        }
125
126        let header_bytes = Base64UrlUnpadded::decode_vec(s)
127            .map_err(|e| SignatureError::from_source(e.to_string()))?;
128        let Header { alg, kid, typ } =
129            serde_json::from_slice(&header_bytes).map_err(SignatureError::from_source)?;
130        if alg != "RS256" {
131            return Err(SignatureError::from_source("jwt alg must be RS256"));
132        }
133        Ok(Self { alg, kid, typ })
134    }
135}