sui_crypto/zklogin/
mod.rs1use std::collections::HashMap;
2
3use crate::SignatureError;
4use poseidon::POSEIDON;
5use signature::Verifier;
6use sui_sdk_types::Jwk;
7use sui_sdk_types::JwkId;
8use sui_sdk_types::UserSignature;
9use sui_sdk_types::ZkLoginAuthenticator;
10use sui_sdk_types::ZkLoginInputs;
11
12mod poseidon;
13mod verify;
14
15#[cfg(test)]
16mod tests;
17
18#[derive(Debug, Clone, PartialEq, Default)]
19pub struct ZkloginVerifier {
20 proof_verifying_key: verify::VerifyingKey,
21 jwks: HashMap<JwkId, Jwk>,
22}
23
24impl ZkloginVerifier {
25 fn new(proof_verifying_key: verify::VerifyingKey) -> Self {
26 Self {
27 proof_verifying_key,
28 jwks: Default::default(),
29 }
30 }
31
32 pub fn new_mainnet() -> Self {
33 Self::new(verify::VerifyingKey::new_mainnet())
34 }
35
36 pub fn new_dev() -> Self {
37 Self::new(verify::VerifyingKey::new_dev())
38 }
39
40 pub fn jwks(&self) -> &HashMap<JwkId, Jwk> {
41 &self.jwks
42 }
43
44 pub fn jwks_mut(&mut self) -> &mut HashMap<JwkId, Jwk> {
45 &mut self.jwks
46 }
47}
48
49impl Verifier<ZkLoginAuthenticator> for ZkloginVerifier {
50 fn verify(
51 &self,
52 message: &[u8],
53 signature: &ZkLoginAuthenticator,
54 ) -> Result<(), SignatureError> {
55 let jwt_details = JwtDetails::from_zklogin_inputs(&signature.inputs)?;
57 let jwk = self.jwks.get(&jwt_details.id).ok_or_else(|| {
58 SignatureError::from_source(format!(
59 "unable to find corrisponding jwk with id '{:?}' for provided authenticator",
60 jwt_details.id
61 ))
62 })?;
63
64 crate::simple::SimpleVerifier.verify(message, &signature.signature)?;
66
67 self.proof_verifying_key.verify_zklogin(
69 jwk,
70 &signature.inputs,
71 &signature.signature,
72 signature.max_epoch,
73 )
74 }
75}
76
77impl Verifier<UserSignature> for ZkloginVerifier {
78 fn verify(&self, message: &[u8], signature: &UserSignature) -> Result<(), SignatureError> {
79 let UserSignature::ZkLogin(zklogin_authenticator) = signature else {
80 return Err(SignatureError::from_source("not a zklogin signature"));
81 };
82
83 self.verify(message, zklogin_authenticator.as_ref())
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
89struct JwtDetails {
90 header: JwtHeader,
91 id: JwkId,
92}
93
94impl JwtDetails {
95 fn from_zklogin_inputs(inputs: &ZkLoginInputs) -> Result<Self, SignatureError> {
96 let header = JwtHeader::from_base64(&inputs.header_base64)?;
97 let id = JwkId {
98 iss: inputs.iss().map_err(SignatureError::from_source)?,
99 kid: header.kid.clone(),
100 };
101 Ok(JwtDetails { header, id })
102 }
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
108struct JwtHeader {
109 alg: String,
110 kid: String,
111 typ: Option<String>,
112}
113
114impl JwtHeader {
115 fn from_base64(s: &str) -> Result<Self, SignatureError> {
116 use base64ct::Base64UrlUnpadded;
117 use base64ct::Encoding;
118
119 #[derive(serde_derive::Serialize, serde_derive::Deserialize)]
120 struct Header {
121 alg: String,
122 kid: String,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 typ: Option<String>,
125 }
126
127 let header_bytes = Base64UrlUnpadded::decode_vec(s)
128 .map_err(|e| SignatureError::from_source(e.to_string()))?;
129 let Header { alg, kid, typ } =
130 serde_json::from_slice(&header_bytes).map_err(SignatureError::from_source)?;
131 if alg != "RS256" {
132 return Err(SignatureError::from_source("jwt alg must be RS256"));
133 }
134 Ok(Self { alg, kid, typ })
135 }
136}