sui_crypto/zklogin/
mod.rs

1use std::collections::HashMap;
2
3use crate::SignatureError;
4use poseidon::POSEIDON;
5use signature::Verifier;
6use sui_sdk_types::Jwk;
7use sui_sdk_types::JwkId;
8use sui_sdk_types::UserSignature;
9use sui_sdk_types::ZkLoginAuthenticator;
10use sui_sdk_types::ZkLoginInputs;
11
12mod poseidon;
13mod verify;
14
15#[cfg(test)]
16mod tests;
17
18#[derive(Debug, Clone, PartialEq, Default)]
19pub struct ZkloginVerifier {
20    proof_verifying_key: verify::VerifyingKey,
21    jwks: HashMap<JwkId, Jwk>,
22}
23
24impl ZkloginVerifier {
25    fn new(proof_verifying_key: verify::VerifyingKey) -> Self {
26        Self {
27            proof_verifying_key,
28            jwks: Default::default(),
29        }
30    }
31
32    pub fn new_mainnet() -> Self {
33        Self::new(verify::VerifyingKey::new_mainnet())
34    }
35
36    pub fn new_dev() -> Self {
37        Self::new(verify::VerifyingKey::new_dev())
38    }
39
40    pub fn jwks(&self) -> &HashMap<JwkId, Jwk> {
41        &self.jwks
42    }
43
44    pub fn jwks_mut(&mut self) -> &mut HashMap<JwkId, Jwk> {
45        &mut self.jwks
46    }
47}
48
49impl Verifier<ZkLoginAuthenticator> for ZkloginVerifier {
50    fn verify(
51        &self,
52        message: &[u8],
53        signature: &ZkLoginAuthenticator,
54    ) -> Result<(), SignatureError> {
55        // 1. check that we have a valid corrisponding Jwk
56        let jwt_details = JwtDetails::from_zklogin_inputs(&signature.inputs)?;
57        let jwk = self.jwks.get(&jwt_details.id).ok_or_else(|| {
58            SignatureError::from_source(format!(
59                "unable to find corrisponding jwk with id '{:?}' for provided authenticator",
60                jwt_details.id
61            ))
62        })?;
63
64        // 2. verify that the provided SimpleSignature is valid
65        crate::simple::SimpleVerifier.verify(message, &signature.signature)?;
66
67        // 3. verify groth16 proof
68        self.proof_verifying_key.verify_zklogin(
69            jwk,
70            &signature.inputs,
71            &signature.signature,
72            signature.max_epoch,
73        )
74    }
75}
76
77impl Verifier<UserSignature> for ZkloginVerifier {
78    fn verify(&self, message: &[u8], signature: &UserSignature) -> Result<(), SignatureError> {
79        let UserSignature::ZkLogin(zklogin_authenticator) = signature else {
80            return Err(SignatureError::from_source("not a zklogin signature"));
81        };
82
83        self.verify(message, zklogin_authenticator.as_ref())
84    }
85}
86
87/// A structed of parsed JWT details, consists of kid, header, iss.
88#[derive(Debug, Clone, PartialEq, Eq)]
89struct JwtDetails {
90    header: JwtHeader,
91    id: JwkId,
92}
93
94impl JwtDetails {
95    fn from_zklogin_inputs(inputs: &ZkLoginInputs) -> Result<Self, SignatureError> {
96        let header = JwtHeader::from_base64(&inputs.header_base64)?;
97        let id = JwkId {
98            iss: inputs.iss().map_err(SignatureError::from_source)?,
99            kid: header.kid.clone(),
100        };
101        Ok(JwtDetails { header, id })
102    }
103}
104
105/// Struct that represents a standard JWT header according to
106/// https://openid.net/specs/openid-connect-core-1_0.html
107#[derive(Debug, Clone, PartialEq, Eq)]
108struct JwtHeader {
109    alg: String,
110    kid: String,
111    typ: Option<String>,
112}
113
114impl JwtHeader {
115    fn from_base64(s: &str) -> Result<Self, SignatureError> {
116        use base64ct::Base64UrlUnpadded;
117        use base64ct::Encoding;
118
119        #[derive(serde_derive::Serialize, serde_derive::Deserialize)]
120        struct Header {
121            alg: String,
122            kid: String,
123            #[serde(skip_serializing_if = "Option::is_none")]
124            typ: Option<String>,
125        }
126
127        let header_bytes = Base64UrlUnpadded::decode_vec(s)
128            .map_err(|e| SignatureError::from_source(e.to_string()))?;
129        let Header { alg, kid, typ } =
130            serde_json::from_slice(&header_bytes).map_err(SignatureError::from_source)?;
131        if alg != "RS256" {
132            return Err(SignatureError::from_source("jwt alg must be RS256"));
133        }
134        Ok(Self { alg, kid, typ })
135    }
136}