sui_http/middleware/
grpc_timeout.rs1use http::HeaderMap;
8use http::HeaderName;
9use http::HeaderValue;
10use http::Request;
11use http::Response;
12use pin_project_lite::pin_project;
13use std::future::Future;
14use std::pin::Pin;
15use std::task::Context;
16use std::task::Poll;
17use std::task::ready;
18use std::time::Duration;
19use tokio::time::Sleep;
20use tower::Service;
21
22const GRPC_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("grpc-timeout");
23const GRPC_STATUS_HEADER: HeaderName = HeaderName::from_static("grpc-status");
24const GRPC_MESSAGE_HEADER: HeaderName = HeaderName::from_static("grpc-message");
25
26const GRPC_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/grpc");
27const GRPC_DEADLINE_EXCEEDED_CODE: HeaderValue = HeaderValue::from_static("4");
28const GRPC_DEADLINE_EXCEEDED_MESSAGE: HeaderValue = HeaderValue::from_static("Timeout%20expired");
29
30#[derive(Debug, Clone)]
31pub struct GrpcTimeout<S> {
32 inner: S,
33 server_timeout: Option<Duration>,
34}
35
36impl<S> GrpcTimeout<S> {
37 pub fn new(inner: S, server_timeout: Option<Duration>) -> Self {
38 Self {
39 inner,
40 server_timeout,
41 }
42 }
43}
44
45impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for GrpcTimeout<S>
46where
47 S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
48{
49 type Response = Response<MaybeEmptyBody<ResponseBody>>;
50 type Error = S::Error;
51 type Future = ResponseFuture<S::Future>;
52
53 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54 self.inner.poll_ready(cx).map_err(Into::into)
55 }
56
57 fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
58 let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
59 tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
60 None
61 });
62
63 let timeout_duration = match (client_timeout, self.server_timeout) {
65 (None, None) => None,
66 (Some(dur), None) => Some(dur),
67 (None, Some(dur)) => Some(dur),
68 (Some(header), Some(server)) => {
69 let shorter_duration = std::cmp::min(header, server);
70 Some(shorter_duration)
71 }
72 };
73
74 ResponseFuture {
75 inner: self.inner.call(req),
76 sleep: timeout_duration.map(tokio::time::sleep),
77 }
78 }
79}
80
81pin_project! {
82 pub struct ResponseFuture<F> {
83 #[pin]
84 inner: F,
85 #[pin]
86 sleep: Option<Sleep>,
87 }
88}
89
90impl<F, ResponseBody, E> Future for ResponseFuture<F>
91where
92 F: Future<Output = Result<Response<ResponseBody>, E>>,
93{
94 type Output = Result<Response<MaybeEmptyBody<ResponseBody>>, E>;
95
96 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97 let this = self.project();
98
99 if let Poll::Ready(result) = this.inner.poll(cx) {
100 return Poll::Ready(result.map(|response| response.map(MaybeEmptyBody::full)));
101 }
102
103 if let Some(sleep) = this.sleep.as_pin_mut() {
104 ready!(sleep.poll(cx));
105 let mut response = Response::new(MaybeEmptyBody::empty());
106
107 response
108 .headers_mut()
109 .insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
110 response
111 .headers_mut()
112 .insert(GRPC_STATUS_HEADER, GRPC_DEADLINE_EXCEEDED_CODE);
113 response
114 .headers_mut()
115 .insert(GRPC_MESSAGE_HEADER, GRPC_DEADLINE_EXCEEDED_MESSAGE);
116
117 return Poll::Ready(Ok(response));
118 }
119
120 Poll::Pending
121 }
122}
123
124pin_project! {
125 pub struct MaybeEmptyBody<B> {
126 #[pin]
127 inner: Option<B>,
128 }
129}
130
131impl<B> MaybeEmptyBody<B> {
132 fn full(inner: B) -> Self {
133 Self { inner: Some(inner) }
134 }
135
136 fn empty() -> Self {
137 Self { inner: None }
138 }
139}
140
141impl<B> http_body::Body for MaybeEmptyBody<B>
142where
143 B: http_body::Body + Send,
144{
145 type Data = B::Data;
146 type Error = B::Error;
147
148 fn poll_frame(
149 self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
152 match self.project().inner.as_pin_mut() {
153 Some(b) => b.poll_frame(cx),
154 None => Poll::Ready(None),
155 }
156 }
157
158 fn is_end_stream(&self) -> bool {
159 match &self.inner {
160 Some(b) => b.is_end_stream(),
161 None => true,
162 }
163 }
164
165 fn size_hint(&self) -> http_body::SizeHint {
166 match &self.inner {
167 Some(body) => body.size_hint(),
168 None => http_body::SizeHint::with_exact(0),
169 }
170 }
171}
172
173const SECONDS_IN_HOUR: u64 = 60 * 60;
174const SECONDS_IN_MINUTE: u64 = 60;
175
176fn try_parse_grpc_timeout(
181 headers: &HeaderMap<HeaderValue>,
182) -> Result<Option<Duration>, &HeaderValue> {
183 let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
184 return Ok(None);
185 };
186
187 let (timeout_value, timeout_unit) = val
188 .to_str()
189 .map_err(|_| val)
190 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
191 .split_at(val.len() - 1);
197
198 if timeout_value.len() > 8 {
201 return Err(val);
202 }
203
204 let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
205
206 let duration = match timeout_unit {
207 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
209 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
211 "S" => Duration::from_secs(timeout_value),
213 "m" => Duration::from_millis(timeout_value),
215 "u" => Duration::from_micros(timeout_value),
217 "n" => Duration::from_nanos(timeout_value),
219 _ => return Err(val),
220 };
221
222 Ok(Some(duration))
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
231 let mut hm = HeaderMap::new();
232 if let Some(v) = val {
233 let hv = HeaderValue::from_str(v).unwrap();
234 hm.insert(GRPC_TIMEOUT_HEADER, hv);
235 };
236
237 try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
238 }
239
240 #[test]
241 fn test_hours() {
242 let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
243 assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
244 }
245
246 #[test]
247 fn test_minutes() {
248 let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
249 assert_eq!(Duration::from_secs(60), parsed_duration);
250 }
251
252 #[test]
253 fn test_seconds() {
254 let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
255 assert_eq!(Duration::from_secs(42), parsed_duration);
256 }
257
258 #[test]
259 fn test_milliseconds() {
260 let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
261 assert_eq!(Duration::from_millis(13), parsed_duration);
262 }
263
264 #[test]
265 fn test_microseconds() {
266 let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
267 assert_eq!(Duration::from_micros(2), parsed_duration);
268 }
269
270 #[test]
271 fn test_nanoseconds() {
272 let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
273 assert_eq!(Duration::from_nanos(82), parsed_duration);
274 }
275
276 #[test]
277 fn test_header_not_present() {
278 let parsed_duration = setup_map_try_parse(None).unwrap();
279 assert!(parsed_duration.is_none());
280 }
281
282 #[test]
283 #[should_panic(expected = "82f")]
284 fn test_invalid_unit() {
285 setup_map_try_parse(Some("82f")).unwrap().unwrap();
287 }
288
289 #[test]
290 #[should_panic(expected = "123456789H")]
291 fn test_too_many_digits() {
292 setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
294 }
295
296 #[test]
297 #[should_panic(expected = "oneH")]
298 fn test_invalid_digits() {
299 setup_map_try_parse(Some("oneH")).unwrap().unwrap();
301 }
302}