sui_http/middleware/
grpc_timeout.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Ported from `tonic` crate
5// SPDX-License-Identifier: MIT
6
7use http::HeaderMap;
8use http::HeaderName;
9use http::HeaderValue;
10use http::Request;
11use http::Response;
12use pin_project_lite::pin_project;
13use std::future::Future;
14use std::pin::Pin;
15use std::task::Context;
16use std::task::Poll;
17use std::task::ready;
18use std::time::Duration;
19use tokio::time::Sleep;
20use tower::Service;
21
22const GRPC_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("grpc-timeout");
23const GRPC_STATUS_HEADER: HeaderName = HeaderName::from_static("grpc-status");
24const GRPC_MESSAGE_HEADER: HeaderName = HeaderName::from_static("grpc-message");
25
26const GRPC_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/grpc");
27const GRPC_DEADLINE_EXCEEDED_CODE: HeaderValue = HeaderValue::from_static("4");
28const GRPC_DEADLINE_EXCEEDED_MESSAGE: HeaderValue = HeaderValue::from_static("Timeout%20expired");
29
30#[derive(Debug, Clone)]
31pub struct GrpcTimeout<S> {
32    inner: S,
33    server_timeout: Option<Duration>,
34}
35
36impl<S> GrpcTimeout<S> {
37    pub fn new(inner: S, server_timeout: Option<Duration>) -> Self {
38        Self {
39            inner,
40            server_timeout,
41        }
42    }
43}
44
45impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for GrpcTimeout<S>
46where
47    S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
48{
49    type Response = Response<MaybeEmptyBody<ResponseBody>>;
50    type Error = S::Error;
51    type Future = ResponseFuture<S::Future>;
52
53    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54        self.inner.poll_ready(cx).map_err(Into::into)
55    }
56
57    fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
58        let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
59            tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
60            None
61        });
62
63        // Use the shorter of the two durations, if either are set
64        let timeout_duration = match (client_timeout, self.server_timeout) {
65            (None, None) => None,
66            (Some(dur), None) => Some(dur),
67            (None, Some(dur)) => Some(dur),
68            (Some(header), Some(server)) => {
69                let shorter_duration = std::cmp::min(header, server);
70                Some(shorter_duration)
71            }
72        };
73
74        ResponseFuture {
75            inner: self.inner.call(req),
76            sleep: timeout_duration.map(tokio::time::sleep),
77        }
78    }
79}
80
81pin_project! {
82    pub struct ResponseFuture<F> {
83        #[pin]
84        inner: F,
85        #[pin]
86        sleep: Option<Sleep>,
87    }
88}
89
90impl<F, ResponseBody, E> Future for ResponseFuture<F>
91where
92    F: Future<Output = Result<Response<ResponseBody>, E>>,
93{
94    type Output = Result<Response<MaybeEmptyBody<ResponseBody>>, E>;
95
96    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97        let this = self.project();
98
99        if let Poll::Ready(result) = this.inner.poll(cx) {
100            return Poll::Ready(result.map(|response| response.map(MaybeEmptyBody::full)));
101        }
102
103        if let Some(sleep) = this.sleep.as_pin_mut() {
104            ready!(sleep.poll(cx));
105            let mut response = Response::new(MaybeEmptyBody::empty());
106
107            response
108                .headers_mut()
109                .insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
110            response
111                .headers_mut()
112                .insert(GRPC_STATUS_HEADER, GRPC_DEADLINE_EXCEEDED_CODE);
113            response
114                .headers_mut()
115                .insert(GRPC_MESSAGE_HEADER, GRPC_DEADLINE_EXCEEDED_MESSAGE);
116
117            return Poll::Ready(Ok(response));
118        }
119
120        Poll::Pending
121    }
122}
123
124pin_project! {
125    pub struct MaybeEmptyBody<B> {
126        #[pin]
127        inner: Option<B>,
128    }
129}
130
131impl<B> MaybeEmptyBody<B> {
132    fn full(inner: B) -> Self {
133        Self { inner: Some(inner) }
134    }
135
136    fn empty() -> Self {
137        Self { inner: None }
138    }
139}
140
141impl<B> http_body::Body for MaybeEmptyBody<B>
142where
143    B: http_body::Body + Send,
144{
145    type Data = B::Data;
146    type Error = B::Error;
147
148    fn poll_frame(
149        self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
152        match self.project().inner.as_pin_mut() {
153            Some(b) => b.poll_frame(cx),
154            None => Poll::Ready(None),
155        }
156    }
157
158    fn is_end_stream(&self) -> bool {
159        match &self.inner {
160            Some(b) => b.is_end_stream(),
161            None => true,
162        }
163    }
164
165    fn size_hint(&self) -> http_body::SizeHint {
166        match &self.inner {
167            Some(body) => body.size_hint(),
168            None => http_body::SizeHint::with_exact(0),
169        }
170    }
171}
172
173const SECONDS_IN_HOUR: u64 = 60 * 60;
174const SECONDS_IN_MINUTE: u64 = 60;
175
176/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns
177/// the value we attempted to parse.
178///
179/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
180fn try_parse_grpc_timeout(
181    headers: &HeaderMap<HeaderValue>,
182) -> Result<Option<Duration>, &HeaderValue> {
183    let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
184        return Ok(None);
185    };
186
187    let (timeout_value, timeout_unit) = val
188        .to_str()
189        .map_err(|_| val)
190        .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
191        // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
192        // `split_at` will never panic from trying to split in the middle of a character.
193        // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str
194        //
195        // `len - 1` also wont panic since we just checked `s.is_empty`.
196        .split_at(val.len() - 1);
197
198    // gRPC spec specifies `TimeoutValue` will be at most 8 digits
199    // Caping this at 8 digits also prevents integer overflow from ever occurring
200    if timeout_value.len() > 8 {
201        return Err(val);
202    }
203
204    let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
205
206    let duration = match timeout_unit {
207        // Hours
208        "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
209        // Minutes
210        "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
211        // Seconds
212        "S" => Duration::from_secs(timeout_value),
213        // Milliseconds
214        "m" => Duration::from_millis(timeout_value),
215        // Microseconds
216        "u" => Duration::from_micros(timeout_value),
217        // Nanoseconds
218        "n" => Duration::from_nanos(timeout_value),
219        _ => return Err(val),
220    };
221
222    Ok(Some(duration))
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    // Helper function to reduce the boiler plate of our test cases
230    fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
231        let mut hm = HeaderMap::new();
232        if let Some(v) = val {
233            let hv = HeaderValue::from_str(v).unwrap();
234            hm.insert(GRPC_TIMEOUT_HEADER, hv);
235        };
236
237        try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
238    }
239
240    #[test]
241    fn test_hours() {
242        let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
243        assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
244    }
245
246    #[test]
247    fn test_minutes() {
248        let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
249        assert_eq!(Duration::from_secs(60), parsed_duration);
250    }
251
252    #[test]
253    fn test_seconds() {
254        let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
255        assert_eq!(Duration::from_secs(42), parsed_duration);
256    }
257
258    #[test]
259    fn test_milliseconds() {
260        let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
261        assert_eq!(Duration::from_millis(13), parsed_duration);
262    }
263
264    #[test]
265    fn test_microseconds() {
266        let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
267        assert_eq!(Duration::from_micros(2), parsed_duration);
268    }
269
270    #[test]
271    fn test_nanoseconds() {
272        let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
273        assert_eq!(Duration::from_nanos(82), parsed_duration);
274    }
275
276    #[test]
277    fn test_header_not_present() {
278        let parsed_duration = setup_map_try_parse(None).unwrap();
279        assert!(parsed_duration.is_none());
280    }
281
282    #[test]
283    #[should_panic(expected = "82f")]
284    fn test_invalid_unit() {
285        // "f" is not a valid TimeoutUnit
286        setup_map_try_parse(Some("82f")).unwrap().unwrap();
287    }
288
289    #[test]
290    #[should_panic(expected = "123456789H")]
291    fn test_too_many_digits() {
292        // gRPC spec states TimeoutValue will be at most 8 digits
293        setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
294    }
295
296    #[test]
297    #[should_panic(expected = "oneH")]
298    fn test_invalid_digits() {
299        // gRPC spec states TimeoutValue will be at most 8 digits
300        setup_map_try_parse(Some("oneH")).unwrap().unwrap();
301    }
302}