Skip to main content

sui_sdk_types/crypto/
zklogin.rs

1use super::SimpleSignature;
2use crate::checkpoint::EpochId;
3use crate::u256::U256;
4
5/// A zklogin authenticator
6///
7/// # BCS
8///
9/// The BCS serialized form for this type is defined by the following ABNF:
10///
11/// ```text
12/// zklogin-bcs = bytes             ; contents are defined by <zklogin-authenticator>
13/// zklogin     = zklogin-flag
14///               zklogin-inputs
15///               u64               ; max epoch
16///               simple-signature    
17/// ```
18///
19/// Note: Due to historical reasons, signatures are serialized slightly different from the majority
20/// of the types in Sui. In particular if a signature is ever embedded in another structure it
21/// generally is serialized as `bytes` meaning it has a length prefix that defines the length of
22/// the completely serialized signature.
23#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
25pub struct ZkLoginAuthenticator {
26    /// Zklogin proof and inputs required to perform proof verification.
27    pub inputs: ZkLoginInputs,
28
29    /// Maximum epoch for which the proof is valid.
30    pub max_epoch: EpochId,
31
32    /// User signature with the pubkey attested to by the provided proof.
33    pub signature: SimpleSignature,
34}
35
36/// A zklogin groth16 proof and the required inputs to perform proof verification.
37///
38/// # BCS
39///
40/// The BCS serialized form for this type is defined by the following ABNF:
41///
42/// ```text
43/// zklogin-inputs = zklogin-proof
44///                  zklogin-claim
45///                  string              ; base64url-unpadded encoded JwtHeader
46///                  bn254-field-element ; address_seed
47/// ```
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct ZkLoginInputs {
50    proof_points: ZkLoginProof,
51    iss_base64_details: ZkLoginClaim,
52    header_base64: String,
53
54    jwt_header: JwtHeader,
55    jwk_id: JwkId,
56    public_identifier: ZkLoginPublicIdentifier,
57}
58
59impl ZkLoginInputs {
60    #[cfg(feature = "serde")]
61    #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
62    pub fn new(
63        proof_points: ZkLoginProof,
64        iss_base64_details: ZkLoginClaim,
65        header_base64: String,
66        address_seed: Bn254FieldElement,
67    ) -> Result<Self, InvalidZkLoginAuthenticatorError> {
68        let iss = {
69            const ISS: &str = "iss";
70
71            let iss = iss_base64_details.verify_extended_claim(ISS)?;
72
73            if iss.len() > 255 {
74                return Err(InvalidZkLoginAuthenticatorError::new(
75                    "invalid iss: too long",
76                ));
77            }
78            iss
79        };
80
81        let jwt_header = JwtHeader::from_base64(&header_base64)?;
82        let jwk_id = JwkId {
83            iss: iss.clone(),
84            kid: jwt_header.kid.clone(),
85        };
86
87        let public_identifier = ZkLoginPublicIdentifier { iss, address_seed };
88
89        Ok(Self {
90            proof_points,
91            iss_base64_details,
92            header_base64,
93            jwt_header,
94            jwk_id,
95            public_identifier,
96        })
97    }
98
99    pub fn proof_points(&self) -> &ZkLoginProof {
100        &self.proof_points
101    }
102
103    pub fn iss_base64_details(&self) -> &ZkLoginClaim {
104        &self.iss_base64_details
105    }
106
107    pub fn header_base64(&self) -> &str {
108        &self.header_base64
109    }
110
111    pub fn address_seed(&self) -> &Bn254FieldElement {
112        &self.public_identifier.address_seed
113    }
114
115    pub fn jwk_id(&self) -> &JwkId {
116        &self.jwk_id
117    }
118
119    pub fn iss(&self) -> &str {
120        &self.public_identifier.iss
121    }
122
123    pub fn public_identifier(&self) -> &ZkLoginPublicIdentifier {
124        &self.public_identifier
125    }
126}
127
128#[cfg(feature = "proptest")]
129impl proptest::arbitrary::Arbitrary for ZkLoginInputs {
130    type Parameters = ();
131    type Strategy = proptest::strategy::BoxedStrategy<Self>;
132
133    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
134        use proptest::prelude::*;
135
136        (any::<ZkLoginProof>(), any::<Bn254FieldElement>())
137            .prop_map(|(proof_points, address_seed)| {
138                //TODO implement Arbitrary for real for ZkLoginClaim and header_base64 values
139                let iss_base64_details = ZkLoginClaim {
140                    value: "wiaXNzIjoiaHR0cHM6Ly9pZC50d2l0Y2gudHYvb2F1dGgyIiw".to_owned(),
141                    index_mod_4: 2,
142                };
143                let header_base64 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ".to_owned();
144                Self::new(
145                    proof_points,
146                    iss_base64_details,
147                    header_base64,
148                    address_seed,
149                )
150                .unwrap()
151            })
152            .boxed()
153    }
154}
155
156/// A claim of the iss in a zklogin proof
157///
158/// # BCS
159///
160/// The BCS serialized form for this type is defined by the following ABNF:
161///
162/// ```text
163/// zklogin-claim = string u8
164/// ```
165#[derive(Debug, Clone, PartialEq, Eq)]
166#[cfg_attr(
167    feature = "serde",
168    derive(serde_derive::Serialize, serde_derive::Deserialize)
169)]
170#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
171pub struct ZkLoginClaim {
172    pub value: String,
173    pub index_mod_4: u8,
174}
175
176#[derive(Debug)]
177pub struct InvalidZkLoginAuthenticatorError(String);
178
179#[cfg(feature = "serde")]
180#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
181impl InvalidZkLoginAuthenticatorError {
182    fn new<T: Into<String>>(err: T) -> Self {
183        Self(err.into())
184    }
185}
186
187impl std::fmt::Display for InvalidZkLoginAuthenticatorError {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        write!(f, "invalid zklogin claim: {}", self.0)
190    }
191}
192
193impl std::error::Error for InvalidZkLoginAuthenticatorError {}
194
195#[cfg(feature = "serde")]
196#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
197impl ZkLoginClaim {
198    fn verify_extended_claim(
199        &self,
200        expected_key: &str,
201    ) -> Result<String, InvalidZkLoginAuthenticatorError> {
202        /// Map a base64 string to a bit array by taking each char's index and convert it to binary form with one bit per u8
203        /// element in the output. Returns InvalidZkLoginClaimError if one of the characters is not in the base64 charset.
204        fn base64_to_bitarray(input: &str) -> Result<Vec<u8>, InvalidZkLoginAuthenticatorError> {
205            use itertools::Itertools;
206
207            const BASE64_URL_CHARSET: &str =
208                "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
209
210            input
211                .chars()
212                .map(|c| {
213                    BASE64_URL_CHARSET
214                        .find(c)
215                        .map(|index| index as u8)
216                        .map(|index| (0..6).rev().map(move |i| (index >> i) & 1))
217                        .ok_or_else(|| {
218                            InvalidZkLoginAuthenticatorError::new("base64_to_bitarry invalid input")
219                        })
220                })
221                .flatten_ok()
222                .collect()
223        }
224
225        /// Convert a bitarray (each bit is represented by a u8) to a byte array by taking each 8 bits as a
226        /// byte in big-endian format.
227        fn bitarray_to_bytearray(bits: &[u8]) -> Result<Vec<u8>, InvalidZkLoginAuthenticatorError> {
228            #[expect(clippy::manual_is_multiple_of)]
229            if bits.len() % 8 != 0 {
230                return Err(InvalidZkLoginAuthenticatorError::new(
231                    "bitarray_to_bytearray invalid input",
232                ));
233            }
234            Ok(bits
235                .chunks(8)
236                .map(|chunk| {
237                    let mut byte = 0u8;
238                    for (i, bit) in chunk.iter().rev().enumerate() {
239                        byte |= bit << i;
240                    }
241                    byte
242                })
243                .collect())
244        }
245
246        /// Parse the base64 string, add paddings based on offset, and convert to a bytearray.
247        fn decode_base64_url(
248            s: &str,
249            index_mod_4: &u8,
250        ) -> Result<String, InvalidZkLoginAuthenticatorError> {
251            if s.len() < 2 {
252                return Err(InvalidZkLoginAuthenticatorError::new(
253                    "Base64 string smaller than 2",
254                ));
255            }
256            let mut bits = base64_to_bitarray(s)?;
257            match index_mod_4 {
258                0 => {}
259                1 => {
260                    bits.drain(..2);
261                }
262                2 => {
263                    bits.drain(..4);
264                }
265                _ => {
266                    return Err(InvalidZkLoginAuthenticatorError::new(
267                        "Invalid first_char_offset",
268                    ));
269                }
270            }
271
272            // Compute the offset in `usize` so that an `s.len()` past
273            // `u8::MAX` cannot wrap to a small value (or underflow when
274            // combined with the `- 1`). The earlier match has already
275            // narrowed `*index_mod_4` to `0..=2`, and `s.len() >= 2`,
276            // so the unsigned subtraction never underflows here.
277            let last_char_offset = (*index_mod_4 as usize + s.len() - 1) % 4;
278            match last_char_offset {
279                3 => {}
280                2 => {
281                    bits.drain(bits.len() - 2..);
282                }
283                1 => {
284                    bits.drain(bits.len() - 4..);
285                }
286                _ => {
287                    return Err(InvalidZkLoginAuthenticatorError::new(
288                        "Invalid last_char_offset",
289                    ));
290                }
291            }
292
293            if bits.len() % 8 != 0 {
294                return Err(InvalidZkLoginAuthenticatorError::new("Invalid bits length"));
295            }
296
297            Ok(std::str::from_utf8(&bitarray_to_bytearray(&bits)?)
298                .map_err(|_| InvalidZkLoginAuthenticatorError::new("Invalid UTF8 string"))?
299                .to_owned())
300        }
301
302        let extended_claim = decode_base64_url(&self.value, &self.index_mod_4)?;
303
304        // Last character of each extracted_claim must be '}' or ','
305        if !(extended_claim.ends_with('}') || extended_claim.ends_with(',')) {
306            return Err(InvalidZkLoginAuthenticatorError::new(
307                "Invalid extended claim",
308            ));
309        }
310
311        let json_str = format!("{{{}}}", &extended_claim[..extended_claim.len() - 1]);
312
313        serde_json::from_str::<serde_json::Value>(&json_str)
314            .map_err(|e| InvalidZkLoginAuthenticatorError::new(e.to_string()))?
315            .as_object_mut()
316            .and_then(|o| o.get_mut(expected_key))
317            .map(serde_json::Value::take)
318            .and_then(|v| match v {
319                serde_json::Value::String(s) => Some(s),
320                _ => None,
321            })
322            .ok_or_else(|| InvalidZkLoginAuthenticatorError::new("invalid extended claim"))
323    }
324}
325
326/// Struct that represents a standard JWT header according to
327/// https://openid.net/specs/openid-connect-core-1_0.html
328#[derive(Debug, Clone, PartialEq, Eq)]
329struct JwtHeader {
330    alg: String,
331    kid: String,
332    typ: Option<String>,
333}
334
335impl JwtHeader {
336    #[cfg(feature = "serde")]
337    fn from_base64(s: &str) -> Result<Self, InvalidZkLoginAuthenticatorError> {
338        use base64ct::Base64UrlUnpadded;
339        use base64ct::Encoding;
340
341        #[derive(serde_derive::Serialize, serde_derive::Deserialize)]
342        struct Header {
343            alg: String,
344            kid: String,
345            #[serde(skip_serializing_if = "Option::is_none")]
346            typ: Option<String>,
347        }
348
349        let header_bytes = Base64UrlUnpadded::decode_vec(s)
350            .map_err(|e| InvalidZkLoginAuthenticatorError::new(format!("invalid base64: {e}")))?;
351        let Header { alg, kid, typ } = serde_json::from_slice(&header_bytes)
352            .map_err(|e| InvalidZkLoginAuthenticatorError::new(format!("invalid json: {e}")))?;
353        if alg != "RS256" {
354            return Err(InvalidZkLoginAuthenticatorError::new(
355                "jwt alg must be RS256",
356            ));
357        }
358        Ok(Self { alg, kid, typ })
359    }
360}
361
362/// A zklogin groth16 proof
363///
364/// # BCS
365///
366/// The BCS serialized form for this type is defined by the following ABNF:
367///
368/// ```text
369/// zklogin-proof = circom-g1 circom-g2 circom-g1
370/// ```
371#[derive(Debug, Clone, PartialEq, Eq)]
372#[cfg_attr(
373    feature = "serde",
374    derive(serde_derive::Serialize, serde_derive::Deserialize)
375)]
376#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
377pub struct ZkLoginProof {
378    pub a: CircomG1,
379    pub b: CircomG2,
380    pub c: CircomG1,
381}
382
383/// A G1 point
384///
385/// This represents the canonical decimal representation of the projective coordinates in Fq.
386///
387/// # BCS
388///
389/// The BCS serialized form for this type is defined by the following ABNF:
390///
391/// ```text
392/// circom-g1 = %x03 3(bn254-field-element)
393/// ```
394#[derive(Clone, Debug, PartialEq, Eq)]
395#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
396pub struct CircomG1(pub [Bn254FieldElement; 3]);
397
398/// A G2 point
399///
400/// This represents the canonical decimal representation of the coefficients of the projective
401/// coordinates in Fq2.
402///
403/// # BCS
404///
405/// The BCS serialized form for this type is defined by the following ABNF:
406///
407/// ```text
408/// circom-g2 = %x03 3(%x02 2(bn254-field-element))
409/// ```
410#[derive(Clone, Debug, PartialEq, Eq)]
411#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
412pub struct CircomG2(pub [[Bn254FieldElement; 2]; 3]);
413
414/// Public Key equivalent for Zklogin authenticators
415///
416/// A `ZkLoginPublicIdentifier` is the equivalent of a public key for other account authenticators,
417/// and contains the information required to derive the onchain account [`Address`] for a Zklogin
418/// authenticator.
419///
420/// ## Note
421///
422/// Due to a historical bug that was introduced in the Sui Typescript SDK when the zklogin
423/// authenticator was first introduced, there are now possibly two "valid" addresses for each
424/// zklogin authenticator depending on the bit-pattern of the `address_seed` value.
425///
426/// The original bug incorrectly derived a zklogin's address by stripping any leading
427/// zero-bytes that could have been present in the 32-byte length `address_seed` value prior to
428/// hashing, leading to a different derived address. This incorrectly derived address was
429/// presented to users of various wallets, leading them to sending funds to these addresses
430/// that they couldn't access. Instead of letting these users lose any assets that were sent to
431/// these addresses, the Sui network decided to change the protocol to allow for a zklogin
432/// authenticator who's `address_seed` value had leading zero-bytes be authorized to sign for
433/// both the addresses derived from both the unpadded and padded `address_seed` value.
434///
435/// # BCS
436///
437/// The BCS serialized form for this type is defined by the following ABNF:
438///
439/// ```text
440/// zklogin-public-identifier-bcs = bytes ; where the contents are defined by
441///                                       ; <zklogin-public-identifier>
442///
443/// zklogin-public-identifier = zklogin-public-identifier-iss
444///                             address-seed
445///
446/// zklogin-public-identifier-unpadded = zklogin-public-identifier-iss
447///                                      address-seed-unpadded
448///
449/// ; The iss, or issuer, is a utf8 string that is less than 255 bytes long
450/// ; and is serialized with the iss's length in bytes as a u8 followed by
451/// ; the bytes of the iss
452/// zklogin-public-identifier-iss = u8 *255(OCTET)
453///
454/// ; A Bn254FieldElement serialized as a 32-byte big-endian value
455/// address-seed = 32(OCTET)
456///
457/// ; A Bn254FieldElement serialized as a 32-byte big-endian value
458/// ; with any leading zero bytes stripped
459/// address-seed-unpadded = %x00 / %x01-ff *31(OCTET)
460/// ```
461///
462/// [`Address`]: crate::Address
463#[derive(Clone, Debug, PartialEq, Eq)]
464#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
465pub struct ZkLoginPublicIdentifier {
466    iss: String,
467    address_seed: Bn254FieldElement,
468}
469
470impl ZkLoginPublicIdentifier {
471    pub fn new(iss: String, address_seed: Bn254FieldElement) -> Option<Self> {
472        if iss.len() > 255 {
473            None
474        } else {
475            Some(Self { iss, address_seed })
476        }
477    }
478
479    pub fn iss(&self) -> &str {
480        &self.iss
481    }
482
483    pub fn address_seed(&self) -> &Bn254FieldElement {
484        &self.address_seed
485    }
486}
487
488/// A JSON Web Key
489///
490/// Struct that contains info for a JWK. A list of them for different kids can
491/// be retrieved from the JWK endpoint (e.g. <https://www.googleapis.com/oauth2/v3/certs>).
492/// The JWK is used to verify the JWT token.
493///
494/// # BCS
495///
496/// The BCS serialized form for this type is defined by the following ABNF:
497///
498/// ```text
499/// jwk = string string string string
500/// ```
501#[derive(Clone, Debug, PartialEq, Eq, Hash)]
502#[cfg_attr(
503    feature = "serde",
504    derive(serde_derive::Serialize, serde_derive::Deserialize)
505)]
506#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
507pub struct Jwk {
508    /// Key type parameter, <https://datatracker.ietf.org/doc/html/rfc7517#section-4.1>
509    pub kty: String,
510
511    /// RSA public exponent, <https://datatracker.ietf.org/doc/html/rfc7517#section-9.3>
512    pub e: String,
513
514    /// RSA modulus, <https://datatracker.ietf.org/doc/html/rfc7517#section-9.3>
515    pub n: String,
516
517    /// Algorithm parameter, <https://datatracker.ietf.org/doc/html/rfc7517#section-4.4>
518    pub alg: String,
519}
520
521/// Key to uniquely identify a JWK
522///
523/// # BCS
524///
525/// The BCS serialized form for this type is defined by the following ABNF:
526///
527/// ```text
528/// jwk-id = string string
529/// ```
530#[derive(Clone, Debug, PartialEq, Eq, Hash)]
531#[cfg_attr(
532    feature = "serde",
533    derive(serde_derive::Serialize, serde_derive::Deserialize)
534)]
535#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
536pub struct JwkId {
537    /// The issuer or identity of the OIDC provider.
538    pub iss: String,
539
540    /// A key id use to uniquely identify a key from an OIDC provider.
541    pub kid: String,
542}
543
544/// A point on the BN254 elliptic curve.
545///
546/// This is a 32-byte, or 256-bit, value that is generally represented as radix10 when a
547/// human-readable display format is needed, and is represented as a 32-byte big-endian value while
548/// in memory.
549///
550/// # BCS
551///
552/// The BCS serialized form for this type is defined by the following ABNF:
553///
554/// ```text
555/// bn254-field-element = *DIGIT ; which is then interpreted as a radix10 encoded 32-byte value
556/// ```
557#[derive(Clone, Debug, Default, PartialEq, Eq)]
558#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
559pub struct Bn254FieldElement([u8; 32]);
560
561impl Bn254FieldElement {
562    pub const fn new(bytes: [u8; 32]) -> Self {
563        Self(bytes)
564    }
565
566    pub const fn from_str_radix_10(s: &str) -> Result<Self, Bn254FieldElementParseError> {
567        let u256 = match U256::from_str_radix(s, 10) {
568            Ok(u256) => u256,
569            Err(e) => return Err(Bn254FieldElementParseError(e)),
570        };
571        let be = u256.to_be();
572        Ok(Self(*be.digits()))
573    }
574
575    pub fn unpadded(&self) -> &[u8] {
576        let mut buf = self.0.as_slice();
577
578        while !buf.is_empty() && buf[0] == 0 {
579            buf = &buf[1..];
580        }
581
582        // If the value is '0' then just return a slice of length 1 of the final byte
583        if buf.is_empty() { &self.0[31..] } else { buf }
584    }
585
586    pub fn padded(&self) -> &[u8] {
587        &self.0
588    }
589}
590
591impl std::fmt::Display for Bn254FieldElement {
592    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        let u256 = U256::from_be(U256::from_digits(self.0));
594        let radix10 = u256.to_str_radix(10);
595        f.write_str(&radix10)
596    }
597}
598
599#[derive(Debug)]
600pub struct Bn254FieldElementParseError(crate::U256ParseError);
601
602impl std::fmt::Display for Bn254FieldElementParseError {
603    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
604        write!(f, "unable to parse radix10 encoded value {}", self.0)
605    }
606}
607
608impl std::error::Error for Bn254FieldElementParseError {}
609
610impl std::str::FromStr for Bn254FieldElement {
611    type Err = Bn254FieldElementParseError;
612
613    fn from_str(s: &str) -> Result<Self, Self::Err> {
614        let u256 = U256::from_str_radix(s, 10).map_err(Bn254FieldElementParseError)?;
615        let be = u256.to_be();
616        Ok(Self(*be.digits()))
617    }
618}
619
620#[cfg(test)]
621mod test {
622    use super::Bn254FieldElement;
623    use num_bigint::BigUint;
624    use proptest::prelude::*;
625    use std::str::FromStr;
626    use test_strategy::proptest;
627
628    #[cfg(target_arch = "wasm32")]
629    use wasm_bindgen_test::wasm_bindgen_test as test;
630
631    #[test]
632    fn unpadded_slice() {
633        let seed = Bn254FieldElement([0; 32]);
634        let zero: [u8; 1] = [0];
635        assert_eq!(seed.unpadded(), zero.as_slice());
636
637        let mut seed = Bn254FieldElement([1; 32]);
638        seed.0[0] = 0;
639        assert_eq!(seed.unpadded(), [1; 31].as_slice());
640    }
641
642    #[proptest]
643    fn dont_crash_on_large_inputs(
644        #[strategy(proptest::collection::vec(any::<u8>(), 33..1024))] bytes: Vec<u8>,
645    ) {
646        let big_int = BigUint::from_bytes_be(&bytes);
647        let radix10 = big_int.to_str_radix(10);
648
649        // doesn't crash
650        let _ = Bn254FieldElement::from_str(&radix10);
651    }
652
653    #[proptest]
654    fn valid_address_seeds(
655        #[strategy(proptest::collection::vec(any::<u8>(), 1..=32))] bytes: Vec<u8>,
656    ) {
657        let big_int = BigUint::from_bytes_be(&bytes);
658        let radix10 = big_int.to_str_radix(10);
659
660        let seed = Bn254FieldElement::from_str(&radix10).unwrap();
661        assert_eq!(radix10, seed.to_string());
662        // Ensure unpadded doesn't crash
663        seed.unpadded();
664    }
665
666    // Regression test: BCS deserialization used to call
667    // `DisplayFromStr::deserialize_as` directly, which accepts any
668    // radix10 string `bnum::U256::from_str_radix` would parse —
669    // including encodings with leading zeros like `"007"`. Two
670    // distinct BCS byte strings (`0x01 0x37` and `0x03 0x30 0x30 0x37`)
671    // therefore decoded to the same `Bn254FieldElement`, breaking the
672    // canonicality invariant that downstream signature deduplication
673    // and digesting rely on. The deserializer must now reject any
674    // encoding that does not round-trip through `Display`.
675    #[cfg(feature = "serde")]
676    #[test]
677    fn bcs_rejects_non_canonical_radix10_encoding() {
678        let canonical = bcs::to_bytes("7").unwrap();
679        let leading_zero = bcs::to_bytes("007").unwrap();
680        assert_ne!(canonical, leading_zero);
681
682        let parsed: Bn254FieldElement = bcs::from_bytes(&canonical).unwrap();
683        assert_eq!(parsed.to_string(), "7");
684
685        let err = bcs::from_bytes::<Bn254FieldElement>(&leading_zero).unwrap_err();
686        assert!(
687            err.to_string().contains("non-canonical"),
688            "unexpected error: {err}"
689        );
690    }
691
692    // Regression test: `decode_base64_url` used to compute
693    // `(index_mod_4 + s.len() as u8 - 1) % 4` in `u8`. With an
694    // attacker-supplied JWT claim value of length 256 the cast
695    // truncates `s.len()` to `0` and the subtraction underflows,
696    // panicking under `debug_assertions`. The arithmetic must now use
697    // `usize` so that the function returns a structured error rather
698    // than crashing the caller.
699    #[cfg(feature = "serde")]
700    #[test]
701    fn long_claim_value_does_not_panic_on_u8_overflow() {
702        use super::ZkLoginClaim;
703
704        let claim = ZkLoginClaim {
705            value: "A".repeat(256),
706            index_mod_4: 0,
707        };
708        assert!(claim.verify_extended_claim("iss").is_err());
709    }
710}
711
712#[cfg(feature = "serde")]
713#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
714mod serialization {
715    use crate::SignatureScheme;
716
717    use super::*;
718    use serde::Deserialize;
719    use serde::Deserializer;
720    use serde::Serialize;
721    use serde::Serializer;
722    use serde_with::Bytes;
723    use serde_with::DeserializeAs;
724    use serde_with::SerializeAs;
725    use std::borrow::Cow;
726
727    // Serialized format is: iss_bytes_len || iss_bytes || padded_32_byte_address_seed.
728    impl Serialize for ZkLoginPublicIdentifier {
729        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
730        where
731            S: Serializer,
732        {
733            if serializer.is_human_readable() {
734                #[derive(serde_derive::Serialize)]
735                struct Readable<'a> {
736                    iss: &'a str,
737                    address_seed: &'a Bn254FieldElement,
738                }
739                let readable = Readable {
740                    iss: &self.iss,
741                    address_seed: &self.address_seed,
742                };
743                readable.serialize(serializer)
744            } else {
745                let mut buf = Vec::new();
746                let iss_bytes = self.iss.as_bytes();
747                buf.push(iss_bytes.len() as u8);
748                buf.extend(iss_bytes);
749
750                buf.extend(&self.address_seed.0);
751
752                serializer.serialize_bytes(&buf)
753            }
754        }
755    }
756
757    impl<'de> Deserialize<'de> for ZkLoginPublicIdentifier {
758        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
759        where
760            D: Deserializer<'de>,
761        {
762            if deserializer.is_human_readable() {
763                #[derive(serde_derive::Deserialize)]
764                struct Readable {
765                    iss: String,
766                    address_seed: Bn254FieldElement,
767                }
768
769                let Readable { iss, address_seed } = Deserialize::deserialize(deserializer)?;
770                Self::new(iss, address_seed)
771                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
772            } else {
773                let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
774                let iss_len = *bytes
775                    .first()
776                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
777                let iss_bytes = bytes
778                    .get(1..(1 + iss_len as usize))
779                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
780                let iss = std::str::from_utf8(iss_bytes).map_err(serde::de::Error::custom)?;
781                let address_seed_bytes = bytes
782                    .get((1 + iss_len as usize)..)
783                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
784
785                let address_seed = <[u8; 32]>::try_from(address_seed_bytes)
786                    .map_err(serde::de::Error::custom)
787                    .map(Bn254FieldElement)?;
788
789                Self::new(iss.into(), address_seed)
790                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
791            }
792        }
793    }
794
795    #[derive(serde_derive::Serialize)]
796    struct AuthenticatorRef<'a> {
797        inputs: &'a ZkLoginInputs,
798        max_epoch: EpochId,
799        signature: &'a SimpleSignature,
800    }
801
802    #[derive(serde_derive::Deserialize)]
803    struct Authenticator {
804        inputs: ZkLoginInputs,
805        max_epoch: EpochId,
806        signature: SimpleSignature,
807    }
808
809    impl Serialize for ZkLoginAuthenticator {
810        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
811        where
812            S: Serializer,
813        {
814            if serializer.is_human_readable() {
815                let authenticator_ref = AuthenticatorRef {
816                    inputs: &self.inputs,
817                    max_epoch: self.max_epoch,
818                    signature: &self.signature,
819                };
820
821                authenticator_ref.serialize(serializer)
822            } else {
823                let bytes = self.to_bytes();
824                serializer.serialize_bytes(&bytes)
825            }
826        }
827    }
828
829    impl<'de> Deserialize<'de> for ZkLoginAuthenticator {
830        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
831        where
832            D: Deserializer<'de>,
833        {
834            if deserializer.is_human_readable() {
835                let Authenticator {
836                    inputs,
837                    max_epoch,
838                    signature,
839                } = Authenticator::deserialize(deserializer)?;
840                Ok(Self {
841                    inputs,
842                    max_epoch,
843                    signature,
844                })
845            } else {
846                let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
847                Self::from_serialized_bytes(bytes)
848            }
849        }
850    }
851
852    impl ZkLoginAuthenticator {
853        pub(crate) fn to_bytes(&self) -> Vec<u8> {
854            let authenticator_ref = AuthenticatorRef {
855                inputs: &self.inputs,
856                max_epoch: self.max_epoch,
857                signature: &self.signature,
858            };
859
860            let mut buf = Vec::new();
861            buf.push(SignatureScheme::ZkLogin as u8);
862
863            bcs::serialize_into(&mut buf, &authenticator_ref).expect("serialization cannot fail");
864            buf
865        }
866
867        pub(crate) fn from_serialized_bytes<T: AsRef<[u8]>, E: serde::de::Error>(
868            bytes: T,
869        ) -> Result<Self, E> {
870            let bytes = bytes.as_ref();
871            let flag = SignatureScheme::from_byte(
872                *bytes
873                    .first()
874                    .ok_or_else(|| serde::de::Error::custom("missing signature scheme flag"))?,
875            )
876            .map_err(serde::de::Error::custom)?;
877            if flag != SignatureScheme::ZkLogin {
878                return Err(serde::de::Error::custom("invalid zklogin flag"));
879            }
880            let bcs_bytes = &bytes[1..];
881
882            let Authenticator {
883                inputs,
884                max_epoch,
885                signature,
886            } = bcs::from_bytes(bcs_bytes).map_err(serde::de::Error::custom)?;
887            Ok(Self {
888                inputs,
889                max_epoch,
890                signature,
891            })
892        }
893    }
894
895    impl Serialize for ZkLoginInputs {
896        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
897        where
898            S: Serializer,
899        {
900            #[derive(serde_derive::Serialize)]
901            struct Inputs<'a> {
902                proof_points: &'a ZkLoginProof,
903                iss_base64_details: &'a ZkLoginClaim,
904                header_base64: &'a str,
905                address_seed: &'a Bn254FieldElement,
906            }
907
908            Inputs {
909                proof_points: self.proof_points(),
910                iss_base64_details: self.iss_base64_details(),
911                header_base64: self.header_base64(),
912                address_seed: self.address_seed(),
913            }
914            .serialize(serializer)
915        }
916    }
917
918    impl<'de> Deserialize<'de> for ZkLoginInputs {
919        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
920        where
921            D: Deserializer<'de>,
922        {
923            #[derive(serde_derive::Deserialize)]
924            struct Inputs {
925                proof_points: ZkLoginProof,
926                iss_base64_details: ZkLoginClaim,
927                header_base64: String,
928                address_seed: Bn254FieldElement,
929            }
930
931            let Inputs {
932                proof_points,
933                iss_base64_details,
934                header_base64,
935                address_seed,
936            } = Inputs::deserialize(deserializer)?;
937            Self::new(
938                proof_points,
939                iss_base64_details,
940                header_base64,
941                address_seed,
942            )
943            .map_err(serde::de::Error::custom)
944        }
945    }
946
947    // AddressSeed's serialized format is as a radix10 encoded string
948    impl Serialize for Bn254FieldElement {
949        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
950        where
951            S: serde::Serializer,
952        {
953            serde_with::DisplayFromStr::serialize_as(self, serializer)
954        }
955    }
956
957    impl<'de> Deserialize<'de> for Bn254FieldElement {
958        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
959        where
960            D: Deserializer<'de>,
961        {
962            // `Display` strips any leading zeros from the radix10
963            // encoding while `FromStr` (via `bnum::U256::from_str_radix`)
964            // accepts them, so the bare `DisplayFromStr` round-trip is
965            // not canonical: the BCS encodings of e.g. `"7"` and
966            // `"007"` both parse to the same value but differ at the
967            // byte level. Any consumer that keys signature dedup,
968            // replay protection, or content digesting on the BCS
969            // bytes of a `ZkLoginAuthenticator` (which embeds ten
970            // `Bn254FieldElement`s) would treat such pairs as
971            // distinct, so reject any encoding whose `Display`
972            // round-trip differs from the input.
973            let s: std::borrow::Cow<'de, str> = Deserialize::deserialize(deserializer)?;
974            let value = s
975                .parse::<Bn254FieldElement>()
976                .map_err(serde::de::Error::custom)?;
977            if value.to_string() != *s {
978                return Err(serde::de::Error::custom(
979                    "non-canonical Bn254FieldElement encoding",
980                ));
981            }
982            Ok(value)
983        }
984    }
985
986    impl Serialize for CircomG1 {
987        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
988        where
989            S: serde::Serializer,
990        {
991            use serde::ser::SerializeSeq;
992            let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
993            for element in &self.0 {
994                seq.serialize_element(element)?;
995            }
996            seq.end()
997        }
998    }
999
1000    impl<'de> Deserialize<'de> for CircomG1 {
1001        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1002        where
1003            D: Deserializer<'de>,
1004        {
1005            let inner = <Vec<_>>::deserialize(deserializer)?;
1006            Ok(Self(inner.try_into().map_err(|_| {
1007                serde::de::Error::custom("expected array of length 3")
1008            })?))
1009        }
1010    }
1011
1012    impl Serialize for CircomG2 {
1013        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1014        where
1015            S: serde::Serializer,
1016        {
1017            use serde::ser::SerializeSeq;
1018
1019            struct Inner<'a>(&'a [Bn254FieldElement; 2]);
1020
1021            impl Serialize for Inner<'_> {
1022                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1023                where
1024                    S: serde::Serializer,
1025                {
1026                    let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
1027                    for element in self.0 {
1028                        seq.serialize_element(element)?;
1029                    }
1030                    seq.end()
1031                }
1032            }
1033
1034            let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
1035            for element in &self.0 {
1036                seq.serialize_element(&Inner(element))?;
1037            }
1038            seq.end()
1039        }
1040    }
1041
1042    impl<'de> Deserialize<'de> for CircomG2 {
1043        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1044        where
1045            D: Deserializer<'de>,
1046        {
1047            let vecs = <Vec<Vec<Bn254FieldElement>>>::deserialize(deserializer)?;
1048            let mut inner: [[Bn254FieldElement; 2]; 3] = Default::default();
1049
1050            if vecs.len() != 3 {
1051                return Err(serde::de::Error::custom(
1052                    "vector of three vectors each being a vector of two strings",
1053                ));
1054            }
1055
1056            for (i, v) in vecs.into_iter().enumerate() {
1057                if v.len() != 2 {
1058                    return Err(serde::de::Error::custom(
1059                        "vector of three vectors each being a vector of two strings",
1060                    ));
1061                }
1062
1063                for (j, point) in v.into_iter().enumerate() {
1064                    inner[i][j] = point;
1065                }
1066            }
1067
1068            Ok(Self(inner))
1069        }
1070    }
1071}