1use super::SimpleSignature;
2use crate::checkpoint::EpochId;
3use crate::u256::U256;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
25pub struct ZkLoginAuthenticator {
26 pub inputs: ZkLoginInputs,
28
29 pub max_epoch: EpochId,
31
32 pub signature: SimpleSignature,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
49#[cfg_attr(
50 feature = "serde",
51 derive(serde_derive::Serialize, serde_derive::Deserialize)
52)]
53#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
54pub struct ZkLoginInputs {
55 pub proof_points: ZkLoginProof,
56 pub iss_base64_details: ZkLoginClaim,
57 pub header_base64: String,
58 pub address_seed: Bn254FieldElement,
59}
60
61impl ZkLoginInputs {
62 #[cfg(feature = "serde")]
63 #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
64 pub fn iss(&self) -> Result<String, InvalidZkLoginClaimError> {
65 const ISS: &str = "iss";
66
67 let iss = self.iss_base64_details.verify_extended_claim(ISS)?;
68
69 if iss.len() > 255 {
70 Err(InvalidZkLoginClaimError::new("invalid iss: too long"))
71 } else {
72 Ok(iss)
73 }
74 }
75
76 #[cfg(feature = "serde")]
77 #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
78 pub fn public_identifier(&self) -> Result<ZkLoginPublicIdentifier, InvalidZkLoginClaimError> {
79 let iss = self.iss()?;
80 Ok(ZkLoginPublicIdentifier {
81 iss,
82 address_seed: self.address_seed.clone(),
83 })
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
97#[cfg_attr(
98 feature = "serde",
99 derive(serde_derive::Serialize, serde_derive::Deserialize)
100)]
101#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
102pub struct ZkLoginClaim {
103 pub value: String,
104 pub index_mod_4: u8,
105}
106
107#[derive(Debug)]
108pub struct InvalidZkLoginClaimError(String);
109
110#[cfg(feature = "serde")]
111#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
112impl InvalidZkLoginClaimError {
113 fn new<T: Into<String>>(err: T) -> Self {
114 Self(err.into())
115 }
116}
117
118impl std::fmt::Display for InvalidZkLoginClaimError {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "invalid zklogin claim: {}", self.0)
121 }
122}
123
124impl std::error::Error for InvalidZkLoginClaimError {}
125
126#[cfg(feature = "serde")]
127#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
128impl ZkLoginClaim {
129 fn verify_extended_claim(
130 &self,
131 expected_key: &str,
132 ) -> Result<String, InvalidZkLoginClaimError> {
133 fn base64_to_bitarray(input: &str) -> Result<Vec<u8>, InvalidZkLoginClaimError> {
136 use itertools::Itertools;
137
138 const BASE64_URL_CHARSET: &str =
139 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
140
141 input
142 .chars()
143 .map(|c| {
144 BASE64_URL_CHARSET
145 .find(c)
146 .map(|index| index as u8)
147 .map(|index| (0..6).rev().map(move |i| (index >> i) & 1))
148 .ok_or_else(|| {
149 InvalidZkLoginClaimError::new("base64_to_bitarry invalid input")
150 })
151 })
152 .flatten_ok()
153 .collect()
154 }
155
156 fn bitarray_to_bytearray(bits: &[u8]) -> Result<Vec<u8>, InvalidZkLoginClaimError> {
159 if bits.len() % 8 != 0 {
160 return Err(InvalidZkLoginClaimError::new(
161 "bitarray_to_bytearray invalid input",
162 ));
163 }
164 Ok(bits
165 .chunks(8)
166 .map(|chunk| {
167 let mut byte = 0u8;
168 for (i, bit) in chunk.iter().rev().enumerate() {
169 byte |= bit << i;
170 }
171 byte
172 })
173 .collect())
174 }
175
176 fn decode_base64_url(
178 s: &str,
179 index_mod_4: &u8,
180 ) -> Result<String, InvalidZkLoginClaimError> {
181 if s.len() < 2 {
182 return Err(InvalidZkLoginClaimError::new(
183 "Base64 string smaller than 2",
184 ));
185 }
186 let mut bits = base64_to_bitarray(s)?;
187 match index_mod_4 {
188 0 => {}
189 1 => {
190 bits.drain(..2);
191 }
192 2 => {
193 bits.drain(..4);
194 }
195 _ => {
196 return Err(InvalidZkLoginClaimError::new("Invalid first_char_offset"));
197 }
198 }
199
200 let last_char_offset = (index_mod_4 + s.len() as u8 - 1) % 4;
201 match last_char_offset {
202 3 => {}
203 2 => {
204 bits.drain(bits.len() - 2..);
205 }
206 1 => {
207 bits.drain(bits.len() - 4..);
208 }
209 _ => {
210 return Err(InvalidZkLoginClaimError::new("Invalid last_char_offset"));
211 }
212 }
213
214 if bits.len() % 8 != 0 {
215 return Err(InvalidZkLoginClaimError::new("Invalid bits length"));
216 }
217
218 Ok(std::str::from_utf8(&bitarray_to_bytearray(&bits)?)
219 .map_err(|_| InvalidZkLoginClaimError::new("Invalid UTF8 string"))?
220 .to_owned())
221 }
222
223 let extended_claim = decode_base64_url(&self.value, &self.index_mod_4)?;
224
225 if !(extended_claim.ends_with('}') || extended_claim.ends_with(',')) {
227 return Err(InvalidZkLoginClaimError::new("Invalid extended claim"));
228 }
229
230 let json_str = format!("{{{}}}", &extended_claim[..extended_claim.len() - 1]);
231
232 serde_json::from_str::<serde_json::Value>(&json_str)
233 .map_err(|e| InvalidZkLoginClaimError::new(e.to_string()))?
234 .as_object_mut()
235 .and_then(|o| o.get_mut(expected_key))
236 .map(serde_json::Value::take)
237 .and_then(|v| match v {
238 serde_json::Value::String(s) => Some(s),
239 _ => None,
240 })
241 .ok_or_else(|| InvalidZkLoginClaimError::new("invalid extended claim"))
242 }
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
255#[cfg_attr(
256 feature = "serde",
257 derive(serde_derive::Serialize, serde_derive::Deserialize)
258)]
259#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
260pub struct ZkLoginProof {
261 pub a: CircomG1,
262 pub b: CircomG2,
263 pub c: CircomG1,
264}
265
266#[derive(Clone, Debug, PartialEq, Eq)]
278#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
279pub struct CircomG1(pub [Bn254FieldElement; 3]);
280
281#[derive(Clone, Debug, PartialEq, Eq)]
294#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
295pub struct CircomG2(pub [[Bn254FieldElement; 2]; 3]);
296
297#[derive(Clone, Debug, PartialEq, Eq)]
347#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
348pub struct ZkLoginPublicIdentifier {
349 iss: String,
350 address_seed: Bn254FieldElement,
351}
352
353impl ZkLoginPublicIdentifier {
354 pub fn new(iss: String, address_seed: Bn254FieldElement) -> Option<Self> {
355 if iss.len() > 255 {
356 None
357 } else {
358 Some(Self { iss, address_seed })
359 }
360 }
361
362 pub fn iss(&self) -> &str {
363 &self.iss
364 }
365
366 pub fn address_seed(&self) -> &Bn254FieldElement {
367 &self.address_seed
368 }
369}
370
371#[derive(Clone, Debug, PartialEq, Eq, Hash)]
385#[cfg_attr(
386 feature = "serde",
387 derive(serde_derive::Serialize, serde_derive::Deserialize)
388)]
389#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
390pub struct Jwk {
391 pub kty: String,
393
394 pub e: String,
396
397 pub n: String,
399
400 pub alg: String,
402}
403
404#[derive(Clone, Debug, PartialEq, Eq, Hash)]
414#[cfg_attr(
415 feature = "serde",
416 derive(serde_derive::Serialize, serde_derive::Deserialize)
417)]
418#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
419pub struct JwkId {
420 pub iss: String,
422
423 pub kid: String,
425}
426
427#[derive(Clone, Debug, Default, PartialEq, Eq)]
441#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
442pub struct Bn254FieldElement([u8; 32]);
443
444impl Bn254FieldElement {
445 pub const fn new(bytes: [u8; 32]) -> Self {
446 Self(bytes)
447 }
448
449 pub const fn from_str_radix_10(s: &str) -> Result<Self, Bn254FieldElementParseError> {
450 let u256 = match U256::from_str_radix(s, 10) {
451 Ok(u256) => u256,
452 Err(e) => return Err(Bn254FieldElementParseError(e)),
453 };
454 let be = u256.to_be();
455 Ok(Self(*be.digits()))
456 }
457
458 pub fn unpadded(&self) -> &[u8] {
459 let mut buf = self.0.as_slice();
460
461 while !buf.is_empty() && buf[0] == 0 {
462 buf = &buf[1..];
463 }
464
465 if buf.is_empty() {
467 &self.0[31..]
468 } else {
469 buf
470 }
471 }
472
473 pub fn padded(&self) -> &[u8] {
474 &self.0
475 }
476}
477
478impl std::fmt::Display for Bn254FieldElement {
479 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 let u256 = U256::from_be(U256::from_digits(self.0));
481 let radix10 = u256.to_str_radix(10);
482 f.write_str(&radix10)
483 }
484}
485
486#[derive(Debug)]
487pub struct Bn254FieldElementParseError(bnum::errors::ParseIntError);
488
489impl std::fmt::Display for Bn254FieldElementParseError {
490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 write!(f, "unable to parse radix10 encoded value {}", self.0)
492 }
493}
494
495impl std::error::Error for Bn254FieldElementParseError {}
496
497impl std::str::FromStr for Bn254FieldElement {
498 type Err = Bn254FieldElementParseError;
499
500 fn from_str(s: &str) -> Result<Self, Self::Err> {
501 let u256 = U256::from_str_radix(s, 10).map_err(Bn254FieldElementParseError)?;
502 let be = u256.to_be();
503 Ok(Self(*be.digits()))
504 }
505}
506
507#[cfg(test)]
508mod test {
509 use super::Bn254FieldElement;
510 use num_bigint::BigUint;
511 use proptest::prelude::*;
512 use std::str::FromStr;
513 use test_strategy::proptest;
514
515 #[cfg(target_arch = "wasm32")]
516 use wasm_bindgen_test::wasm_bindgen_test as test;
517
518 #[test]
519 fn unpadded_slice() {
520 let seed = Bn254FieldElement([0; 32]);
521 let zero: [u8; 1] = [0];
522 assert_eq!(seed.unpadded(), zero.as_slice());
523
524 let mut seed = Bn254FieldElement([1; 32]);
525 seed.0[0] = 0;
526 assert_eq!(seed.unpadded(), [1; 31].as_slice());
527 }
528
529 #[proptest]
530 fn dont_crash_on_large_inputs(
531 #[strategy(proptest::collection::vec(any::<u8>(), 33..1024))] bytes: Vec<u8>,
532 ) {
533 let big_int = BigUint::from_bytes_be(&bytes);
534 let radix10 = big_int.to_str_radix(10);
535
536 let _ = Bn254FieldElement::from_str(&radix10);
538 }
539
540 #[proptest]
541 fn valid_address_seeds(
542 #[strategy(proptest::collection::vec(any::<u8>(), 1..=32))] bytes: Vec<u8>,
543 ) {
544 let big_int = BigUint::from_bytes_be(&bytes);
545 let radix10 = big_int.to_str_radix(10);
546
547 let seed = Bn254FieldElement::from_str(&radix10).unwrap();
548 assert_eq!(radix10, seed.to_string());
549 seed.unpadded();
551 }
552}
553
554#[cfg(feature = "serde")]
555#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
556mod serialization {
557 use crate::SignatureScheme;
558
559 use super::*;
560 use serde::Deserialize;
561 use serde::Deserializer;
562 use serde::Serialize;
563 use serde::Serializer;
564 use serde_with::Bytes;
565 use serde_with::DeserializeAs;
566 use serde_with::SerializeAs;
567 use std::borrow::Cow;
568
569 impl Serialize for ZkLoginPublicIdentifier {
571 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
572 where
573 S: Serializer,
574 {
575 if serializer.is_human_readable() {
576 #[derive(serde_derive::Serialize)]
577 struct Readable<'a> {
578 iss: &'a str,
579 address_seed: &'a Bn254FieldElement,
580 }
581 let readable = Readable {
582 iss: &self.iss,
583 address_seed: &self.address_seed,
584 };
585 readable.serialize(serializer)
586 } else {
587 let mut buf = Vec::new();
588 let iss_bytes = self.iss.as_bytes();
589 buf.push(iss_bytes.len() as u8);
590 buf.extend(iss_bytes);
591
592 buf.extend(&self.address_seed.0);
593
594 serializer.serialize_bytes(&buf)
595 }
596 }
597 }
598
599 impl<'de> Deserialize<'de> for ZkLoginPublicIdentifier {
600 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
601 where
602 D: Deserializer<'de>,
603 {
604 if deserializer.is_human_readable() {
605 #[derive(serde_derive::Deserialize)]
606 struct Readable {
607 iss: String,
608 address_seed: Bn254FieldElement,
609 }
610
611 let Readable { iss, address_seed } = Deserialize::deserialize(deserializer)?;
612 Self::new(iss, address_seed)
613 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
614 } else {
615 let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
616 let iss_len = *bytes
617 .first()
618 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
619 let iss_bytes = bytes
620 .get(1..(1 + iss_len as usize))
621 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
622 let iss = std::str::from_utf8(iss_bytes).map_err(serde::de::Error::custom)?;
623 let address_seed_bytes = bytes
624 .get((1 + iss_len as usize)..)
625 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
626
627 let address_seed = <[u8; 32]>::try_from(address_seed_bytes)
628 .map_err(serde::de::Error::custom)
629 .map(Bn254FieldElement)?;
630
631 Self::new(iss.into(), address_seed)
632 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
633 }
634 }
635 }
636
637 #[derive(serde_derive::Serialize)]
638 struct AuthenticatorRef<'a> {
639 inputs: &'a ZkLoginInputs,
640 #[cfg_attr(feature = "serde", serde(with = "crate::_serde::ReadableDisplay"))]
641 max_epoch: EpochId,
642 signature: &'a SimpleSignature,
643 }
644
645 #[derive(serde_derive::Deserialize)]
646 struct Authenticator {
647 inputs: ZkLoginInputs,
648 #[cfg_attr(feature = "serde", serde(with = "crate::_serde::ReadableDisplay"))]
649 max_epoch: EpochId,
650 signature: SimpleSignature,
651 }
652
653 impl Serialize for ZkLoginAuthenticator {
654 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
655 where
656 S: Serializer,
657 {
658 if serializer.is_human_readable() {
659 let authenticator_ref = AuthenticatorRef {
660 inputs: &self.inputs,
661 max_epoch: self.max_epoch,
662 signature: &self.signature,
663 };
664
665 authenticator_ref.serialize(serializer)
666 } else {
667 let bytes = self.to_bytes();
668 serializer.serialize_bytes(&bytes)
669 }
670 }
671 }
672
673 impl<'de> Deserialize<'de> for ZkLoginAuthenticator {
674 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
675 where
676 D: Deserializer<'de>,
677 {
678 if deserializer.is_human_readable() {
679 let Authenticator {
680 inputs,
681 max_epoch,
682 signature,
683 } = Authenticator::deserialize(deserializer)?;
684 Ok(Self {
685 inputs,
686 max_epoch,
687 signature,
688 })
689 } else {
690 let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
691 Self::from_serialized_bytes(bytes)
692 }
693 }
694 }
695
696 impl ZkLoginAuthenticator {
697 pub(crate) fn to_bytes(&self) -> Vec<u8> {
698 let authenticator_ref = AuthenticatorRef {
699 inputs: &self.inputs,
700 max_epoch: self.max_epoch,
701 signature: &self.signature,
702 };
703
704 let mut buf = Vec::new();
705 buf.push(SignatureScheme::ZkLogin as u8);
706
707 bcs::serialize_into(&mut buf, &authenticator_ref).expect("serialization cannot fail");
708 buf
709 }
710
711 pub(crate) fn from_serialized_bytes<T: AsRef<[u8]>, E: serde::de::Error>(
712 bytes: T,
713 ) -> Result<Self, E> {
714 let bytes = bytes.as_ref();
715 let flag = SignatureScheme::from_byte(
716 *bytes
717 .first()
718 .ok_or_else(|| serde::de::Error::custom("missing signature scheme falg"))?,
719 )
720 .map_err(serde::de::Error::custom)?;
721 if flag != SignatureScheme::ZkLogin {
722 return Err(serde::de::Error::custom("invalid zklogin flag"));
723 }
724 let bcs_bytes = &bytes[1..];
725
726 let Authenticator {
727 inputs,
728 max_epoch,
729 signature,
730 } = bcs::from_bytes(bcs_bytes).map_err(serde::de::Error::custom)?;
731 Ok(Self {
732 inputs,
733 max_epoch,
734 signature,
735 })
736 }
737 }
738
739 impl Serialize for Bn254FieldElement {
741 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
742 where
743 S: serde::Serializer,
744 {
745 serde_with::DisplayFromStr::serialize_as(self, serializer)
746 }
747 }
748
749 impl<'de> Deserialize<'de> for Bn254FieldElement {
750 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
751 where
752 D: Deserializer<'de>,
753 {
754 serde_with::DisplayFromStr::deserialize_as(deserializer)
755 }
756 }
757
758 impl Serialize for CircomG1 {
759 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
760 where
761 S: serde::Serializer,
762 {
763 use serde::ser::SerializeSeq;
764 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
765 for element in &self.0 {
766 seq.serialize_element(element)?;
767 }
768 seq.end()
769 }
770 }
771
772 impl<'de> Deserialize<'de> for CircomG1 {
773 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
774 where
775 D: Deserializer<'de>,
776 {
777 let inner = <Vec<_>>::deserialize(deserializer)?;
778 Ok(Self(inner.try_into().map_err(|_| {
779 serde::de::Error::custom("expected array of length 3")
780 })?))
781 }
782 }
783
784 impl Serialize for CircomG2 {
785 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
786 where
787 S: serde::Serializer,
788 {
789 use serde::ser::SerializeSeq;
790
791 struct Inner<'a>(&'a [Bn254FieldElement; 2]);
792
793 impl Serialize for Inner<'_> {
794 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
795 where
796 S: serde::Serializer,
797 {
798 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
799 for element in self.0 {
800 seq.serialize_element(element)?;
801 }
802 seq.end()
803 }
804 }
805
806 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
807 for element in &self.0 {
808 seq.serialize_element(&Inner(element))?;
809 }
810 seq.end()
811 }
812 }
813
814 impl<'de> Deserialize<'de> for CircomG2 {
815 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
816 where
817 D: Deserializer<'de>,
818 {
819 let vecs = <Vec<Vec<Bn254FieldElement>>>::deserialize(deserializer)?;
820 let mut inner: [[Bn254FieldElement; 2]; 3] = Default::default();
821
822 if vecs.len() != 3 {
823 return Err(serde::de::Error::custom(
824 "vector of three vectors each being a vector of two strings",
825 ));
826 }
827
828 for (i, v) in vecs.into_iter().enumerate() {
829 if v.len() != 2 {
830 return Err(serde::de::Error::custom(
831 "vector of three vectors each being a vector of two strings",
832 ));
833 }
834
835 for (j, point) in v.into_iter().enumerate() {
836 inner[i][j] = point;
837 }
838 }
839
840 Ok(Self(inner))
841 }
842 }
843}