sui_rpc/_serde/
mod.rs

1#![allow(dead_code)]
2
3mod well_known_types;
4pub use well_known_types::*;
5
6/// Re-export base64
7pub use base64;
8
9use base64::engine::DecodePaddingMode;
10use base64::engine::GeneralPurpose;
11use base64::engine::GeneralPurposeConfig;
12use base64::Engine;
13use serde::de::Visitor;
14use serde::Deserialize;
15use std::borrow::Cow;
16use std::str::FromStr;
17
18/// Used to parse a number from either a string or its raw representation
19#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
20pub struct NumberDeserialize<T>(pub T);
21
22#[derive(Deserialize)]
23#[serde(untagged)]
24enum Content<'a, T> {
25    #[serde(borrow)]
26    Str(Cow<'a, str>),
27    Number(T),
28}
29
30impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
31where
32    T: FromStr + serde::Deserialize<'de>,
33    <T as FromStr>::Err: std::error::Error,
34{
35    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
36    where
37        D: serde::Deserializer<'de>,
38    {
39        let content = Content::deserialize(deserializer)?;
40        Ok(Self(match content {
41            Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
42            Content::Number(v) => v,
43        }))
44    }
45}
46
47struct Base64Visitor;
48
49impl<'de> Visitor<'de> for Base64Visitor {
50    type Value = Vec<u8>;
51
52    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        formatter.write_str("a base64 string")
54    }
55
56    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
57    where
58        E: serde::de::Error,
59    {
60        const INDIFFERENT_PAD: GeneralPurposeConfig =
61            GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent);
62        const STANDARD_INDIFFERENT_PAD: GeneralPurpose =
63            GeneralPurpose::new(&base64::alphabet::STANDARD, INDIFFERENT_PAD);
64        const URL_SAFE_INDIFFERENT_PAD: GeneralPurpose =
65            GeneralPurpose::new(&base64::alphabet::URL_SAFE, INDIFFERENT_PAD);
66
67        let decoded = STANDARD_INDIFFERENT_PAD
68            .decode(s)
69            .or_else(|e| match e {
70                // Either standard or URL-safe base64 encoding are accepted
71                //
72                // The difference being URL-safe uses `-` and `_` instead of `+` and `/`
73                //
74                // Therefore if we error out on those characters, try again with
75                // the URL-safe character set
76                base64::DecodeError::InvalidByte(_, c) if c == b'-' || c == b'_' => {
77                    URL_SAFE_INDIFFERENT_PAD.decode(s)
78                }
79                _ => Err(e),
80            })
81            .map_err(serde::de::Error::custom)?;
82        Ok(decoded)
83    }
84}
85
86#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
87pub struct BytesDeserialize<T>(pub T);
88
89impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
90where
91    T: From<Vec<u8>>,
92{
93    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
94    where
95        D: serde::Deserializer<'de>,
96    {
97        Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into()))
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use base64::Engine;
105    use bytes::Bytes;
106    use rand::prelude::*;
107    use serde::de::value::BorrowedStrDeserializer;
108    use serde::de::value::Error;
109
110    #[test]
111    fn test_bytes() {
112        for _ in 0..20 {
113            let mut rng = thread_rng();
114            let len = rng.gen_range(50..100);
115            let raw: Vec<_> = std::iter::from_fn(|| Some(rng.gen())).take(len).collect();
116
117            for config in [
118                base64::engine::general_purpose::STANDARD,
119                base64::engine::general_purpose::STANDARD_NO_PAD,
120                base64::engine::general_purpose::URL_SAFE,
121                base64::engine::general_purpose::URL_SAFE_NO_PAD,
122            ] {
123                let encoded = config.encode(&raw);
124
125                let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
126                let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
127                let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
128
129                assert_eq!(raw.as_slice(), &a);
130                assert_eq!(raw.as_slice(), &b);
131            }
132        }
133    }
134
135    #[test]
136    fn value() {
137        let v = serde_json::json!({
138            "foo": 4,
139            "bar": "abc",
140            "baz": [1, 2, 3],
141            "foobar": null,
142        });
143        let proto: ValueDeserializer = serde_json::from_value(v).unwrap();
144        println!(
145            "{}",
146            serde_json::to_string_pretty(&ValueSerializer(&proto.0)).unwrap()
147        );
148    }
149}