mysten_network/
codec.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use bytes::{Buf, BufMut};
5use std::{io::Read, marker::PhantomData};
6use tonic::{
7    Status,
8    codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder},
9};
10
11/// Default upper bound on total decompressed payload size.
12const MAX_DECOMPRESSED_SIZE: u64 = 128 << 20;
13
14/// Decompress a snappy-framed stream, bounding output at
15/// [`MAX_DECOMPRESSED_SIZE`].
16fn decompress_snappy<R: Read>(src: R) -> std::io::Result<Vec<u8>> {
17    decompress_snappy_bounded(src, MAX_DECOMPRESSED_SIZE)
18}
19
20/// Decompress a snappy-framed stream, bounding total output at
21/// `max_allowed` bytes. Exposed separately from [`decompress_snappy`] so
22/// tests can exercise the bound directly.
23fn decompress_snappy_bounded<R: Read>(src: R, max_allowed: u64) -> std::io::Result<Vec<u8>> {
24    let mut snappy_decoder = snap::read::FrameDecoder::new(src).take(max_allowed);
25    let mut bytes = Vec::new();
26    snappy_decoder.read_to_end(&mut bytes)?;
27    Ok(bytes)
28}
29
30#[derive(Debug)]
31pub struct BcsEncoder<T>(PhantomData<T>);
32
33impl<T: serde::Serialize> Encoder for BcsEncoder<T> {
34    type Item = T;
35    type Error = Status;
36
37    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
38        bcs::serialize_into(&mut buf.writer(), &item).map_err(|e| Status::internal(e.to_string()))
39    }
40}
41
42#[derive(Debug)]
43pub struct BcsDecoder<U>(PhantomData<U>);
44
45impl<U: serde::de::DeserializeOwned> Decoder for BcsDecoder<U> {
46    type Item = U;
47    type Error = Status;
48
49    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
50        if !buf.has_remaining() {
51            return Ok(None);
52        }
53
54        let chunk = buf.chunk();
55
56        let item: Self::Item =
57            bcs::from_bytes(chunk).map_err(|e| Status::internal(e.to_string()))?;
58        buf.advance(chunk.len());
59
60        Ok(Some(item))
61    }
62}
63
64/// A [`Codec`] that implements `application/grpc+bcs` via the serde library.
65#[derive(Debug, Clone)]
66pub struct BcsCodec<T, U>(PhantomData<(T, U)>);
67
68impl<T, U> Default for BcsCodec<T, U> {
69    fn default() -> Self {
70        Self(PhantomData)
71    }
72}
73
74impl<T, U> Codec for BcsCodec<T, U>
75where
76    T: serde::Serialize + Send + 'static,
77    U: serde::de::DeserializeOwned + Send + 'static,
78{
79    type Encode = T;
80    type Decode = U;
81    type Encoder = BcsEncoder<T>;
82    type Decoder = BcsDecoder<U>;
83
84    fn encoder(&mut self) -> Self::Encoder {
85        BcsEncoder(PhantomData)
86    }
87
88    fn decoder(&mut self) -> Self::Decoder {
89        BcsDecoder(PhantomData)
90    }
91}
92
93#[derive(Debug)]
94pub struct BcsSnappyEncoder<T>(PhantomData<T>);
95
96impl<T: serde::Serialize> Encoder for BcsSnappyEncoder<T> {
97    type Item = T;
98    type Error = Status;
99
100    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
101        let mut snappy_encoder = snap::write::FrameEncoder::new(buf.writer());
102        bcs::serialize_into(&mut snappy_encoder, &item).map_err(|e| Status::internal(e.to_string()))
103    }
104}
105
106#[derive(Debug)]
107pub struct BcsSnappyDecoder<U>(PhantomData<U>);
108
109impl<U: serde::de::DeserializeOwned> Decoder for BcsSnappyDecoder<U> {
110    type Item = U;
111    type Error = Status;
112
113    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
114        if !buf.has_remaining() {
115            return Ok(None);
116        }
117        let bytes = decompress_snappy(buf.reader()).map_err(|e| Status::internal(e.to_string()))?;
118        let item =
119            bcs::from_bytes(bytes.as_slice()).map_err(|e| Status::internal(e.to_string()))?;
120        Ok(Some(item))
121    }
122}
123
124/// A [`Codec`] that implements `bcs` encoding/decoding and snappy compression/decompression
125/// via the serde library.
126#[derive(Debug, Clone)]
127pub struct BcsSnappyCodec<T, U>(PhantomData<(T, U)>);
128
129impl<T, U> Default for BcsSnappyCodec<T, U> {
130    fn default() -> Self {
131        Self(PhantomData)
132    }
133}
134
135impl<T, U> Codec for BcsSnappyCodec<T, U>
136where
137    T: serde::Serialize + Send + 'static,
138    U: serde::de::DeserializeOwned + Send + 'static,
139{
140    type Encode = T;
141    type Decode = U;
142    type Encoder = BcsSnappyEncoder<T>;
143    type Decoder = BcsSnappyDecoder<U>;
144
145    fn encoder(&mut self) -> Self::Encoder {
146        BcsSnappyEncoder(PhantomData)
147    }
148
149    fn decoder(&mut self) -> Self::Decoder {
150        BcsSnappyDecoder(PhantomData)
151    }
152}
153
154// Anemo variant of BCS codec using Snappy for compression.
155pub mod anemo {
156    use ::anemo::rpc::codec::{Codec, Decoder, Encoder};
157    use bytes::Buf;
158    use std::marker::PhantomData;
159
160    #[derive(Debug)]
161    pub struct BcsSnappyEncoder<T>(PhantomData<T>);
162
163    impl<T: serde::Serialize> Encoder for BcsSnappyEncoder<T> {
164        type Item = T;
165        type Error = bcs::Error;
166
167        fn encode(&mut self, item: Self::Item) -> Result<bytes::Bytes, Self::Error> {
168            let mut buf = Vec::<u8>::new();
169            let mut snappy_encoder = snap::write::FrameEncoder::new(&mut buf);
170            bcs::serialize_into(&mut snappy_encoder, &item)?;
171            drop(snappy_encoder);
172            Ok(buf.into())
173        }
174    }
175
176    #[derive(Debug)]
177    pub struct BcsSnappyDecoder<U>(PhantomData<U>);
178
179    impl<U: serde::de::DeserializeOwned> Decoder for BcsSnappyDecoder<U> {
180        type Item = U;
181        type Error = bcs::Error;
182
183        fn decode(&mut self, buf: bytes::Bytes) -> Result<Self::Item, Self::Error> {
184            let bytes = super::decompress_snappy(buf.reader())?;
185            bcs::from_bytes(bytes.as_slice())
186        }
187    }
188
189    /// A [`Codec`] that implements `bcs` encoding/decoding via the serde library.
190    #[derive(Debug, Clone)]
191    pub struct BcsSnappyCodec<T, U>(PhantomData<(T, U)>);
192
193    impl<T, U> Default for BcsSnappyCodec<T, U> {
194        fn default() -> Self {
195            Self(PhantomData)
196        }
197    }
198
199    impl<T, U> Codec for BcsSnappyCodec<T, U>
200    where
201        T: serde::Serialize + Send + 'static,
202        U: serde::de::DeserializeOwned + Send + 'static,
203    {
204        type Encode = T;
205        type Decode = U;
206        type Encoder = BcsSnappyEncoder<T>;
207        type Decoder = BcsSnappyDecoder<U>;
208
209        fn encoder(&mut self) -> Self::Encoder {
210            BcsSnappyEncoder(PhantomData)
211        }
212
213        fn decoder(&mut self) -> Self::Decoder {
214            BcsSnappyDecoder(PhantomData)
215        }
216
217        fn format_name(&self) -> &'static str {
218            "bcs"
219        }
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use ::anemo::rpc::codec::{
227        Codec as AnemoCodec, Decoder as AnemoDecoder, Encoder as AnemoEncoder,
228    };
229
230    fn snappy_compress(raw: &[u8]) -> Vec<u8> {
231        let mut out = Vec::new();
232        let mut encoder = snap::write::FrameEncoder::new(&mut out);
233        std::io::Write::write_all(&mut encoder, raw).unwrap();
234        drop(encoder);
235        out
236    }
237
238    #[test]
239    fn anemo_roundtrip() {
240        let mut codec: anemo::BcsSnappyCodec<Vec<u64>, Vec<u64>> = anemo::BcsSnappyCodec::default();
241        let value = vec![1u64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
242        let encoded = codec.encoder().encode(value.clone()).unwrap();
243        let decoded = codec.decoder().decode(encoded).unwrap();
244        assert_eq!(decoded, value);
245    }
246
247    #[test]
248    fn bounded_helper_respects_output_limit() {
249        // With `max_allowed` set below the stream's decompressed size,
250        // `decompress_snappy_bounded` returns exactly `max_allowed` bytes.
251        let raw = vec![0u8; 2 * 1024 * 1024];
252        let compressed = snappy_compress(&raw);
253        let limit = 1024u64;
254        let out = decompress_snappy_bounded(&compressed[..], limit).unwrap();
255        assert_eq!(out.len() as u64, limit);
256        assert!((out.len() as u64) < raw.len() as u64);
257    }
258
259    // Tonic-variant round trip via a bespoke HttpBody. Building a real
260    // `DecodeBuf` requires tonic's internal API, so we drive the decoder
261    // through `tonic::codec::Streaming::new_request`, which is the path used
262    // by the gRPC server. This exercises the real `BcsSnappyDecoder::decode`.
263    mod tonic_via_streaming {
264        use super::super::*;
265        use super::snappy_compress;
266        use bytes::{BufMut, Bytes, BytesMut};
267        use futures::StreamExt;
268        use http_body::{Body as HttpBody, Frame};
269        use std::pin::Pin;
270        use std::task::{Context, Poll};
271
272        /// Minimal HttpBody yielding a single gRPC-framed payload. gRPC length-
273        /// prefixed frames are `[compression:u8][length:u32 BE][payload]`.
274        struct OneFrameBody(Option<Bytes>);
275
276        impl OneFrameBody {
277            fn new(payload: Bytes) -> Self {
278                let mut framed = BytesMut::with_capacity(5 + payload.len());
279                framed.put_u8(0);
280                framed.put_u32(payload.len() as u32);
281                framed.put_slice(&payload);
282                Self(Some(framed.freeze()))
283            }
284        }
285
286        impl HttpBody for OneFrameBody {
287            type Data = Bytes;
288            type Error = std::convert::Infallible;
289
290            fn poll_frame(
291                mut self: Pin<&mut Self>,
292                _cx: &mut Context<'_>,
293            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
294                Poll::Ready(self.0.take().map(|b| Ok(Frame::data(b))))
295            }
296
297            fn is_end_stream(&self) -> bool {
298                self.0.is_none()
299            }
300        }
301
302        #[tokio::test]
303        async fn tonic_roundtrip() {
304            let mut codec: BcsSnappyCodec<Vec<u64>, Vec<u64>> = BcsSnappyCodec::default();
305            let value = vec![1u64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
306            let raw = bcs::to_bytes(&value).unwrap();
307            let compressed = snappy_compress(&raw);
308            let body = OneFrameBody::new(Bytes::from(compressed));
309            let mut stream =
310                tonic::codec::Streaming::new_request(codec.decoder(), body, None, None);
311            let decoded = stream.next().await.unwrap().unwrap();
312            assert_eq!(decoded, value);
313        }
314    }
315}