1#![allow(dead_code)]
2
3mod well_known_types;
4pub use well_known_types::*;
5
6pub use base64;
8
9use base64::engine::DecodePaddingMode;
10use base64::engine::GeneralPurpose;
11use base64::engine::GeneralPurposeConfig;
12use base64::Engine;
13use serde::de::Visitor;
14use serde::Deserialize;
15use std::borrow::Cow;
16use std::str::FromStr;
17
18#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
20pub struct NumberDeserialize<T>(pub T);
21
22#[derive(Deserialize)]
23#[serde(untagged)]
24enum Content<'a, T> {
25 #[serde(borrow)]
26 Str(Cow<'a, str>),
27 Number(T),
28}
29
30impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
31where
32 T: FromStr + serde::Deserialize<'de>,
33 <T as FromStr>::Err: std::error::Error,
34{
35 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
36 where
37 D: serde::Deserializer<'de>,
38 {
39 let content = Content::deserialize(deserializer)?;
40 Ok(Self(match content {
41 Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
42 Content::Number(v) => v,
43 }))
44 }
45}
46
47struct Base64Visitor;
48
49impl<'de> Visitor<'de> for Base64Visitor {
50 type Value = Vec<u8>;
51
52 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 formatter.write_str("a base64 string")
54 }
55
56 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
57 where
58 E: serde::de::Error,
59 {
60 const INDIFFERENT_PAD: GeneralPurposeConfig =
61 GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent);
62 const STANDARD_INDIFFERENT_PAD: GeneralPurpose =
63 GeneralPurpose::new(&base64::alphabet::STANDARD, INDIFFERENT_PAD);
64 const URL_SAFE_INDIFFERENT_PAD: GeneralPurpose =
65 GeneralPurpose::new(&base64::alphabet::URL_SAFE, INDIFFERENT_PAD);
66
67 let decoded = STANDARD_INDIFFERENT_PAD
68 .decode(s)
69 .or_else(|e| match e {
70 base64::DecodeError::InvalidByte(_, c) if c == b'-' || c == b'_' => {
77 URL_SAFE_INDIFFERENT_PAD.decode(s)
78 }
79 _ => Err(e),
80 })
81 .map_err(serde::de::Error::custom)?;
82 Ok(decoded)
83 }
84}
85
86#[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
87pub struct BytesDeserialize<T>(pub T);
88
89impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
90where
91 T: From<Vec<u8>>,
92{
93 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
94 where
95 D: serde::Deserializer<'de>,
96 {
97 Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into()))
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use base64::Engine;
105 use bytes::Bytes;
106 use rand::prelude::*;
107 use serde::de::value::BorrowedStrDeserializer;
108 use serde::de::value::Error;
109
110 #[test]
111 fn test_bytes() {
112 for _ in 0..20 {
113 let mut rng = thread_rng();
114 let len = rng.gen_range(50..100);
115 let raw: Vec<_> = std::iter::from_fn(|| Some(rng.gen())).take(len).collect();
116
117 for config in [
118 base64::engine::general_purpose::STANDARD,
119 base64::engine::general_purpose::STANDARD_NO_PAD,
120 base64::engine::general_purpose::URL_SAFE,
121 base64::engine::general_purpose::URL_SAFE_NO_PAD,
122 ] {
123 let encoded = config.encode(&raw);
124
125 let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
126 let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
127 let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
128
129 assert_eq!(raw.as_slice(), &a);
130 assert_eq!(raw.as_slice(), &b);
131 }
132 }
133 }
134
135 #[test]
136 fn value() {
137 let v = serde_json::json!({
138 "foo": 4,
139 "bar": "abc",
140 "baz": [1, 2, 3],
141 "foobar": null,
142 });
143 let proto: ValueDeserializer = serde_json::from_value(v).unwrap();
144 println!(
145 "{}",
146 serde_json::to_string_pretty(&ValueSerializer(&proto.0)).unwrap()
147 );
148 }
149}