1use super::SimpleSignature;
2use crate::checkpoint::EpochId;
3use crate::u256::U256;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
25pub struct ZkLoginAuthenticator {
26 pub inputs: ZkLoginInputs,
28
29 pub max_epoch: EpochId,
31
32 pub signature: SimpleSignature,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
49#[cfg_attr(
50 feature = "serde",
51 derive(serde_derive::Serialize, serde_derive::Deserialize)
52)]
53#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
54pub struct ZkLoginInputs {
55 pub proof_points: ZkLoginProof,
56 pub iss_base64_details: ZkLoginClaim,
57 pub header_base64: String,
58 pub address_seed: Bn254FieldElement,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
71#[cfg_attr(
72 feature = "serde",
73 derive(serde_derive::Serialize, serde_derive::Deserialize)
74)]
75#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
76pub struct ZkLoginClaim {
77 pub value: String,
78 pub index_mod_4: u8,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq)]
91#[cfg_attr(
92 feature = "serde",
93 derive(serde_derive::Serialize, serde_derive::Deserialize)
94)]
95#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
96pub struct ZkLoginProof {
97 pub a: CircomG1,
98 pub b: CircomG2,
99 pub c: CircomG1,
100}
101
102#[derive(Clone, Debug, PartialEq, Eq)]
114#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
115pub struct CircomG1(pub [Bn254FieldElement; 3]);
116
117#[derive(Clone, Debug, PartialEq, Eq)]
130#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
131pub struct CircomG2(pub [[Bn254FieldElement; 2]; 3]);
132
133#[derive(Clone, Debug, PartialEq, Eq)]
183#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
184pub struct ZkLoginPublicIdentifier {
185 iss: String,
186 address_seed: Bn254FieldElement,
187}
188
189impl ZkLoginPublicIdentifier {
190 pub fn new(iss: String, address_seed: Bn254FieldElement) -> Option<Self> {
191 if iss.len() > 255 {
192 None
193 } else {
194 Some(Self { iss, address_seed })
195 }
196 }
197
198 pub fn iss(&self) -> &str {
199 &self.iss
200 }
201
202 pub fn address_seed(&self) -> &Bn254FieldElement {
203 &self.address_seed
204 }
205}
206
207#[derive(Clone, Debug, PartialEq, Eq, Hash)]
221#[cfg_attr(
222 feature = "serde",
223 derive(serde_derive::Serialize, serde_derive::Deserialize)
224)]
225#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
226pub struct Jwk {
227 pub kty: String,
229
230 pub e: String,
232
233 pub n: String,
235
236 pub alg: String,
238}
239
240#[derive(Clone, Debug, PartialEq, Eq, Hash)]
250#[cfg_attr(
251 feature = "serde",
252 derive(serde_derive::Serialize, serde_derive::Deserialize)
253)]
254#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
255pub struct JwkId {
256 pub iss: String,
258
259 pub kid: String,
261}
262
263#[derive(Clone, Debug, Default, PartialEq, Eq)]
277#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
278pub struct Bn254FieldElement([u8; 32]);
279
280impl Bn254FieldElement {
281 pub const fn new(bytes: [u8; 32]) -> Self {
282 Self(bytes)
283 }
284
285 pub const fn from_str_radix_10(s: &str) -> Result<Self, Bn254FieldElementParseError> {
286 let u256 = match U256::from_str_radix(s, 10) {
287 Ok(u256) => u256,
288 Err(e) => return Err(Bn254FieldElementParseError(e)),
289 };
290 let be = u256.to_be();
291 Ok(Self(*be.digits()))
292 }
293
294 pub fn unpadded(&self) -> &[u8] {
295 let mut buf = self.0.as_slice();
296
297 while !buf.is_empty() && buf[0] == 0 {
298 buf = &buf[1..];
299 }
300
301 if buf.is_empty() {
303 &self.0[31..]
304 } else {
305 buf
306 }
307 }
308
309 pub fn padded(&self) -> &[u8] {
310 &self.0
311 }
312}
313
314impl std::fmt::Display for Bn254FieldElement {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 let u256 = U256::from_be(U256::from_digits(self.0));
317 let radix10 = u256.to_str_radix(10);
318 f.write_str(&radix10)
319 }
320}
321
322#[derive(Debug)]
323pub struct Bn254FieldElementParseError(bnum::errors::ParseIntError);
324
325impl std::fmt::Display for Bn254FieldElementParseError {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 write!(f, "unable to parse radix10 encoded value {}", self.0)
328 }
329}
330
331impl std::error::Error for Bn254FieldElementParseError {}
332
333impl std::str::FromStr for Bn254FieldElement {
334 type Err = Bn254FieldElementParseError;
335
336 fn from_str(s: &str) -> Result<Self, Self::Err> {
337 let u256 = U256::from_str_radix(s, 10).map_err(Bn254FieldElementParseError)?;
338 let be = u256.to_be();
339 Ok(Self(*be.digits()))
340 }
341}
342
343#[cfg(test)]
344mod test {
345 use super::Bn254FieldElement;
346 use num_bigint::BigUint;
347 use proptest::prelude::*;
348 use std::str::FromStr;
349 use test_strategy::proptest;
350
351 #[cfg(target_arch = "wasm32")]
352 use wasm_bindgen_test::wasm_bindgen_test as test;
353
354 #[test]
355 fn unpadded_slice() {
356 let seed = Bn254FieldElement([0; 32]);
357 let zero: [u8; 1] = [0];
358 assert_eq!(seed.unpadded(), zero.as_slice());
359
360 let mut seed = Bn254FieldElement([1; 32]);
361 seed.0[0] = 0;
362 assert_eq!(seed.unpadded(), [1; 31].as_slice());
363 }
364
365 #[proptest]
366 fn dont_crash_on_large_inputs(
367 #[strategy(proptest::collection::vec(any::<u8>(), 33..1024))] bytes: Vec<u8>,
368 ) {
369 let big_int = BigUint::from_bytes_be(&bytes);
370 let radix10 = big_int.to_str_radix(10);
371
372 let _ = Bn254FieldElement::from_str(&radix10);
374 }
375
376 #[proptest]
377 fn valid_address_seeds(
378 #[strategy(proptest::collection::vec(any::<u8>(), 1..=32))] bytes: Vec<u8>,
379 ) {
380 let big_int = BigUint::from_bytes_be(&bytes);
381 let radix10 = big_int.to_str_radix(10);
382
383 let seed = Bn254FieldElement::from_str(&radix10).unwrap();
384 assert_eq!(radix10, seed.to_string());
385 seed.unpadded();
387 }
388}
389
390#[cfg(feature = "serde")]
391#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
392mod serialization {
393 use crate::SignatureScheme;
394
395 use super::*;
396 use serde::Deserialize;
397 use serde::Deserializer;
398 use serde::Serialize;
399 use serde::Serializer;
400 use serde_with::Bytes;
401 use serde_with::DeserializeAs;
402 use serde_with::SerializeAs;
403 use std::borrow::Cow;
404
405 impl Serialize for ZkLoginPublicIdentifier {
407 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
408 where
409 S: Serializer,
410 {
411 if serializer.is_human_readable() {
412 #[derive(serde_derive::Serialize)]
413 struct Readable<'a> {
414 iss: &'a str,
415 address_seed: &'a Bn254FieldElement,
416 }
417 let readable = Readable {
418 iss: &self.iss,
419 address_seed: &self.address_seed,
420 };
421 readable.serialize(serializer)
422 } else {
423 let mut buf = Vec::new();
424 let iss_bytes = self.iss.as_bytes();
425 buf.push(iss_bytes.len() as u8);
426 buf.extend(iss_bytes);
427
428 buf.extend(&self.address_seed.0);
429
430 serializer.serialize_bytes(&buf)
431 }
432 }
433 }
434
435 impl<'de> Deserialize<'de> for ZkLoginPublicIdentifier {
436 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
437 where
438 D: Deserializer<'de>,
439 {
440 if deserializer.is_human_readable() {
441 #[derive(serde_derive::Deserialize)]
442 struct Readable {
443 iss: String,
444 address_seed: Bn254FieldElement,
445 }
446
447 let Readable { iss, address_seed } = Deserialize::deserialize(deserializer)?;
448 Self::new(iss, address_seed)
449 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
450 } else {
451 let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
452 let iss_len = *bytes
453 .first()
454 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
455 let iss_bytes = bytes
456 .get(1..(1 + iss_len as usize))
457 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
458 let iss = std::str::from_utf8(iss_bytes).map_err(serde::de::Error::custom)?;
459 let address_seed_bytes = bytes
460 .get((1 + iss_len as usize)..)
461 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
462
463 let address_seed = <[u8; 32]>::try_from(address_seed_bytes)
464 .map_err(serde::de::Error::custom)
465 .map(Bn254FieldElement)?;
466
467 Self::new(iss.into(), address_seed)
468 .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
469 }
470 }
471 }
472
473 #[derive(serde_derive::Serialize)]
474 struct AuthenticatorRef<'a> {
475 inputs: &'a ZkLoginInputs,
476 #[cfg_attr(feature = "serde", serde(with = "crate::_serde::ReadableDisplay"))]
477 max_epoch: EpochId,
478 signature: &'a SimpleSignature,
479 }
480
481 #[derive(serde_derive::Deserialize)]
482 struct Authenticator {
483 inputs: ZkLoginInputs,
484 #[cfg_attr(feature = "serde", serde(with = "crate::_serde::ReadableDisplay"))]
485 max_epoch: EpochId,
486 signature: SimpleSignature,
487 }
488
489 impl Serialize for ZkLoginAuthenticator {
490 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
491 where
492 S: Serializer,
493 {
494 if serializer.is_human_readable() {
495 let authenticator_ref = AuthenticatorRef {
496 inputs: &self.inputs,
497 max_epoch: self.max_epoch,
498 signature: &self.signature,
499 };
500
501 authenticator_ref.serialize(serializer)
502 } else {
503 let bytes = self.to_bytes();
504 serializer.serialize_bytes(&bytes)
505 }
506 }
507 }
508
509 impl<'de> Deserialize<'de> for ZkLoginAuthenticator {
510 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
511 where
512 D: Deserializer<'de>,
513 {
514 if deserializer.is_human_readable() {
515 let Authenticator {
516 inputs,
517 max_epoch,
518 signature,
519 } = Authenticator::deserialize(deserializer)?;
520 Ok(Self {
521 inputs,
522 max_epoch,
523 signature,
524 })
525 } else {
526 let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
527 Self::from_serialized_bytes(bytes)
528 }
529 }
530 }
531
532 impl ZkLoginAuthenticator {
533 pub(crate) fn to_bytes(&self) -> Vec<u8> {
534 let authenticator_ref = AuthenticatorRef {
535 inputs: &self.inputs,
536 max_epoch: self.max_epoch,
537 signature: &self.signature,
538 };
539
540 let mut buf = Vec::new();
541 buf.push(SignatureScheme::ZkLogin as u8);
542
543 bcs::serialize_into(&mut buf, &authenticator_ref).expect("serialization cannot fail");
544 buf
545 }
546
547 pub(crate) fn from_serialized_bytes<T: AsRef<[u8]>, E: serde::de::Error>(
548 bytes: T,
549 ) -> Result<Self, E> {
550 let bytes = bytes.as_ref();
551 let flag = SignatureScheme::from_byte(
552 *bytes
553 .first()
554 .ok_or_else(|| serde::de::Error::custom("missing signature scheme falg"))?,
555 )
556 .map_err(serde::de::Error::custom)?;
557 if flag != SignatureScheme::ZkLogin {
558 return Err(serde::de::Error::custom("invalid zklogin flag"));
559 }
560 let bcs_bytes = &bytes[1..];
561
562 let Authenticator {
563 inputs,
564 max_epoch,
565 signature,
566 } = bcs::from_bytes(bcs_bytes).map_err(serde::de::Error::custom)?;
567 Ok(Self {
568 inputs,
569 max_epoch,
570 signature,
571 })
572 }
573 }
574
575 impl Serialize for Bn254FieldElement {
577 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
578 where
579 S: serde::Serializer,
580 {
581 serde_with::DisplayFromStr::serialize_as(self, serializer)
582 }
583 }
584
585 impl<'de> Deserialize<'de> for Bn254FieldElement {
586 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
587 where
588 D: Deserializer<'de>,
589 {
590 serde_with::DisplayFromStr::deserialize_as(deserializer)
591 }
592 }
593
594 impl Serialize for CircomG1 {
595 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
596 where
597 S: serde::Serializer,
598 {
599 use serde::ser::SerializeSeq;
600 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
601 for element in &self.0 {
602 seq.serialize_element(element)?;
603 }
604 seq.end()
605 }
606 }
607
608 impl<'de> Deserialize<'de> for CircomG1 {
609 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
610 where
611 D: Deserializer<'de>,
612 {
613 let inner = <Vec<_>>::deserialize(deserializer)?;
614 Ok(Self(inner.try_into().map_err(|_| {
615 serde::de::Error::custom("expected array of length 3")
616 })?))
617 }
618 }
619
620 impl Serialize for CircomG2 {
621 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
622 where
623 S: serde::Serializer,
624 {
625 use serde::ser::SerializeSeq;
626
627 struct Inner<'a>(&'a [Bn254FieldElement; 2]);
628
629 impl Serialize for Inner<'_> {
630 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
631 where
632 S: serde::Serializer,
633 {
634 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
635 for element in self.0 {
636 seq.serialize_element(element)?;
637 }
638 seq.end()
639 }
640 }
641
642 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
643 for element in &self.0 {
644 seq.serialize_element(&Inner(element))?;
645 }
646 seq.end()
647 }
648 }
649
650 impl<'de> Deserialize<'de> for CircomG2 {
651 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
652 where
653 D: Deserializer<'de>,
654 {
655 let vecs = <Vec<Vec<Bn254FieldElement>>>::deserialize(deserializer)?;
656 let mut inner: [[Bn254FieldElement; 2]; 3] = Default::default();
657
658 if vecs.len() != 3 {
659 return Err(serde::de::Error::custom(
660 "vector of three vectors each being a vector of two strings",
661 ));
662 }
663
664 for (i, v) in vecs.into_iter().enumerate() {
665 if v.len() != 2 {
666 return Err(serde::de::Error::custom(
667 "vector of three vectors each being a vector of two strings",
668 ));
669 }
670
671 for (j, point) in v.into_iter().enumerate() {
672 inner[i][j] = point;
673 }
674 }
675
676 Ok(Self(inner))
677 }
678 }
679}