1use crate::crypto::PublicKey;
5use crate::signature_verification::VerifiedDigestCache;
6use crate::{
7 base_types::{EpochId, SuiAddress},
8 crypto::{DefaultHash, Signature, SignatureScheme, SuiSignature},
9 digests::ZKLoginInputsDigest,
10 error::{SuiErrorKind, SuiResult},
11 signature::{AuthenticatorTrait, VerifyParams},
12};
13use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
14use fastcrypto_zkp::bn254::zk_login::JwkId;
15use fastcrypto_zkp::bn254::zk_login::{JWK, OIDCProvider};
16use fastcrypto_zkp::bn254::zk_login_api::ZkLoginEnv;
17use fastcrypto_zkp::bn254::{zk_login::ZkLoginInputs, zk_login_api::verify_zk_login};
18use once_cell::sync::OnceCell;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use shared_crypto::intent::IntentMessage;
22use std::hash::Hash;
23use std::hash::Hasher;
24use std::sync::Arc;
25#[cfg(test)]
26#[path = "unit_tests/zk_login_authenticator_test.rs"]
27mod zk_login_authenticator_test;
28
29#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct ZkLoginAuthenticator {
33 pub inputs: ZkLoginInputs,
34 max_epoch: EpochId,
35 pub user_signature: Signature,
36 #[serde(skip)]
37 pub bytes: OnceCell<Vec<u8>>,
38}
39
40#[derive(Serialize, Deserialize)]
44struct ZkLoginCachingParams {
45 inputs: ZkLoginInputs,
46 max_epoch: EpochId,
47 extended_pk_bytes: Vec<u8>,
48}
49
50impl ZkLoginAuthenticator {
51 fn get_caching_params(&self) -> ZkLoginCachingParams {
55 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
56 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
57 ZkLoginCachingParams {
58 inputs: self.inputs.clone(),
59 max_epoch: self.max_epoch,
60 extended_pk_bytes,
61 }
62 }
63
64 pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
65 use fastcrypto::hash::HashFunction;
66 let mut hasher = DefaultHash::default();
67 bcs::serialize_into(&mut hasher, &self.get_caching_params())
68 .expect("serde should not fail");
69 ZKLoginInputsDigest::new(hasher.finalize().into())
70 }
71
72 pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
74 Self {
75 inputs,
76 max_epoch,
77 user_signature,
78 bytes: OnceCell::new(),
79 }
80 }
81
82 pub fn get_pk(&self) -> SuiResult<PublicKey> {
83 PublicKey::from_zklogin_inputs(&self.inputs)
84 }
85
86 pub fn get_iss(&self) -> &str {
87 self.inputs.get_iss()
88 }
89
90 pub fn get_max_epoch(&self) -> EpochId {
91 self.max_epoch
92 }
93
94 pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
95 &mut self.user_signature
96 }
97 pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
98 &mut self.max_epoch
99 }
100 pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
101 &mut self.inputs
102 }
103}
104
105impl PartialEq for ZkLoginAuthenticator {
107 fn eq(&self, other: &Self) -> bool {
108 self.as_ref() == other.as_ref()
109 }
110}
111
112impl Eq for ZkLoginAuthenticator {}
114
115impl Hash for ZkLoginAuthenticator {
117 fn hash<H: Hasher>(&self, state: &mut H) {
118 self.as_ref().hash(state);
119 }
120}
121
122impl AuthenticatorTrait for ZkLoginAuthenticator {
123 fn verify_user_authenticator_epoch(
124 &self,
125 epoch: EpochId,
126 max_epoch_upper_bound_delta: Option<u64>,
127 ) -> SuiResult {
128 if let Some(delta) = max_epoch_upper_bound_delta {
131 let max_epoch_upper_bound = epoch + delta;
132 if self.get_max_epoch() > max_epoch_upper_bound {
133 return Err(SuiErrorKind::InvalidSignature {
134 error: format!(
135 "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
136 self.get_max_epoch(),
137 epoch,
138 max_epoch_upper_bound
139 ),
140 }
141 .into());
142 }
143 }
144 if epoch > self.get_max_epoch() {
146 return Err(SuiErrorKind::InvalidSignature {
147 error: format!(
148 "ZKLogin expired at epoch {}, current epoch {}",
149 self.get_max_epoch(),
150 epoch
151 ),
152 }
153 .into());
154 }
155 Ok(())
156 }
157
158 fn verify_claims<T>(
160 &self,
161 intent_msg: &IntentMessage<T>,
162 author: SuiAddress,
163 aux_verify_data: &VerifyParams,
164 zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
165 ) -> SuiResult
166 where
167 T: Serialize,
168 {
169 if author != SuiAddress::try_from_unpadded(&self.inputs)? {
171 if !aux_verify_data.verify_legacy_zklogin_address
173 || author != SuiAddress::try_from_padded(&self.inputs)?
174 {
175 return Err(SuiErrorKind::InvalidAddress.into());
176 }
177 }
178
179 if !aux_verify_data.supported_providers.is_empty()
182 && !aux_verify_data.supported_providers.contains(
183 &OIDCProvider::from_iss(self.inputs.get_iss()).map_err(|_| {
184 SuiErrorKind::InvalidSignature {
185 error: "Unknown provider".to_string(),
186 }
187 })?,
188 )
189 {
190 return Err(SuiErrorKind::InvalidSignature {
191 error: format!("OIDC provider not supported: {}", self.inputs.get_iss()),
192 }
193 .into());
194 }
195
196 self.user_signature.verify_secure(
198 intent_msg,
199 author,
200 SignatureScheme::ZkLoginAuthenticator,
201 )?;
202
203 if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
204 Ok(())
207 } else {
208 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
211 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
212 let res = verify_zklogin_inputs_wrapper(
213 self.get_caching_params(),
214 &aux_verify_data.oidc_provider_jwks,
215 &aux_verify_data.zk_login_env,
216 )
217 .map_err(|e| {
218 SuiErrorKind::InvalidSignature {
219 error: e.to_string(),
220 }
221 .into()
222 });
223 match res {
224 Ok(_) => {
225 zklogin_inputs_cache.cache_digest(self.hash_inputs());
227 Ok(())
228 }
229 Err(e) => Err(e),
230 }
231 }
232 }
233}
234
235fn verify_zklogin_inputs_wrapper(
236 params: ZkLoginCachingParams,
237 all_jwk: &im::HashMap<JwkId, JWK>,
238 env: &ZkLoginEnv,
239) -> SuiResult<()> {
240 verify_zk_login(
241 ¶ms.inputs,
242 params.max_epoch,
243 ¶ms.extended_pk_bytes,
244 all_jwk,
245 env,
246 )
247 .map_err(|e| {
248 SuiErrorKind::InvalidSignature {
249 error: e.to_string(),
250 }
251 .into()
252 })
253}
254
255impl ToFromBytes for ZkLoginAuthenticator {
256 fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
257 if bytes.first().ok_or(FastCryptoError::InvalidInput)?
259 != &SignatureScheme::ZkLoginAuthenticator.flag()
260 {
261 return Err(FastCryptoError::InvalidInput);
262 }
263 let mut zk_login: ZkLoginAuthenticator =
264 bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
265 zk_login.inputs.init()?;
266 Ok(zk_login)
267 }
268}
269
270impl AsRef<[u8]> for ZkLoginAuthenticator {
271 fn as_ref(&self) -> &[u8] {
272 self.bytes
273 .get_or_try_init::<_, eyre::Report>(|| {
274 let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
275 let mut bytes = Vec::with_capacity(1 + as_bytes.len());
276 bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
277 bytes.extend_from_slice(as_bytes.as_slice());
278 Ok(bytes)
279 })
280 .expect("OnceCell invariant violated")
281 }
282}
283
284#[derive(Debug, Clone)]
285pub struct AddressSeed([u8; 32]);
286
287impl AddressSeed {
288 pub fn unpadded(&self) -> &[u8] {
289 let mut buf = self.0.as_slice();
290
291 while !buf.is_empty() && buf[0] == 0 {
292 buf = &buf[1..];
293 }
294
295 if buf.is_empty() { &self.0[31..] } else { buf }
297 }
298
299 pub fn padded(&self) -> &[u8] {
300 &self.0
301 }
302}
303
304impl std::fmt::Display for AddressSeed {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
307 let radix10 = big_int.to_str_radix(10);
308 f.write_str(&radix10)
309 }
310}
311
312#[derive(thiserror::Error, Debug)]
313pub enum AddressSeedParseError {
314 #[error("unable to parse radix10 encoded value `{0}`")]
315 Parse(#[from] num_bigint::ParseBigIntError),
316 #[error("larger than 32 bytes")]
317 TooBig,
318}
319
320impl std::str::FromStr for AddressSeed {
321 type Err = AddressSeedParseError;
322
323 fn from_str(s: &str) -> Result<Self, Self::Err> {
324 let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
325 let be_bytes = big_int.to_bytes_be();
326 let len = be_bytes.len();
327 let mut buf = [0; 32];
328
329 if len > 32 {
330 return Err(AddressSeedParseError::TooBig);
331 }
332
333 buf[32 - len..].copy_from_slice(&be_bytes);
334 Ok(Self(buf))
335 }
336}
337
338impl Serialize for AddressSeed {
340 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
341 where
342 S: serde::Serializer,
343 {
344 self.to_string().serialize(serializer)
345 }
346}
347
348impl<'de> Deserialize<'de> for AddressSeed {
349 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
350 where
351 D: serde::Deserializer<'de>,
352 {
353 let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
354 std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
355 }
356}
357
358#[cfg(test)]
359mod test {
360 use std::str::FromStr;
361
362 use super::AddressSeed;
363 use num_bigint::BigUint;
364 use proptest::prelude::*;
365
366 #[test]
367 fn unpadded_slice() {
368 let seed = AddressSeed([0; 32]);
369 let zero: [u8; 1] = [0];
370 assert_eq!(seed.unpadded(), zero.as_slice());
371
372 let mut seed = AddressSeed([1; 32]);
373 seed.0[0] = 0;
374 assert_eq!(seed.unpadded(), [1; 31].as_slice());
375 }
376
377 proptest! {
378 #[test]
379 fn dont_crash_on_large_inputs(
380 bytes in proptest::collection::vec(any::<u8>(), 33..1024)
381 ) {
382 let big_int = BigUint::from_bytes_be(&bytes);
383 let radix10 = big_int.to_str_radix(10);
384
385 let _ = AddressSeed::from_str(&radix10);
387 }
388
389 #[test]
390 fn valid_address_seeds(
391 bytes in proptest::collection::vec(any::<u8>(), 1..=32)
392 ) {
393 let big_int = BigUint::from_bytes_be(&bytes);
394 let radix10 = big_int.to_str_radix(10);
395
396 let seed = AddressSeed::from_str(&radix10).unwrap();
397 assert_eq!(radix10, seed.to_string());
398 seed.unpadded();
400 }
401 }
402}