sui_types/
zk_login_authenticator.rs

1// Copyright (c) 2021, Facebook, Inc. and its affiliates
2// Copyright (c) Mysten Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4use crate::crypto::PublicKey;
5use crate::signature_verification::VerifiedDigestCache;
6use crate::{
7    base_types::{EpochId, SuiAddress},
8    crypto::{DefaultHash, Signature, SignatureScheme, SuiSignature},
9    digests::ZKLoginInputsDigest,
10    error::{SuiErrorKind, SuiResult},
11    signature::{AuthenticatorTrait, VerifyParams},
12};
13use fastcrypto::{error::FastCryptoError, traits::ToFromBytes};
14use fastcrypto_zkp::bn254::zk_login::JwkId;
15use fastcrypto_zkp::bn254::zk_login::{JWK, OIDCProvider};
16use fastcrypto_zkp::bn254::zk_login_api::ZkLoginEnv;
17use fastcrypto_zkp::bn254::{zk_login::ZkLoginInputs, zk_login_api::verify_zk_login};
18use once_cell::sync::OnceCell;
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use shared_crypto::intent::IntentMessage;
22use std::hash::Hash;
23use std::hash::Hasher;
24use std::sync::Arc;
25#[cfg(test)]
26#[path = "unit_tests/zk_login_authenticator_test.rs"]
27mod zk_login_authenticator_test;
28
29/// An zk login authenticator with all the necessary fields.
30#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
31#[serde(rename_all = "camelCase")]
32pub struct ZkLoginAuthenticator {
33    pub inputs: ZkLoginInputs,
34    max_epoch: EpochId,
35    pub user_signature: Signature,
36    #[serde(skip)]
37    pub bytes: OnceCell<Vec<u8>>,
38}
39
40/// A helper struct that contains the necessary fields to calculate caching key.
41/// If the verify_zk_login() api changes, additional fields must be added here
42/// so the cache is not skipped.
43#[derive(Serialize, Deserialize)]
44struct ZkLoginCachingParams {
45    inputs: ZkLoginInputs,
46    max_epoch: EpochId,
47    extended_pk_bytes: Vec<u8>,
48}
49
50impl ZkLoginAuthenticator {
51    /// The caching key for zklogin signature, it is the hash of bcs bytes of
52    /// ZkLoginInputs || max_epoch || flagged_pk_bytes. If any of these fields
53    /// change, zklogin signature is re-verified without using the caching result.
54    fn get_caching_params(&self) -> ZkLoginCachingParams {
55        let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
56        extended_pk_bytes.extend(self.user_signature.public_key_bytes());
57        ZkLoginCachingParams {
58            inputs: self.inputs.clone(),
59            max_epoch: self.max_epoch,
60            extended_pk_bytes,
61        }
62    }
63
64    pub fn hash_inputs(&self) -> ZKLoginInputsDigest {
65        use fastcrypto::hash::HashFunction;
66        let mut hasher = DefaultHash::default();
67        bcs::serialize_into(&mut hasher, &self.get_caching_params())
68            .expect("serde should not fail");
69        ZKLoginInputsDigest::new(hasher.finalize().into())
70    }
71
72    /// Create a new [struct ZkLoginAuthenticator] with necessary fields.
73    pub fn new(inputs: ZkLoginInputs, max_epoch: EpochId, user_signature: Signature) -> Self {
74        Self {
75            inputs,
76            max_epoch,
77            user_signature,
78            bytes: OnceCell::new(),
79        }
80    }
81
82    pub fn get_pk(&self) -> SuiResult<PublicKey> {
83        PublicKey::from_zklogin_inputs(&self.inputs)
84    }
85
86    pub fn get_iss(&self) -> &str {
87        self.inputs.get_iss()
88    }
89
90    pub fn get_max_epoch(&self) -> EpochId {
91        self.max_epoch
92    }
93
94    pub fn user_signature_mut_for_testing(&mut self) -> &mut Signature {
95        &mut self.user_signature
96    }
97    pub fn max_epoch_mut_for_testing(&mut self) -> &mut EpochId {
98        &mut self.max_epoch
99    }
100    pub fn zk_login_inputs_mut_for_testing(&mut self) -> &mut ZkLoginInputs {
101        &mut self.inputs
102    }
103}
104
105/// Necessary trait for [struct SenderSignedData].
106impl PartialEq for ZkLoginAuthenticator {
107    fn eq(&self, other: &Self) -> bool {
108        self.as_ref() == other.as_ref()
109    }
110}
111
112/// Necessary trait for [struct SenderSignedData].
113impl Eq for ZkLoginAuthenticator {}
114
115/// Necessary trait for [struct SenderSignedData].
116impl Hash for ZkLoginAuthenticator {
117    fn hash<H: Hasher>(&self, state: &mut H) {
118        self.as_ref().hash(state);
119    }
120}
121
122impl AuthenticatorTrait for ZkLoginAuthenticator {
123    fn verify_user_authenticator_epoch(
124        &self,
125        epoch: EpochId,
126        max_epoch_upper_bound_delta: Option<u64>,
127    ) -> SuiResult {
128        // the checks here ensure that `current_epoch + max_epoch_upper_bound_delta >= self.max_epoch >= current_epoch`.
129        // 1. if the config for upper bound is set, ensure that the max epoch in signature is not larger than epoch + upper_bound.
130        if let Some(delta) = max_epoch_upper_bound_delta {
131            let max_epoch_upper_bound = epoch + delta;
132            if self.get_max_epoch() > max_epoch_upper_bound {
133                return Err(SuiErrorKind::InvalidSignature {
134                    error: format!(
135                        "ZKLogin max epoch too large {}, current epoch {}, max accepted: {}",
136                        self.get_max_epoch(),
137                        epoch,
138                        max_epoch_upper_bound
139                    ),
140                }
141                .into());
142            }
143        }
144        // 2. ensure that max epoch in signature is greater than the current epoch.
145        if epoch > self.get_max_epoch() {
146            return Err(SuiErrorKind::InvalidSignature {
147                error: format!(
148                    "ZKLogin expired at epoch {}, current epoch {}",
149                    self.get_max_epoch(),
150                    epoch
151                ),
152            }
153            .into());
154        }
155        Ok(())
156    }
157
158    /// Verify an intent message of a transaction with an zk login authenticator.
159    fn verify_claims<T>(
160        &self,
161        intent_msg: &IntentMessage<T>,
162        author: SuiAddress,
163        aux_verify_data: &VerifyParams,
164        zklogin_inputs_cache: Arc<VerifiedDigestCache<ZKLoginInputsDigest>>,
165    ) -> SuiResult
166    where
167        T: Serialize,
168    {
169        // Always evaluate the unpadded address derivation.
170        if author != SuiAddress::try_from_unpadded(&self.inputs)? {
171            // If the verify_legacy_zklogin_address flag is set, also evaluate the padded address derivation.
172            if !aux_verify_data.verify_legacy_zklogin_address
173                || author != SuiAddress::try_from_padded(&self.inputs)?
174            {
175                return Err(SuiErrorKind::InvalidAddress.into());
176            }
177        }
178
179        // Only when supported_providers list is not empty, we check if the provider is supported. Otherwise,
180        // we just use the JWK map to check if its supported.
181        if !aux_verify_data.supported_providers.is_empty()
182            && !aux_verify_data.supported_providers.contains(
183                &OIDCProvider::from_iss(self.inputs.get_iss()).map_err(|_| {
184                    SuiErrorKind::InvalidSignature {
185                        error: "Unknown provider".to_string(),
186                    }
187                })?,
188            )
189        {
190            return Err(SuiErrorKind::InvalidSignature {
191                error: format!("OIDC provider not supported: {}", self.inputs.get_iss()),
192            }
193            .into());
194        }
195
196        // Verify the ephemeral signature over the intent message of the transaction data.
197        self.user_signature.verify_secure(
198            intent_msg,
199            author,
200            SignatureScheme::ZkLoginAuthenticator,
201        )?;
202
203        if zklogin_inputs_cache.is_cached(&self.hash_inputs()) {
204            // If the zklogin inputs hits the cache, we don't need to verify the zklogin
205            // again that contains the heavy computation.
206            Ok(())
207        } else {
208            // if it is not cached, we verify the full zklogin inputs.
209            // build extended_pk_bytes as flag || pk_bytes.
210            let mut extended_pk_bytes = vec![self.user_signature.scheme().flag()];
211            extended_pk_bytes.extend(self.user_signature.public_key_bytes());
212            let res = verify_zklogin_inputs_wrapper(
213                self.get_caching_params(),
214                &aux_verify_data.oidc_provider_jwks,
215                &aux_verify_data.zk_login_env,
216            )
217            .map_err(|e| {
218                SuiErrorKind::InvalidSignature {
219                    error: e.to_string(),
220                }
221                .into()
222            });
223            match res {
224                Ok(_) => {
225                    // If it's verified ok, we cache the digest.
226                    zklogin_inputs_cache.cache_digest(self.hash_inputs());
227                    Ok(())
228                }
229                Err(e) => Err(e),
230            }
231        }
232    }
233}
234
235fn verify_zklogin_inputs_wrapper(
236    params: ZkLoginCachingParams,
237    all_jwk: &im::HashMap<JwkId, JWK>,
238    env: &ZkLoginEnv,
239) -> SuiResult<()> {
240    verify_zk_login(
241        &params.inputs,
242        params.max_epoch,
243        &params.extended_pk_bytes,
244        all_jwk,
245        env,
246    )
247    .map_err(|e| {
248        SuiErrorKind::InvalidSignature {
249            error: e.to_string(),
250        }
251        .into()
252    })
253}
254
255impl ToFromBytes for ZkLoginAuthenticator {
256    fn from_bytes(bytes: &[u8]) -> Result<Self, FastCryptoError> {
257        // The first byte matches the flag of MultiSig.
258        if bytes.first().ok_or(FastCryptoError::InvalidInput)?
259            != &SignatureScheme::ZkLoginAuthenticator.flag()
260        {
261            return Err(FastCryptoError::InvalidInput);
262        }
263        let mut zk_login: ZkLoginAuthenticator =
264            bcs::from_bytes(&bytes[1..]).map_err(|_| FastCryptoError::InvalidSignature)?;
265        zk_login.inputs.init()?;
266        Ok(zk_login)
267    }
268}
269
270impl AsRef<[u8]> for ZkLoginAuthenticator {
271    fn as_ref(&self) -> &[u8] {
272        self.bytes
273            .get_or_try_init::<_, eyre::Report>(|| {
274                let as_bytes = bcs::to_bytes(self).expect("BCS serialization should not fail");
275                let mut bytes = Vec::with_capacity(1 + as_bytes.len());
276                bytes.push(SignatureScheme::ZkLoginAuthenticator.flag());
277                bytes.extend_from_slice(as_bytes.as_slice());
278                Ok(bytes)
279            })
280            .expect("OnceCell invariant violated")
281    }
282}
283
284#[derive(Debug, Clone)]
285pub struct AddressSeed([u8; 32]);
286
287impl AddressSeed {
288    pub fn unpadded(&self) -> &[u8] {
289        let mut buf = self.0.as_slice();
290
291        while !buf.is_empty() && buf[0] == 0 {
292            buf = &buf[1..];
293        }
294
295        // If the value is '0' then just return a slice of length 1 of the final byte
296        if buf.is_empty() { &self.0[31..] } else { buf }
297    }
298
299    pub fn padded(&self) -> &[u8] {
300        &self.0
301    }
302}
303
304impl std::fmt::Display for AddressSeed {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        let big_int = num_bigint::BigUint::from_bytes_be(&self.0);
307        let radix10 = big_int.to_str_radix(10);
308        f.write_str(&radix10)
309    }
310}
311
312#[derive(thiserror::Error, Debug)]
313pub enum AddressSeedParseError {
314    #[error("unable to parse radix10 encoded value `{0}`")]
315    Parse(#[from] num_bigint::ParseBigIntError),
316    #[error("larger than 32 bytes")]
317    TooBig,
318}
319
320impl std::str::FromStr for AddressSeed {
321    type Err = AddressSeedParseError;
322
323    fn from_str(s: &str) -> Result<Self, Self::Err> {
324        let big_int = <num_bigint::BigUint as num_traits::Num>::from_str_radix(s, 10)?;
325        let be_bytes = big_int.to_bytes_be();
326        let len = be_bytes.len();
327        let mut buf = [0; 32];
328
329        if len > 32 {
330            return Err(AddressSeedParseError::TooBig);
331        }
332
333        buf[32 - len..].copy_from_slice(&be_bytes);
334        Ok(Self(buf))
335    }
336}
337
338// AddressSeed's serialized format is as a radix10 encoded string
339impl Serialize for AddressSeed {
340    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
341    where
342        S: serde::Serializer,
343    {
344        self.to_string().serialize(serializer)
345    }
346}
347
348impl<'de> Deserialize<'de> for AddressSeed {
349    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
350    where
351        D: serde::Deserializer<'de>,
352    {
353        let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
354        std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
355    }
356}
357
358#[cfg(test)]
359mod test {
360    use std::str::FromStr;
361
362    use super::AddressSeed;
363    use num_bigint::BigUint;
364    use proptest::prelude::*;
365
366    #[test]
367    fn unpadded_slice() {
368        let seed = AddressSeed([0; 32]);
369        let zero: [u8; 1] = [0];
370        assert_eq!(seed.unpadded(), zero.as_slice());
371
372        let mut seed = AddressSeed([1; 32]);
373        seed.0[0] = 0;
374        assert_eq!(seed.unpadded(), [1; 31].as_slice());
375    }
376
377    proptest! {
378        #[test]
379        fn dont_crash_on_large_inputs(
380            bytes in proptest::collection::vec(any::<u8>(), 33..1024)
381        ) {
382            let big_int = BigUint::from_bytes_be(&bytes);
383            let radix10 = big_int.to_str_radix(10);
384
385            // doesn't crash
386            let _ = AddressSeed::from_str(&radix10);
387        }
388
389        #[test]
390        fn valid_address_seeds(
391            bytes in proptest::collection::vec(any::<u8>(), 1..=32)
392        ) {
393            let big_int = BigUint::from_bytes_be(&bytes);
394            let radix10 = big_int.to_str_radix(10);
395
396            let seed = AddressSeed::from_str(&radix10).unwrap();
397            assert_eq!(radix10, seed.to_string());
398            // Ensure unpadded doesn't crash
399            seed.unpadded();
400        }
401    }
402}