1use crate::crypto::PublicKey;
5use crate::signature_verification::VerifiedDigestCache;
6use crate::{
7    base_types::{EpochId, SuiAddress},
8    crypto::{DefaultHash, Signature, SignatureScheme, SuiSignature},
9    digests::ZKLoginInputsDigest,
10    error::{SuiErrorKind, SuiResult},
11    signature::{AuthenticatorTrait, VerifyParams},
12};
13use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
14use fastcrypto_zkp::bn254::zk_login::JwkId;
15use fastcrypto_zkp::bn254::zk_login::{JWK, OIDCProvider};
16use fastcrypto_zkp::bn254::zk_login_api::ZkLoginEnv;
17use fastcrypto_zkp::bn254::{zk_login::ZkLoginInputs, zk_login_api::verify_zk_login};
18use once_cell::sync::OnceCell;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use shared_crypto::intent::IntentMessage;
22use std::hash::Hash;
23use std::hash::Hasher;
24use std::sync::Arc;
25#[cfg(test)]
26#[path = "unit_tests/zk_login_authenticator_test.rs"]
27mod zk_login_authenticator_test;
28
29#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct ZkLoginAuthenticator {
33    pub inputs: ZkLoginInputs,
34    max_epoch: EpochId,
35    pub user_signature: Signature,
36    #[serde(skip)]
37    pub bytes: OnceCell<Vec<u8>>,
38}
39
40#[derive(Serialize, Deserialize)]
44struct ZkLoginCachingParams {
45    inputs: ZkLoginInputs,
46    max_epoch: EpochId,
47    extended_pk_bytes: Vec<u8>,
48}
49
50impl ZkLoginAuthenticator {
51    fn get_caching_params(&self) -> ZkLoginCachingParams {
55        let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
56        extended_pk_bytes.extend(self.user_signature.public_key_bytes());
57        ZkLoginCachingParams {
58            inputs: self.inputs.clone(),
59            max_epoch: self.max_epoch,
60            extended_pk_bytes,
61        }
62    }
63
64    pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
65        use fastcrypto::hash::HashFunction;
66        let mut hasher = DefaultHash::default();
67        hasher.update(bcs::to_bytes(&self.get_caching_params()).expect("serde should not fail"));
68        ZKLoginInputsDigest::new(hasher.finalize().into())
69    }
70
71    pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
73        Self {
74            inputs,
75            max_epoch,
76            user_signature,
77            bytes: OnceCell::new(),
78        }
79    }
80
81    pub fn get_pk(&self) -> SuiResult<PublicKey> {
82        PublicKey::from_zklogin_inputs(&self.inputs)
83    }
84
85    pub fn get_iss(&self) -> &str {
86        self.inputs.get_iss()
87    }
88
89    pub fn get_max_epoch(&self) -> EpochId {
90        self.max_epoch
91    }
92
93    pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
94        &mut self.user_signature
95    }
96    pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
97        &mut self.max_epoch
98    }
99    pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
100        &mut self.inputs
101    }
102}
103
104impl PartialEq for ZkLoginAuthenticator {
106    fn eq(&self, other: &Self) -> bool {
107        self.as_ref() == other.as_ref()
108    }
109}
110
111impl Eq for ZkLoginAuthenticator {}
113
114impl Hash for ZkLoginAuthenticator {
116    fn hash<H: Hasher>(&self, state: &mut H) {
117        self.as_ref().hash(state);
118    }
119}
120
121impl AuthenticatorTrait for ZkLoginAuthenticator {
122    fn verify_user_authenticator_epoch(
123        &self,
124        epoch: EpochId,
125        max_epoch_upper_bound_delta: Option<u64>,
126    ) -> SuiResult {
127        if let Some(delta) = max_epoch_upper_bound_delta {
130            let max_epoch_upper_bound = epoch + delta;
131            if self.get_max_epoch() > max_epoch_upper_bound {
132                return Err(SuiErrorKind::InvalidSignature {
133                    error: format!(
134                        "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
135                        self.get_max_epoch(),
136                        epoch,
137                        max_epoch_upper_bound
138                    ),
139                }
140                .into());
141            }
142        }
143        if epoch > self.get_max_epoch() {
145            return Err(SuiErrorKind::InvalidSignature {
146                error: format!(
147                    "ZKLogin expired at epoch {}, current epoch {}",
148                    self.get_max_epoch(),
149                    epoch
150                ),
151            }
152            .into());
153        }
154        Ok(())
155    }
156
157    fn verify_claims<T>(
159        &self,
160        intent_msg: &IntentMessage<T>,
161        author: SuiAddress,
162        aux_verify_data: &VerifyParams,
163        zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
164    ) -> SuiResult
165    where
166        T: Serialize,
167    {
168        if author != SuiAddress::try_from_unpadded(&self.inputs)? {
170            if !aux_verify_data.verify_legacy_zklogin_address
172                || author != SuiAddress::try_from_padded(&self.inputs)?
173            {
174                return Err(SuiErrorKind::InvalidAddress.into());
175            }
176        }
177
178        if !aux_verify_data.supported_providers.is_empty()
181            && !aux_verify_data.supported_providers.contains(
182                &OIDCProvider::from_iss(self.inputs.get_iss()).map_err(|_| {
183                    SuiErrorKind::InvalidSignature {
184                        error: "Unknown provider".to_string(),
185                    }
186                })?,
187            )
188        {
189            return Err(SuiErrorKind::InvalidSignature {
190                error: format!("OIDC provider not supported: {}", self.inputs.get_iss()),
191            }
192            .into());
193        }
194
195        self.user_signature.verify_secure(
197            intent_msg,
198            author,
199            SignatureScheme::ZkLoginAuthenticator,
200        )?;
201
202        if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
203            Ok(())
206        } else {
207            let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
210            extended_pk_bytes.extend(self.user_signature.public_key_bytes());
211            let res = verify_zklogin_inputs_wrapper(
212                self.get_caching_params(),
213                &aux_verify_data.oidc_provider_jwks,
214                &aux_verify_data.zk_login_env,
215            )
216            .map_err(|e| {
217                SuiErrorKind::InvalidSignature {
218                    error: e.to_string(),
219                }
220                .into()
221            });
222            match res {
223                Ok(_) => {
224                    zklogin_inputs_cache.cache_digest(self.hash_inputs());
226                    Ok(())
227                }
228                Err(e) => Err(e),
229            }
230        }
231    }
232}
233
234fn verify_zklogin_inputs_wrapper(
235    params: ZkLoginCachingParams,
236    all_jwk: &im::HashMap<JwkId, JWK>,
237    env: &ZkLoginEnv,
238) -> SuiResult<()> {
239    verify_zk_login(
240        ¶ms.inputs,
241        params.max_epoch,
242        ¶ms.extended_pk_bytes,
243        all_jwk,
244        env,
245    )
246    .map_err(|e| {
247        SuiErrorKind::InvalidSignature {
248            error: e.to_string(),
249        }
250        .into()
251    })
252}
253
254impl ToFromBytes for ZkLoginAuthenticator {
255    fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
256        if bytes.first().ok_or(FastCryptoError::InvalidInput)?
258            != &SignatureScheme::ZkLoginAuthenticator.flag()
259        {
260            return Err(FastCryptoError::InvalidInput);
261        }
262        let mut zk_login: ZkLoginAuthenticator =
263            bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
264        zk_login.inputs.init()?;
265        Ok(zk_login)
266    }
267}
268
269impl AsRef<[u8]> for ZkLoginAuthenticator {
270    fn as_ref(&self) -> &[u8] {
271        self.bytes
272            .get_or_try_init::<_, eyre::Report>(|| {
273                let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
274                let mut bytes = Vec::with_capacity(1 + as_bytes.len());
275                bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
276                bytes.extend_from_slice(as_bytes.as_slice());
277                Ok(bytes)
278            })
279            .expect("OnceCell invariant violated")
280    }
281}
282
283#[derive(Debug, Clone)]
284pub struct AddressSeed([u8; 32]);
285
286impl AddressSeed {
287    pub fn unpadded(&self) -> &[u8] {
288        let mut buf = self.0.as_slice();
289
290        while !buf.is_empty() && buf[0] == 0 {
291            buf = &buf[1..];
292        }
293
294        if buf.is_empty() { &self.0[31..] } else { buf }
296    }
297
298    pub fn padded(&self) -> &[u8] {
299        &self.0
300    }
301}
302
303impl std::fmt::Display for AddressSeed {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
306        let radix10 = big_int.to_str_radix(10);
307        f.write_str(&radix10)
308    }
309}
310
311#[derive(thiserror::Error, Debug)]
312pub enum AddressSeedParseError {
313    #[error("unable to parse radix10 encoded value `{0}`")]
314    Parse(#[from] num_bigint::ParseBigIntError),
315    #[error("larger than 32 bytes")]
316    TooBig,
317}
318
319impl std::str::FromStr for AddressSeed {
320    type Err = AddressSeedParseError;
321
322    fn from_str(s: &str) -> Result<Self, Self::Err> {
323        let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
324        let be_bytes = big_int.to_bytes_be();
325        let len = be_bytes.len();
326        let mut buf = [0; 32];
327
328        if len > 32 {
329            return Err(AddressSeedParseError::TooBig);
330        }
331
332        buf[32 - len..].copy_from_slice(&be_bytes);
333        Ok(Self(buf))
334    }
335}
336
337impl Serialize for AddressSeed {
339    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
340    where
341        S: serde::Serializer,
342    {
343        self.to_string().serialize(serializer)
344    }
345}
346
347impl<'de> Deserialize<'de> for AddressSeed {
348    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
349    where
350        D: serde::Deserializer<'de>,
351    {
352        let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
353        std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
354    }
355}
356
357#[cfg(test)]
358mod test {
359    use std::str::FromStr;
360
361    use super::AddressSeed;
362    use num_bigint::BigUint;
363    use proptest::prelude::*;
364
365    #[test]
366    fn unpadded_slice() {
367        let seed = AddressSeed([0; 32]);
368        let zero: [u8; 1] = [0];
369        assert_eq!(seed.unpadded(), zero.as_slice());
370
371        let mut seed = AddressSeed([1; 32]);
372        seed.0[0] = 0;
373        assert_eq!(seed.unpadded(), [1; 31].as_slice());
374    }
375
376    proptest! {
377        #[test]
378        fn dont_crash_on_large_inputs(
379            bytes in proptest::collection::vec(any::<u8>(), 33..1024)
380        ) {
381            let big_int = BigUint::from_bytes_be(&bytes);
382            let radix10 = big_int.to_str_radix(10);
383
384            let _ = AddressSeed::from_str(&radix10);
386        }
387
388        #[test]
389        fn valid_address_seeds(
390            bytes in proptest::collection::vec(any::<u8>(), 1..=32)
391        ) {
392            let big_int = BigUint::from_bytes_be(&bytes);
393            let radix10 = big_int.to_str_radix(10);
394
395            let seed = AddressSeed::from_str(&radix10).unwrap();
396            assert_eq!(radix10, seed.to_string());
397            seed.unpadded();
399        }
400    }
401}