1use crate::crypto::PublicKey;
5use crate::signature_verification::VerifiedDigestCache;
6use crate::{
7 base_types::{EpochId, SuiAddress},
8 crypto::{DefaultHash, Signature, SignatureScheme, SuiSignature},
9 digests::ZKLoginInputsDigest,
10 error::{SuiErrorKind, SuiResult},
11 signature::{AuthenticatorTrait, VerifyParams},
12};
13use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
14use fastcrypto_zkp::bn254::zk_login::JwkId;
15use fastcrypto_zkp::bn254::zk_login::{JWK, OIDCProvider};
16use fastcrypto_zkp::bn254::zk_login_api::ZkLoginEnv;
17use fastcrypto_zkp::bn254::{zk_login::ZkLoginInputs, zk_login_api::verify_zk_login};
18use once_cell::sync::OnceCell;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use shared_crypto::intent::IntentMessage;
22use std::hash::Hash;
23use std::hash::Hasher;
24use std::sync::Arc;
25#[cfg(test)]
26#[path = "unit_tests/zk_login_authenticator_test.rs"]
27mod zk_login_authenticator_test;
28
29#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct ZkLoginAuthenticator {
33 pub inputs: ZkLoginInputs,
34 max_epoch: EpochId,
35 pub user_signature: Signature,
36 #[serde(skip)]
37 pub bytes: OnceCell<Vec<u8>>,
38}
39
40#[derive(Serialize, Deserialize)]
44struct ZkLoginCachingParams {
45 inputs: ZkLoginInputs,
46 max_epoch: EpochId,
47 extended_pk_bytes: Vec<u8>,
48}
49
50impl ZkLoginAuthenticator {
51 fn get_caching_params(&self) -> ZkLoginCachingParams {
55 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
56 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
57 ZkLoginCachingParams {
58 inputs: self.inputs.clone(),
59 max_epoch: self.max_epoch,
60 extended_pk_bytes,
61 }
62 }
63
64 pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
65 use fastcrypto::hash::HashFunction;
66 let mut hasher = DefaultHash::default();
67 hasher.update(bcs::to_bytes(&self.get_caching_params()).expect("serde should not fail"));
68 ZKLoginInputsDigest::new(hasher.finalize().into())
69 }
70
71 pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
73 Self {
74 inputs,
75 max_epoch,
76 user_signature,
77 bytes: OnceCell::new(),
78 }
79 }
80
81 pub fn get_pk(&self) -> SuiResult<PublicKey> {
82 PublicKey::from_zklogin_inputs(&self.inputs)
83 }
84
85 pub fn get_iss(&self) -> &str {
86 self.inputs.get_iss()
87 }
88
89 pub fn get_max_epoch(&self) -> EpochId {
90 self.max_epoch
91 }
92
93 pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
94 &mut self.user_signature
95 }
96 pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
97 &mut self.max_epoch
98 }
99 pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
100 &mut self.inputs
101 }
102}
103
104impl PartialEq for ZkLoginAuthenticator {
106 fn eq(&self, other: &Self) -> bool {
107 self.as_ref() == other.as_ref()
108 }
109}
110
111impl Eq for ZkLoginAuthenticator {}
113
114impl Hash for ZkLoginAuthenticator {
116 fn hash<H: Hasher>(&self, state: &mut H) {
117 self.as_ref().hash(state);
118 }
119}
120
121impl AuthenticatorTrait for ZkLoginAuthenticator {
122 fn verify_user_authenticator_epoch(
123 &self,
124 epoch: EpochId,
125 max_epoch_upper_bound_delta: Option<u64>,
126 ) -> SuiResult {
127 if let Some(delta) = max_epoch_upper_bound_delta {
130 let max_epoch_upper_bound = epoch + delta;
131 if self.get_max_epoch() > max_epoch_upper_bound {
132 return Err(SuiErrorKind::InvalidSignature {
133 error: format!(
134 "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
135 self.get_max_epoch(),
136 epoch,
137 max_epoch_upper_bound
138 ),
139 }
140 .into());
141 }
142 }
143 if epoch > self.get_max_epoch() {
145 return Err(SuiErrorKind::InvalidSignature {
146 error: format!(
147 "ZKLogin expired at epoch {}, current epoch {}",
148 self.get_max_epoch(),
149 epoch
150 ),
151 }
152 .into());
153 }
154 Ok(())
155 }
156
157 fn verify_claims<T>(
159 &self,
160 intent_msg: &IntentMessage<T>,
161 author: SuiAddress,
162 aux_verify_data: &VerifyParams,
163 zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
164 ) -> SuiResult
165 where
166 T: Serialize,
167 {
168 if author != SuiAddress::try_from_unpadded(&self.inputs)? {
170 if !aux_verify_data.verify_legacy_zklogin_address
172 || author != SuiAddress::try_from_padded(&self.inputs)?
173 {
174 return Err(SuiErrorKind::InvalidAddress.into());
175 }
176 }
177
178 if !aux_verify_data.supported_providers.is_empty()
181 && !aux_verify_data.supported_providers.contains(
182 &OIDCProvider::from_iss(self.inputs.get_iss()).map_err(|_| {
183 SuiErrorKind::InvalidSignature {
184 error: "Unknown provider".to_string(),
185 }
186 })?,
187 )
188 {
189 return Err(SuiErrorKind::InvalidSignature {
190 error: format!("OIDC provider not supported: {}", self.inputs.get_iss()),
191 }
192 .into());
193 }
194
195 self.user_signature.verify_secure(
197 intent_msg,
198 author,
199 SignatureScheme::ZkLoginAuthenticator,
200 )?;
201
202 if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
203 Ok(())
206 } else {
207 let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
210 extended_pk_bytes.extend(self.user_signature.public_key_bytes());
211 let res = verify_zklogin_inputs_wrapper(
212 self.get_caching_params(),
213 &aux_verify_data.oidc_provider_jwks,
214 &aux_verify_data.zk_login_env,
215 )
216 .map_err(|e| {
217 SuiErrorKind::InvalidSignature {
218 error: e.to_string(),
219 }
220 .into()
221 });
222 match res {
223 Ok(_) => {
224 zklogin_inputs_cache.cache_digest(self.hash_inputs());
226 Ok(())
227 }
228 Err(e) => Err(e),
229 }
230 }
231 }
232}
233
234fn verify_zklogin_inputs_wrapper(
235 params: ZkLoginCachingParams,
236 all_jwk: &im::HashMap<JwkId, JWK>,
237 env: &ZkLoginEnv,
238) -> SuiResult<()> {
239 verify_zk_login(
240 ¶ms.inputs,
241 params.max_epoch,
242 ¶ms.extended_pk_bytes,
243 all_jwk,
244 env,
245 )
246 .map_err(|e| {
247 SuiErrorKind::InvalidSignature {
248 error: e.to_string(),
249 }
250 .into()
251 })
252}
253
254impl ToFromBytes for ZkLoginAuthenticator {
255 fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
256 if bytes.first().ok_or(FastCryptoError::InvalidInput)?
258 != &SignatureScheme::ZkLoginAuthenticator.flag()
259 {
260 return Err(FastCryptoError::InvalidInput);
261 }
262 let mut zk_login: ZkLoginAuthenticator =
263 bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
264 zk_login.inputs.init()?;
265 Ok(zk_login)
266 }
267}
268
269impl AsRef<[u8]> for ZkLoginAuthenticator {
270 fn as_ref(&self) -> &[u8] {
271 self.bytes
272 .get_or_try_init::<_, eyre::Report>(|| {
273 let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
274 let mut bytes = Vec::with_capacity(1 + as_bytes.len());
275 bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
276 bytes.extend_from_slice(as_bytes.as_slice());
277 Ok(bytes)
278 })
279 .expect("OnceCell invariant violated")
280 }
281}
282
283#[derive(Debug, Clone)]
284pub struct AddressSeed([u8; 32]);
285
286impl AddressSeed {
287 pub fn unpadded(&self) -> &[u8] {
288 let mut buf = self.0.as_slice();
289
290 while !buf.is_empty() && buf[0] == 0 {
291 buf = &buf[1..];
292 }
293
294 if buf.is_empty() { &self.0[31..] } else { buf }
296 }
297
298 pub fn padded(&self) -> &[u8] {
299 &self.0
300 }
301}
302
303impl std::fmt::Display for AddressSeed {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
306 let radix10 = big_int.to_str_radix(10);
307 f.write_str(&radix10)
308 }
309}
310
311#[derive(thiserror::Error, Debug)]
312pub enum AddressSeedParseError {
313 #[error("unable to parse radix10 encoded value `{0}`")]
314 Parse(#[from] num_bigint::ParseBigIntError),
315 #[error("larger than 32 bytes")]
316 TooBig,
317}
318
319impl std::str::FromStr for AddressSeed {
320 type Err = AddressSeedParseError;
321
322 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
324 let be_bytes = big_int.to_bytes_be();
325 let len = be_bytes.len();
326 let mut buf = [0; 32];
327
328 if len > 32 {
329 return Err(AddressSeedParseError::TooBig);
330 }
331
332 buf[32 - len..].copy_from_slice(&be_bytes);
333 Ok(Self(buf))
334 }
335}
336
337impl Serialize for AddressSeed {
339 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
340 where
341 S: serde::Serializer,
342 {
343 self.to_string().serialize(serializer)
344 }
345}
346
347impl<'de> Deserialize<'de> for AddressSeed {
348 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
349 where
350 D: serde::Deserializer<'de>,
351 {
352 let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
353 std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
354 }
355}
356
357#[cfg(test)]
358mod test {
359 use std::str::FromStr;
360
361 use super::AddressSeed;
362 use num_bigint::BigUint;
363 use proptest::prelude::*;
364
365 #[test]
366 fn unpadded_slice() {
367 let seed = AddressSeed([0; 32]);
368 let zero: [u8; 1] = [0];
369 assert_eq!(seed.unpadded(), zero.as_slice());
370
371 let mut seed = AddressSeed([1; 32]);
372 seed.0[0] = 0;
373 assert_eq!(seed.unpadded(), [1; 31].as_slice());
374 }
375
376 proptest! {
377 #[test]
378 fn dont_crash_on_large_inputs(
379 bytes in proptest::collection::vec(any::<u8>(), 33..1024)
380 ) {
381 let big_int = BigUint::from_bytes_be(&bytes);
382 let radix10 = big_int.to_str_radix(10);
383
384 let _ = AddressSeed::from_str(&radix10);
386 }
387
388 #[test]
389 fn valid_address_seeds(
390 bytes in proptest::collection::vec(any::<u8>(), 1..=32)
391 ) {
392 let big_int = BigUint::from_bytes_be(&bytes);
393 let radix10 = big_int.to_str_radix(10);
394
395 let seed = AddressSeed::from_str(&radix10).unwrap();
396 assert_eq!(radix10, seed.to_string());
397 seed.unpadded();
399 }
400 }
401}