mysten_network/
grpc_timeout.rs1use http::{HeaderMap, HeaderValue, Request, Response};
8use pin_project_lite::pin_project;
9use std::{
10 future::Future,
11 pin::Pin,
12 task::{Context, Poll, ready},
13 time::Duration,
14};
15use tokio::time::Sleep;
16use tonic::Status;
17use tower::Service;
18
19const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";
20
21#[derive(Debug, Clone)]
23pub struct GrpcTimeout<S> {
24 inner: S,
25 server_timeout: Duration,
27}
28
29impl<S> GrpcTimeout<S> {
30 pub fn new(inner: S, server_timeout: Duration) -> Self {
31 Self {
32 inner,
33 server_timeout,
34 }
35 }
36}
37
38impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for GrpcTimeout<S>
39where
40 S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
41{
42 type Response = Response<MaybeEmptyBody<ResponseBody>>;
43 type Error = S::Error;
44 type Future = ResponseFuture<S::Future>;
45
46 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
47 self.inner.poll_ready(cx).map_err(Into::into)
48 }
49
50 fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
51 let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
52 tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
53 None
54 });
55
56 let resp_timeout = match client_timeout {
58 None => self.server_timeout,
59 Some(d) => self.server_timeout.min(d),
60 };
61
62 ResponseFuture {
63 inner: self.inner.call(req),
64 sleep: tokio::time::sleep(resp_timeout),
65 }
66 }
67}
68
69pin_project! {
70 pub struct ResponseFuture<F> {
71 #[pin]
72 inner: F,
73 #[pin]
74 sleep: Sleep,
75 }
76}
77
78impl<F, ResponseBody, E> Future for ResponseFuture<F>
79where
80 F: Future<Output = Result<Response<ResponseBody>, E>>,
81{
82 type Output = Result<Response<MaybeEmptyBody<ResponseBody>>, E>;
83
84 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85 let this = self.project();
86
87 if let Poll::Ready(result) = this.inner.poll(cx) {
88 return Poll::Ready(result.map(|response| response.map(MaybeEmptyBody::full)));
89 }
90
91 ready!(this.sleep.poll(cx));
92 let response = Status::deadline_exceeded("Timeout expired")
93 .into_http()
94 .map(|()| MaybeEmptyBody::empty());
95 Poll::Ready(Ok(response))
96 }
97}
98
99pin_project! {
100 pub struct MaybeEmptyBody<B> {
101 #[pin]
102 inner: Option<B>,
103 }
104}
105
106impl<B> MaybeEmptyBody<B> {
107 fn full(inner: B) -> Self {
108 Self { inner: Some(inner) }
109 }
110
111 fn empty() -> Self {
112 Self { inner: None }
113 }
114}
115
116impl<B> http_body::Body for MaybeEmptyBody<B>
117where
118 B: http_body::Body + Send,
119{
120 type Data = B::Data;
121 type Error = B::Error;
122
123 fn poll_frame(
124 self: Pin<&mut Self>,
125 cx: &mut Context<'_>,
126 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
127 match self.project().inner.as_pin_mut() {
128 Some(b) => b.poll_frame(cx),
129 None => Poll::Ready(None),
130 }
131 }
132
133 fn is_end_stream(&self) -> bool {
134 match &self.inner {
135 Some(b) => b.is_end_stream(),
136 None => true,
137 }
138 }
139
140 fn size_hint(&self) -> http_body::SizeHint {
141 match &self.inner {
142 Some(body) => body.size_hint(),
143 None => http_body::SizeHint::with_exact(0),
144 }
145 }
146}
147
148const SECONDS_IN_HOUR: u64 = 60 * 60;
149const SECONDS_IN_MINUTE: u64 = 60;
150
151fn try_parse_grpc_timeout(
156 headers: &HeaderMap<HeaderValue>,
157) -> Result<Option<Duration>, &HeaderValue> {
158 let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
159 return Ok(None);
160 };
161
162 let (timeout_value, timeout_unit) = val
163 .to_str()
164 .map_err(|_| val)
165 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
166 .split_at(val.len() - 1);
172
173 if timeout_value.len() > 8 {
176 return Err(val);
177 }
178
179 let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
180
181 let duration = match timeout_unit {
182 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
184 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
186 "S" => Duration::from_secs(timeout_value),
188 "m" => Duration::from_millis(timeout_value),
190 "u" => Duration::from_micros(timeout_value),
192 "n" => Duration::from_nanos(timeout_value),
194 _ => return Err(val),
195 };
196
197 Ok(Some(duration))
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
206 let mut hm = HeaderMap::new();
207 if let Some(v) = val {
208 let hv = HeaderValue::from_str(v).unwrap();
209 hm.insert(GRPC_TIMEOUT_HEADER, hv);
210 };
211
212 try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
213 }
214
215 #[test]
216 fn test_hours() {
217 let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
218 assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
219 }
220
221 #[test]
222 fn test_minutes() {
223 let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
224 assert_eq!(Duration::from_secs(60), parsed_duration);
225 }
226
227 #[test]
228 fn test_seconds() {
229 let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
230 assert_eq!(Duration::from_secs(42), parsed_duration);
231 }
232
233 #[test]
234 fn test_milliseconds() {
235 let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
236 assert_eq!(Duration::from_millis(13), parsed_duration);
237 }
238
239 #[test]
240 fn test_microseconds() {
241 let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
242 assert_eq!(Duration::from_micros(2), parsed_duration);
243 }
244
245 #[test]
246 fn test_nanoseconds() {
247 let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
248 assert_eq!(Duration::from_nanos(82), parsed_duration);
249 }
250
251 #[test]
252 fn test_header_not_present() {
253 let parsed_duration = setup_map_try_parse(None).unwrap();
254 assert!(parsed_duration.is_none());
255 }
256
257 #[test]
258 #[should_panic(expected = "82f")]
259 fn test_invalid_unit() {
260 setup_map_try_parse(Some("82f")).unwrap().unwrap();
262 }
263
264 #[test]
265 #[should_panic(expected = "123456789H")]
266 fn test_too_many_digits() {
267 setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
269 }
270
271 #[test]
272 #[should_panic(expected = "oneH")]
273 fn test_invalid_digits() {
274 setup_map_try_parse(Some("oneH")).unwrap().unwrap();
276 }
277}