1use bytes::{Buf, BufMut};
5use std::{io::Read, marker::PhantomData};
6use tonic::{
7 Status,
8 codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder},
9};
10
11const MAX_DECOMPRESSED_SIZE: u64 = 128 << 20;
13
14fn decompress_snappy<R: Read>(src: R) -> std::io::Result<Vec<u8>> {
17 decompress_snappy_bounded(src, MAX_DECOMPRESSED_SIZE)
18}
19
20fn decompress_snappy_bounded<R: Read>(src: R, max_allowed: u64) -> std::io::Result<Vec<u8>> {
24 let mut snappy_decoder = snap::read::FrameDecoder::new(src).take(max_allowed);
25 let mut bytes = Vec::new();
26 snappy_decoder.read_to_end(&mut bytes)?;
27 Ok(bytes)
28}
29
30#[derive(Debug)]
31pub struct BcsEncoder<T>(PhantomData<T>);
32
33impl<T: serde::Serialize> Encoder for BcsEncoder<T> {
34 type Item = T;
35 type Error = Status;
36
37 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
38 bcs::serialize_into(&mut buf.writer(), &item).map_err(|e| Status::internal(e.to_string()))
39 }
40}
41
42#[derive(Debug)]
43pub struct BcsDecoder<U>(PhantomData<U>);
44
45impl<U: serde::de::DeserializeOwned> Decoder for BcsDecoder<U> {
46 type Item = U;
47 type Error = Status;
48
49 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
50 if !buf.has_remaining() {
51 return Ok(None);
52 }
53
54 let chunk = buf.chunk();
55
56 let item: Self::Item =
57 bcs::from_bytes(chunk).map_err(|e| Status::internal(e.to_string()))?;
58 buf.advance(chunk.len());
59
60 Ok(Some(item))
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct BcsCodec<T, U>(PhantomData<(T, U)>);
67
68impl<T, U> Default for BcsCodec<T, U> {
69 fn default() -> Self {
70 Self(PhantomData)
71 }
72}
73
74impl<T, U> Codec for BcsCodec<T, U>
75where
76 T: serde::Serialize + Send + 'static,
77 U: serde::de::DeserializeOwned + Send + 'static,
78{
79 type Encode = T;
80 type Decode = U;
81 type Encoder = BcsEncoder<T>;
82 type Decoder = BcsDecoder<U>;
83
84 fn encoder(&mut self) -> Self::Encoder {
85 BcsEncoder(PhantomData)
86 }
87
88 fn decoder(&mut self) -> Self::Decoder {
89 BcsDecoder(PhantomData)
90 }
91}
92
93#[derive(Debug)]
94pub struct BcsSnappyEncoder<T>(PhantomData<T>);
95
96impl<T: serde::Serialize> Encoder for BcsSnappyEncoder<T> {
97 type Item = T;
98 type Error = Status;
99
100 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
101 let mut snappy_encoder = snap::write::FrameEncoder::new(buf.writer());
102 bcs::serialize_into(&mut snappy_encoder, &item).map_err(|e| Status::internal(e.to_string()))
103 }
104}
105
106#[derive(Debug)]
107pub struct BcsSnappyDecoder<U>(PhantomData<U>);
108
109impl<U: serde::de::DeserializeOwned> Decoder for BcsSnappyDecoder<U> {
110 type Item = U;
111 type Error = Status;
112
113 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
114 if !buf.has_remaining() {
115 return Ok(None);
116 }
117 let bytes = decompress_snappy(buf.reader()).map_err(|e| Status::internal(e.to_string()))?;
118 let item =
119 bcs::from_bytes(bytes.as_slice()).map_err(|e| Status::internal(e.to_string()))?;
120 Ok(Some(item))
121 }
122}
123
124#[derive(Debug, Clone)]
127pub struct BcsSnappyCodec<T, U>(PhantomData<(T, U)>);
128
129impl<T, U> Default for BcsSnappyCodec<T, U> {
130 fn default() -> Self {
131 Self(PhantomData)
132 }
133}
134
135impl<T, U> Codec for BcsSnappyCodec<T, U>
136where
137 T: serde::Serialize + Send + 'static,
138 U: serde::de::DeserializeOwned + Send + 'static,
139{
140 type Encode = T;
141 type Decode = U;
142 type Encoder = BcsSnappyEncoder<T>;
143 type Decoder = BcsSnappyDecoder<U>;
144
145 fn encoder(&mut self) -> Self::Encoder {
146 BcsSnappyEncoder(PhantomData)
147 }
148
149 fn decoder(&mut self) -> Self::Decoder {
150 BcsSnappyDecoder(PhantomData)
151 }
152}
153
154pub mod anemo {
156 use ::anemo::rpc::codec::{Codec, Decoder, Encoder};
157 use bytes::Buf;
158 use std::marker::PhantomData;
159
160 #[derive(Debug)]
161 pub struct BcsSnappyEncoder<T>(PhantomData<T>);
162
163 impl<T: serde::Serialize> Encoder for BcsSnappyEncoder<T> {
164 type Item = T;
165 type Error = bcs::Error;
166
167 fn encode(&mut self, item: Self::Item) -> Result<bytes::Bytes, Self::Error> {
168 let mut buf = Vec::<u8>::new();
169 let mut snappy_encoder = snap::write::FrameEncoder::new(&mut buf);
170 bcs::serialize_into(&mut snappy_encoder, &item)?;
171 drop(snappy_encoder);
172 Ok(buf.into())
173 }
174 }
175
176 #[derive(Debug)]
177 pub struct BcsSnappyDecoder<U>(PhantomData<U>);
178
179 impl<U: serde::de::DeserializeOwned> Decoder for BcsSnappyDecoder<U> {
180 type Item = U;
181 type Error = bcs::Error;
182
183 fn decode(&mut self, buf: bytes::Bytes) -> Result<Self::Item, Self::Error> {
184 let bytes = super::decompress_snappy(buf.reader())?;
185 bcs::from_bytes(bytes.as_slice())
186 }
187 }
188
189 #[derive(Debug, Clone)]
191 pub struct BcsSnappyCodec<T, U>(PhantomData<(T, U)>);
192
193 impl<T, U> Default for BcsSnappyCodec<T, U> {
194 fn default() -> Self {
195 Self(PhantomData)
196 }
197 }
198
199 impl<T, U> Codec for BcsSnappyCodec<T, U>
200 where
201 T: serde::Serialize + Send + 'static,
202 U: serde::de::DeserializeOwned + Send + 'static,
203 {
204 type Encode = T;
205 type Decode = U;
206 type Encoder = BcsSnappyEncoder<T>;
207 type Decoder = BcsSnappyDecoder<U>;
208
209 fn encoder(&mut self) -> Self::Encoder {
210 BcsSnappyEncoder(PhantomData)
211 }
212
213 fn decoder(&mut self) -> Self::Decoder {
214 BcsSnappyDecoder(PhantomData)
215 }
216
217 fn format_name(&self) -> &'static str {
218 "bcs"
219 }
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use ::anemo::rpc::codec::{
227 Codec as AnemoCodec, Decoder as AnemoDecoder, Encoder as AnemoEncoder,
228 };
229
230 fn snappy_compress(raw: &[u8]) -> Vec<u8> {
231 let mut out = Vec::new();
232 let mut encoder = snap::write::FrameEncoder::new(&mut out);
233 std::io::Write::write_all(&mut encoder, raw).unwrap();
234 drop(encoder);
235 out
236 }
237
238 #[test]
239 fn anemo_roundtrip() {
240 let mut codec: anemo::BcsSnappyCodec<Vec<u64>, Vec<u64>> = anemo::BcsSnappyCodec::default();
241 let value = vec![1u64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
242 let encoded = codec.encoder().encode(value.clone()).unwrap();
243 let decoded = codec.decoder().decode(encoded).unwrap();
244 assert_eq!(decoded, value);
245 }
246
247 #[test]
248 fn bounded_helper_respects_output_limit() {
249 let raw = vec![0u8; 2 * 1024 * 1024];
252 let compressed = snappy_compress(&raw);
253 let limit = 1024u64;
254 let out = decompress_snappy_bounded(&compressed[..], limit).unwrap();
255 assert_eq!(out.len() as u64, limit);
256 assert!((out.len() as u64) < raw.len() as u64);
257 }
258
259 mod tonic_via_streaming {
264 use super::super::*;
265 use super::snappy_compress;
266 use bytes::{BufMut, Bytes, BytesMut};
267 use futures::StreamExt;
268 use http_body::{Body as HttpBody, Frame};
269 use std::pin::Pin;
270 use std::task::{Context, Poll};
271
272 struct OneFrameBody(Option<Bytes>);
275
276 impl OneFrameBody {
277 fn new(payload: Bytes) -> Self {
278 let mut framed = BytesMut::with_capacity(5 + payload.len());
279 framed.put_u8(0);
280 framed.put_u32(payload.len() as u32);
281 framed.put_slice(&payload);
282 Self(Some(framed.freeze()))
283 }
284 }
285
286 impl HttpBody for OneFrameBody {
287 type Data = Bytes;
288 type Error = std::convert::Infallible;
289
290 fn poll_frame(
291 mut self: Pin<&mut Self>,
292 _cx: &mut Context<'_>,
293 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
294 Poll::Ready(self.0.take().map(|b| Ok(Frame::data(b))))
295 }
296
297 fn is_end_stream(&self) -> bool {
298 self.0.is_none()
299 }
300 }
301
302 #[tokio::test]
303 async fn tonic_roundtrip() {
304 let mut codec: BcsSnappyCodec<Vec<u64>, Vec<u64>> = BcsSnappyCodec::default();
305 let value = vec![1u64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
306 let raw = bcs::to_bytes(&value).unwrap();
307 let compressed = snappy_compress(&raw);
308 let body = OneFrameBody::new(Bytes::from(compressed));
309 let mut stream =
310 tonic::codec::Streaming::new_request(codec.decoder(), body, None, None);
311 let decoded = stream.next().await.unwrap().unwrap();
312 assert_eq!(decoded, value);
313 }
314 }
315}