1use super::SimpleSignature;
2use crate::checkpoint::EpochId;
3use crate::u256::U256;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
25pub struct ZkLoginAuthenticator {
26 pub inputs: ZkLoginInputs,
28
29 pub max_epoch: EpochId,
31
32 pub signature: SimpleSignature,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
49#[cfg_attr(
50 feature = "serde",
51 derive(serde_derive::Serialize, serde_derive::Deserialize)
52)]
53#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
54pub struct ZkLoginInputs {
55 pub proof_points: ZkLoginProof,
56 pub iss_base64_details: ZkLoginClaim,
57 pub header_base64: String,
58 pub address_seed: Bn254FieldElement,
59}
60
61impl ZkLoginInputs {
62 #[cfg(feature = "serde")]
63 #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
64 pub fn iss(&self) -> Result<String, InvalidZkLoginClaimError> {
65 const ISS: &str = "iss";
66
67 let iss = self.iss_base64_details.verify_extended_claim(ISS)?;
68
69 if iss.len() > 255 {
70 Err(InvalidZkLoginClaimError::new("invalid iss: too long"))
71 } else {
72 Ok(iss)
73 }
74 }
75
76 #[cfg(feature = "serde")]
77 #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
78 pub fn public_identifier(&self) -> Result<ZkLoginPublicIdentifier, InvalidZkLoginClaimError> {
79 let iss = self.iss()?;
80 Ok(ZkLoginPublicIdentifier {
81 iss,
82 address_seed: self.address_seed.clone(),
83 })
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
97#[cfg_attr(
98 feature = "serde",
99 derive(serde_derive::Serialize, serde_derive::Deserialize)
100)]
101#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
102pub struct ZkLoginClaim {
103 pub value: String,
104 pub index_mod_4: u8,
105}
106
107#[derive(Debug)]
108pub struct InvalidZkLoginClaimError(String);
109
110#[cfg(feature = "serde")]
111#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
112impl InvalidZkLoginClaimError {
113 fn new<T: Into<String>>(err: T) -> Self {
114 Self(err.into())
115 }
116}
117
118impl std::fmt::Display for InvalidZkLoginClaimError {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "invalid zklogin claim: {}", self.0)
121 }
122}
123
124impl std::error::Error for InvalidZkLoginClaimError {}
125
126#[cfg(feature = "serde")]
127#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
128impl ZkLoginClaim {
129 fn verify_extended_claim(
130 &self,
131 expected_key: &str,
132 ) -> Result<String, InvalidZkLoginClaimError> {
133 fn base64_to_bitarray(input: &str) -> Result<Vec<u8>, InvalidZkLoginClaimError> {
136 use itertools::Itertools;
137
138 const BASE64_URL_CHARSET: &str =
139 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
140
141 input
142 .chars()
143 .map(|c| {
144 BASE64_URL_CHARSET
145 .find(c)
146 .map(|index| index as u8)
147 .map(|index| (0..6).rev().map(move |i| (index >> i) & 1))
148 .ok_or_else(|| {
149 InvalidZkLoginClaimError::new("base64_to_bitarry invalid input")
150 })
151 })
152 .flatten_ok()
153 .collect()
154 }
155
156 fn bitarray_to_bytearray(bits: &[u8]) -> Result<Vec<u8>, InvalidZkLoginClaimError> {
159 if bits.len() % 8 != 0 {
160 return Err(InvalidZkLoginClaimError::new(
161 "bitarray_to_bytearray invalid input",
162 ));
163 }
164 Ok(bits
165 .chunks(8)
166 .map(|chunk| {
167 let mut byte = 0u8;
168 for (i, bit) in chunk.iter().rev().enumerate() {
169 byte |= bit << i;
170 }
171 byte
172 })
173 .collect())
174 }
175
176 fn decode_base64_url(
178 s: &str,
179 index_mod_4: &u8,
180 ) -> Result<String, InvalidZkLoginClaimError> {
181 if s.len() < 2 {
182 return Err(InvalidZkLoginClaimError::new(
183 "Base64 string smaller than 2",
184 ));
185 }
186 let mut bits = base64_to_bitarray(s)?;
187 match index_mod_4 {
188 0 => {}
189 1 => {
190 bits.drain(..2);
191 }
192 2 => {
193 bits.drain(..4);
194 }
195 _ => {
196 return Err(InvalidZkLoginClaimError::new("Invalid first_char_offset"));
197 }
198 }
199
200 let last_char_offset = (index_mod_4 + s.len() as u8 - 1) % 4;
201 match last_char_offset {
202 3 => {}
203 2 => {
204 bits.drain(bits.len() - 2..);
205 }
206 1 => {
207 bits.drain(bits.len() - 4..);
208 }
209 _ => {
210 return Err(InvalidZkLoginClaimError::new("Invalid last_char_offset"));
211 }
212 }
213
214 if bits.len() % 8 != 0 {
215 return Err(InvalidZkLoginClaimError::new("Invalid bits length"));
216 }
217
218 Ok(std::str::from_utf8(&bitarray_to_bytearray(&bits)?)
219 .map_err(|_| InvalidZkLoginClaimError::new("Invalid UTF8 string"))?
220 .to_owned())
221 }
222
223 let extended_claim = decode_base64_url(&self.value, &self.index_mod_4)?;
224
225 if !(extended_claim.ends_with('}') || extended_claim.ends_with(',')) {
227 return Err(InvalidZkLoginClaimError::new("Invalid extended claim"));
228 }
229
230 let json_str = format!("{{{}}}", &extended_claim[..extended_claim.len() - 1]);
231
232 serde_json::from_str::<serde_json::Value>(&json_str)
233 .map_err(|e| InvalidZkLoginClaimError::new(e.to_string()))?
234 .as_object_mut()
235 .and_then(|o| o.get_mut(expected_key))
236 .map(serde_json::Value::take)
237 .and_then(|v| match v {
238 serde_json::Value::String(s) => Some(s),
239 _ => None,
240 })
241 .ok_or_else(|| InvalidZkLoginClaimError::new("invalid extended claim"))
242 }
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
255#[cfg_attr(
256 feature = "serde",
257 derive(serde_derive::Serialize, serde_derive::Deserialize)
258)]
259#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
260pub struct ZkLoginProof {
261 pub a: CircomG1,
262 pub b: CircomG2,
263 pub c: CircomG1,
264}
265
266#[derive(Clone, Debug, PartialEq, Eq)]
278#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
279pub struct CircomG1(pub [Bn254FieldElement; 3]);
280
281#[derive(Clone, Debug, PartialEq, Eq)]
294#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
295pub struct CircomG2(pub [[Bn254FieldElement; 2]; 3]);
296
297#[derive(Clone, Debug, PartialEq, Eq)]
347#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
348pub struct ZkLoginPublicIdentifier {
349 iss: String,
350 address_seed: Bn254FieldElement,
351}
352
353impl ZkLoginPublicIdentifier {
354 pub fn new(iss: String, address_seed: Bn254FieldElement) -> Option<Self> {
355 if iss.len() > 255 {
356 None
357 } else {
358 Some(Self { iss, address_seed })
359 }
360 }
361
362 pub fn iss(&self) -> &str {
363 &self.iss
364 }
365
366 pub fn address_seed(&self) -> &Bn254FieldElement {
367 &self.address_seed
368 }
369}
370
371#[derive(Clone, Debug, PartialEq, Eq, Hash)]
385#[cfg_attr(
386 feature = "serde",
387 derive(serde_derive::Serialize, serde_derive::Deserialize)
388)]
389#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
390pub struct Jwk {
391 pub kty: String,
393
394 pub e: String,
396
397 pub n: String,
399
400 pub alg: String,
402}
403
404#[derive(Clone, Debug, PartialEq, Eq, Hash)]
414#[cfg_attr(
415 feature = "serde",
416 derive(serde_derive::Serialize, serde_derive::Deserialize)
417)]
418#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
419pub struct JwkId {
420 pub iss: String,
422
423 pub kid: String,
425}
426
427#[derive(Clone, Debug, Default, PartialEq, Eq)]
441#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
442pub struct Bn254FieldElement([u8; 32]);
443
444impl Bn254FieldElement {
445 pub const fn new(bytes: [u8; 32]) -> Self {
446 Self(bytes)
447 }
448
449 pub const fn from_str_radix_10(s: &str) -> Result<Self, Bn254FieldElementParseError> {
450 let u256 = match U256::from_str_radix(s, 10) {
451 Ok(u256) => u256,
452 Err(e) => return Err(Bn254FieldElementParseError(e)),
453 };
454 let be = u256.to_be();
455 Ok(Self(*be.digits()))
456 }
457
458 pub fn unpadded(&self) -> &[u8] {
459 let mut buf = self.0.as_slice();
460
461 while !buf.is_empty() && buf[0] == 0 {
462 buf = &buf[1..];
463 }
464
465 if buf.is_empty() {
467 &self.0[31..]
468 } else {
469 buf
470 }
471 }
472
473 pub fn padded(&self) -> &[u8] {
474 &self.0
475 }
476}
477
478impl std::fmt::Display for Bn254FieldElement {
479 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 let u256 = U256::from_be(U256::from_digits(self.0));
481 let radix10 = u256.to_str_radix(10);
482 f.write_str(&radix10)
483 }
484}
485
486#[derive(Debug)]
487pub struct Bn254FieldElementParseError(bnum::errors::ParseIntError);
488
489impl std::fmt::Display for Bn254FieldElementParseError {
490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 write!(f, "unable to parse radix10 encoded value {}", self.0)
492 }
493}
494
495impl std::error::Error for Bn254FieldElementParseError {}
496
497impl std::str::FromStr for Bn254FieldElement {
498 type Err = Bn254FieldElementParseError;
499
500 fn from_str(s: &str) -> Result<Self, Self::Err> {
501 let u256 = U256::from_str_radix(s, 10).map_err(Bn254FieldElementParseError)?;
502 let be = u256.to_be();
503 Ok(Self(*be.digits()))
504 }
505}
506
507#[cfg(test)]
508mod test {
509 use super::Bn254FieldElement;
510 use num_bigint::BigUint;
511 use proptest::prelude::*;
512 use std::str::FromStr;
513 use test_strategy::proptest;
514
515 #[cfg(target_arch = "wasm32")]
516 use wasm_bindgen_test::wasm_bindgen_test as test;
517
518 #[test]
519 fn unpadded_slice() {
520 let seed = Bn254FieldElement([0; 32]);
521 let zero: [u8; 1] = [0];
522 assert_eq!(seed.unpadded(), zero.as_slice());
523
524 let mut seed = Bn254FieldElement([1; 32]);
525 seed.0[0] = 0;
526 assert_eq!(seed.unpadded(), [1; 31].as_slice());
527 }
528
529 #[proptest]
530 fn dont_crash_on_large_inputs(
531 #[strategy(proptest::collection::vec(any::<u8>(), 33..1024))] bytes: Vec<u8>,
532 ) {
533 let big_int = BigUint::from_bytes_be(&bytes);
534 let radix10 = big_int.to_str_radix(10);
535
536 let _ = Bn254FieldElement::from_str(&radix10);
538 }
539
540 #[proptest]
541 fn valid_address_seeds(
542 #[strategy(proptest::collection::vec(any::<u8>(), 1..=32))] bytes: Vec<u8>,
543 ) {
544 let big_int = BigUint::from_bytes_be(&bytes);
545 let radix10 = big_int.to_str_radix(10);
546
547 let seed = Bn254FieldElement::from_str(&radix10).unwrap();
548 assert_eq!(radix10, seed.to_string());
549 seed.unpadded();
551 }
552}
553
554#[cfg(feature = "serde")]
555#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
556mod serialization {
557 use crate::SignatureScheme;
558
559 use super::*;
560 use serde::Deserialize;
561 use serde::Deserializer;
562 use serde::Serialize;
563 use serde::Serializer;
564 use serde_with::Bytes;
565 use serde_with::DeserializeAs;
566 use serde_with::SerializeAs;
567 use std::borrow::Cow;
568
569 impl Serialize for ZkLoginPublicIdentifier {
571 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
572 where
573 S: Serializer,
574 {
575 if serializer.is_human_readable() {
576 #[derive(serde_derive::Serialize)]
577 struct Readable<'a> {
578 iss: &'a str,
579 address_seed: &'a Bn254FieldElement,
580 }
581 let readable = Readable {
582 iss: &self.iss,
583 address_seed: &self.address_seed,
584 };
585 readable.serialize(serializer)
586 } else {
587 let mut buf = Vec::new();
588 let iss_bytes = self.iss.as_bytes();
589 buf.push(iss_bytes.len() as u8);
590 buf.extend(iss_bytes);
591
592 buf.extend(&self.address_seed.0);
593
594 serializer.serialize_bytes(&buf)
595 }
596 }
597 }
598
599 impl<'de> Deserialize<'de> for ZkLoginPublicIdentifier {
600 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
601 where
602 D: Deserializer<'de>,
603 {
604 if deserializer.is_human_readable() {
605 #[derive(serde_derive::Deserialize)]
606 struct Readable {
607 iss: String,
608 address_seed: Bn254FieldElement,
609 }
610
611 let Readable { iss, address_seed } = Deserialize::deserialize(deserializer)?;
612 Self::new(iss, address_seed)
613 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
614 } else {
615 let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
616 let iss_len = *bytes
617 .first()
618 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
619 let iss_bytes = bytes
620 .get(1..(1 + iss_len as usize))
621 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
622 let iss = std::str::from_utf8(iss_bytes).map_err(serde::de::Error::custom)?;
623 let address_seed_bytes = bytes
624 .get((1 + iss_len as usize)..)
625 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
626
627 let address_seed = <[u8; 32]>::try_from(address_seed_bytes)
628 .map_err(serde::de::Error::custom)
629 .map(Bn254FieldElement)?;
630
631 Self::new(iss.into(), address_seed)
632 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
633 }
634 }
635 }
636
637 #[derive(serde_derive::Serialize)]
638 struct AuthenticatorRef<'a> {
639 inputs: &'a ZkLoginInputs,
640 max_epoch: EpochId,
641 signature: &'a SimpleSignature,
642 }
643
644 #[derive(serde_derive::Deserialize)]
645 struct Authenticator {
646 inputs: ZkLoginInputs,
647 max_epoch: EpochId,
648 signature: SimpleSignature,
649 }
650
651 impl Serialize for ZkLoginAuthenticator {
652 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
653 where
654 S: Serializer,
655 {
656 if serializer.is_human_readable() {
657 let authenticator_ref = AuthenticatorRef {
658 inputs: &self.inputs,
659 max_epoch: self.max_epoch,
660 signature: &self.signature,
661 };
662
663 authenticator_ref.serialize(serializer)
664 } else {
665 let bytes = self.to_bytes();
666 serializer.serialize_bytes(&bytes)
667 }
668 }
669 }
670
671 impl<'de> Deserialize<'de> for ZkLoginAuthenticator {
672 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
673 where
674 D: Deserializer<'de>,
675 {
676 if deserializer.is_human_readable() {
677 let Authenticator {
678 inputs,
679 max_epoch,
680 signature,
681 } = Authenticator::deserialize(deserializer)?;
682 Ok(Self {
683 inputs,
684 max_epoch,
685 signature,
686 })
687 } else {
688 let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
689 Self::from_serialized_bytes(bytes)
690 }
691 }
692 }
693
694 impl ZkLoginAuthenticator {
695 pub(crate) fn to_bytes(&self) -> Vec<u8> {
696 let authenticator_ref = AuthenticatorRef {
697 inputs: &self.inputs,
698 max_epoch: self.max_epoch,
699 signature: &self.signature,
700 };
701
702 let mut buf = Vec::new();
703 buf.push(SignatureScheme::ZkLogin as u8);
704
705 bcs::serialize_into(&mut buf, &authenticator_ref).expect("serialization cannot fail");
706 buf
707 }
708
709 pub(crate) fn from_serialized_bytes<T: AsRef<[u8]>, E: serde::de::Error>(
710 bytes: T,
711 ) -> Result<Self, E> {
712 let bytes = bytes.as_ref();
713 let flag = SignatureScheme::from_byte(
714 *bytes
715 .first()
716 .ok_or_else(|| serde::de::Error::custom("missing signature scheme flag"))?,
717 )
718 .map_err(serde::de::Error::custom)?;
719 if flag != SignatureScheme::ZkLogin {
720 return Err(serde::de::Error::custom("invalid zklogin flag"));
721 }
722 let bcs_bytes = &bytes[1..];
723
724 let Authenticator {
725 inputs,
726 max_epoch,
727 signature,
728 } = bcs::from_bytes(bcs_bytes).map_err(serde::de::Error::custom)?;
729 Ok(Self {
730 inputs,
731 max_epoch,
732 signature,
733 })
734 }
735 }
736
737 impl Serialize for Bn254FieldElement {
739 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
740 where
741 S: serde::Serializer,
742 {
743 serde_with::DisplayFromStr::serialize_as(self, serializer)
744 }
745 }
746
747 impl<'de> Deserialize<'de> for Bn254FieldElement {
748 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
749 where
750 D: Deserializer<'de>,
751 {
752 serde_with::DisplayFromStr::deserialize_as(deserializer)
753 }
754 }
755
756 impl Serialize for CircomG1 {
757 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
758 where
759 S: serde::Serializer,
760 {
761 use serde::ser::SerializeSeq;
762 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
763 for element in &self.0 {
764 seq.serialize_element(element)?;
765 }
766 seq.end()
767 }
768 }
769
770 impl<'de> Deserialize<'de> for CircomG1 {
771 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
772 where
773 D: Deserializer<'de>,
774 {
775 let inner = <Vec<_>>::deserialize(deserializer)?;
776 Ok(Self(inner.try_into().map_err(|_| {
777 serde::de::Error::custom("expected array of length 3")
778 })?))
779 }
780 }
781
782 impl Serialize for CircomG2 {
783 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
784 where
785 S: serde::Serializer,
786 {
787 use serde::ser::SerializeSeq;
788
789 struct Inner<'a>(&'a [Bn254FieldElement; 2]);
790
791 impl Serialize for Inner<'_> {
792 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
793 where
794 S: serde::Serializer,
795 {
796 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
797 for element in self.0 {
798 seq.serialize_element(element)?;
799 }
800 seq.end()
801 }
802 }
803
804 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
805 for element in &self.0 {
806 seq.serialize_element(&Inner(element))?;
807 }
808 seq.end()
809 }
810 }
811
812 impl<'de> Deserialize<'de> for CircomG2 {
813 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
814 where
815 D: Deserializer<'de>,
816 {
817 let vecs = <Vec<Vec<Bn254FieldElement>>>::deserialize(deserializer)?;
818 let mut inner: [[Bn254FieldElement; 2]; 3] = Default::default();
819
820 if vecs.len() != 3 {
821 return Err(serde::de::Error::custom(
822 "vector of three vectors each being a vector of two strings",
823 ));
824 }
825
826 for (i, v) in vecs.into_iter().enumerate() {
827 if v.len() != 2 {
828 return Err(serde::de::Error::custom(
829 "vector of three vectors each being a vector of two strings",
830 ));
831 }
832
833 for (j, point) in v.into_iter().enumerate() {
834 inner[i][j] = point;
835 }
836 }
837
838 Ok(Self(inner))
839 }
840 }
841}