mysten_network/
grpc_timeout.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Ported from `tonic` crate
5// SPDX-License-Identifier: MIT
6
7use http::{HeaderMap, HeaderValue, Request, Response};
8use pin_project_lite::pin_project;
9use std::{
10    future::Future,
11    pin::Pin,
12    task::{Context, Poll, ready},
13    time::Duration,
14};
15use tokio::time::Sleep;
16use tonic::Status;
17use tower::Service;
18
19const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";
20
21/// Applies timeouts on incoming requests, specified by their header and server default.
22#[derive(Debug, Clone)]
23pub struct GrpcTimeout<S> {
24    inner: S,
25    // Apply a max timeout for all requests to limit their total memory usage.
26    server_timeout: Duration,
27}
28
29impl<S> GrpcTimeout<S> {
30    pub fn new(inner: S, server_timeout: Duration) -> Self {
31        Self {
32            inner,
33            server_timeout,
34        }
35    }
36}
37
38impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for GrpcTimeout<S>
39where
40    S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
41{
42    type Response = Response<MaybeEmptyBody<ResponseBody>>;
43    type Error = S::Error;
44    type Future = ResponseFuture<S::Future>;
45
46    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
47        self.inner.poll_ready(cx).map_err(Into::into)
48    }
49
50    fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
51        let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
52            tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
53            None
54        });
55
56        // Use the shorter of the two durations, if either are set
57        let resp_timeout = match client_timeout {
58            None => self.server_timeout,
59            Some(d) => self.server_timeout.min(d),
60        };
61
62        ResponseFuture {
63            inner: self.inner.call(req),
64            sleep: tokio::time::sleep(resp_timeout),
65        }
66    }
67}
68
69pin_project! {
70    pub struct ResponseFuture<F> {
71        #[pin]
72        inner: F,
73        #[pin]
74        sleep: Sleep,
75    }
76}
77
78impl<F, ResponseBody, E> Future for ResponseFuture<F>
79where
80    F: Future<Output = Result<Response<ResponseBody>, E>>,
81{
82    type Output = Result<Response<MaybeEmptyBody<ResponseBody>>, E>;
83
84    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85        let this = self.project();
86
87        if let Poll::Ready(result) = this.inner.poll(cx) {
88            return Poll::Ready(result.map(|response| response.map(MaybeEmptyBody::full)));
89        }
90
91        ready!(this.sleep.poll(cx));
92        let response = Status::deadline_exceeded("Timeout expired")
93            .into_http()
94            .map(|()| MaybeEmptyBody::empty());
95        Poll::Ready(Ok(response))
96    }
97}
98
99pin_project! {
100    pub struct MaybeEmptyBody<B> {
101        #[pin]
102        inner: Option<B>,
103    }
104}
105
106impl<B> MaybeEmptyBody<B> {
107    fn full(inner: B) -> Self {
108        Self { inner: Some(inner) }
109    }
110
111    fn empty() -> Self {
112        Self { inner: None }
113    }
114}
115
116impl<B> http_body::Body for MaybeEmptyBody<B>
117where
118    B: http_body::Body + Send,
119{
120    type Data = B::Data;
121    type Error = B::Error;
122
123    fn poll_frame(
124        self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
127        match self.project().inner.as_pin_mut() {
128            Some(b) => b.poll_frame(cx),
129            None => Poll::Ready(None),
130        }
131    }
132
133    fn is_end_stream(&self) -> bool {
134        match &self.inner {
135            Some(b) => b.is_end_stream(),
136            None => true,
137        }
138    }
139
140    fn size_hint(&self) -> http_body::SizeHint {
141        match &self.inner {
142            Some(body) => body.size_hint(),
143            None => http_body::SizeHint::with_exact(0),
144        }
145    }
146}
147
148const SECONDS_IN_HOUR: u64 = 60 * 60;
149const SECONDS_IN_MINUTE: u64 = 60;
150
151/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns
152/// the value we attempted to parse.
153///
154/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
155fn try_parse_grpc_timeout(
156    headers: &HeaderMap<HeaderValue>,
157) -> Result<Option<Duration>, &HeaderValue> {
158    let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
159        return Ok(None);
160    };
161
162    let (timeout_value, timeout_unit) = val
163        .to_str()
164        .map_err(|_| val)
165        .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
166        // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
167        // `split_at` will never panic from trying to split in the middle of a character.
168        // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str
169        //
170        // `len - 1` also wont panic since we just checked `s.is_empty`.
171        .split_at(val.len() - 1);
172
173    // gRPC spec specifies `TimeoutValue` will be at most 8 digits
174    // Caping this at 8 digits also prevents integer overflow from ever occurring
175    if timeout_value.len() > 8 {
176        return Err(val);
177    }
178
179    let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
180
181    let duration = match timeout_unit {
182        // Hours
183        "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
184        // Minutes
185        "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
186        // Seconds
187        "S" => Duration::from_secs(timeout_value),
188        // Milliseconds
189        "m" => Duration::from_millis(timeout_value),
190        // Microseconds
191        "u" => Duration::from_micros(timeout_value),
192        // Nanoseconds
193        "n" => Duration::from_nanos(timeout_value),
194        _ => return Err(val),
195    };
196
197    Ok(Some(duration))
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    // Helper function to reduce the boiler plate of our test cases
205    fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
206        let mut hm = HeaderMap::new();
207        if let Some(v) = val {
208            let hv = HeaderValue::from_str(v).unwrap();
209            hm.insert(GRPC_TIMEOUT_HEADER, hv);
210        };
211
212        try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
213    }
214
215    #[test]
216    fn test_hours() {
217        let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
218        assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
219    }
220
221    #[test]
222    fn test_minutes() {
223        let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
224        assert_eq!(Duration::from_secs(60), parsed_duration);
225    }
226
227    #[test]
228    fn test_seconds() {
229        let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
230        assert_eq!(Duration::from_secs(42), parsed_duration);
231    }
232
233    #[test]
234    fn test_milliseconds() {
235        let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
236        assert_eq!(Duration::from_millis(13), parsed_duration);
237    }
238
239    #[test]
240    fn test_microseconds() {
241        let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
242        assert_eq!(Duration::from_micros(2), parsed_duration);
243    }
244
245    #[test]
246    fn test_nanoseconds() {
247        let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
248        assert_eq!(Duration::from_nanos(82), parsed_duration);
249    }
250
251    #[test]
252    fn test_header_not_present() {
253        let parsed_duration = setup_map_try_parse(None).unwrap();
254        assert!(parsed_duration.is_none());
255    }
256
257    #[test]
258    #[should_panic(expected = "82f")]
259    fn test_invalid_unit() {
260        // "f" is not a valid TimeoutUnit
261        setup_map_try_parse(Some("82f")).unwrap().unwrap();
262    }
263
264    #[test]
265    #[should_panic(expected = "123456789H")]
266    fn test_too_many_digits() {
267        // gRPC spec states TimeoutValue will be at most 8 digits
268        setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
269    }
270
271    #[test]
272    #[should_panic(expected = "oneH")]
273    fn test_invalid_digits() {
274        // gRPC spec states TimeoutValue will be at most 8 digits
275        setup_map_try_parse(Some("oneH")).unwrap().unwrap();
276    }
277}