1use std::convert::Infallible;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use eyre::{Result, eyre};
9use mysten_network::{
10 config::Config,
11 metrics::{
12 DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
13 MetricsHandler,
14 },
15 multiaddr::{Multiaddr, Protocol},
16};
17use tokio_rustls::rustls::ServerConfig;
18use tonic::codegen::http::HeaderValue;
19use tonic::{
20 body::Body,
21 codegen::http::{Request, Response},
22 server::NamedService,
23};
24use tower::{Layer, Service, ServiceBuilder, ServiceExt};
25use tower_http::propagate_header::PropagateHeaderLayer;
26use tower_http::set_header::SetRequestHeaderLayer;
27use tower_http::trace::TraceLayer;
28
29pub const DEFAULT_GRPC_REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32 config: Config,
33 metrics_provider: M,
34 router: tonic::service::Routes,
35 health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40 let (health_reporter, health_service) = tonic_health::server::health_reporter();
41 let router = tonic::service::Routes::new(health_service);
42
43 Self {
44 config: config.to_owned(),
45 metrics_provider,
46 router,
47 health_reporter,
48 }
49 }
50
51 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52 self.health_reporter.clone()
53 }
54
55 pub fn add_service<S>(mut self, svc: S) -> Self
57 where
58 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59 + NamedService
60 + Clone
61 + Send
62 + Sync
63 + 'static,
64 S::Future: Send + 'static,
65 {
66 self.router = self.router.add_service(svc);
67 self
68 }
69
70 pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71 let http_config = self
72 .config
73 .http_config()
74 .allow_insecure(true);
77
78 let request_timeout = self
79 .config
80 .request_timeout
81 .unwrap_or(DEFAULT_GRPC_REQUEST_TIMEOUT);
82 let metrics_provider = self.metrics_provider;
83 let metrics = MetricsHandler::new(metrics_provider.clone());
84 let request_metrics = TraceLayer::new_for_grpc()
85 .on_request(metrics.clone())
86 .on_response(metrics.clone())
87 .on_failure(metrics);
88
89 fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
90 let path = request.uri().path();
91 Some(HeaderValue::from_str(path).unwrap())
92 }
93
94 let limiting_layers = ServiceBuilder::new()
95 .option_layer(
96 self.config
97 .load_shed
98 .unwrap_or_default()
99 .then_some(tower::load_shed::LoadShedLayer::new()),
100 )
101 .option_layer(
102 self.config
103 .global_concurrency_limit
104 .map(tower::limit::GlobalConcurrencyLimitLayer::new),
105 );
106 let route_layers = ServiceBuilder::new()
107 .map_request(|mut request: http::Request<_>| {
108 if let Some(connect_info) = request.extensions().get::<sui_http::ConnectInfo>() {
109 let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
110 local_addr: Some(connect_info.local_addr),
111 remote_addr: Some(connect_info.remote_addr),
112 };
113 request.extensions_mut().insert(tonic_connect_info);
114 }
115 request
116 })
117 .layer(RequestLifetimeLayer { metrics_provider })
118 .layer(SetRequestHeaderLayer::overriding(
119 GRPC_ENDPOINT_PATH_HEADER.clone(),
120 add_path_to_request_header,
121 ))
122 .layer(request_metrics)
123 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
124 .layer_fn(move |service| {
125 mysten_network::grpc_timeout::GrpcTimeout::new(service, request_timeout)
126 });
127
128 let mut builder = sui_http::Builder::new().config(http_config);
129
130 if let Some(tls_config) = tls_config {
131 builder = builder.tls_config(tls_config);
132 }
133
134 let server_handle = builder
135 .serve(
136 addr,
137 limiting_layers.service(
138 self.router
139 .into_axum_router()
140 .layer(route_layers)
141 .into_service()
142 .map_err(tower::BoxError::from),
143 ),
144 )
145 .map_err(|e| eyre!(e))?;
146
147 let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
148 Ok(Server {
149 server: server_handle,
150 local_addr,
151 health_reporter: self.health_reporter,
152 })
153 }
154}
155
156pub const SUI_TLS_SERVER_NAME: &str = "sui";
158
159pub struct Server {
160 server: sui_http::ServerHandle,
161 local_addr: Multiaddr,
162 health_reporter: tonic_health::server::HealthReporter,
163}
164
165impl Server {
166 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
167 self.server.wait_for_shutdown().await;
168 Ok(())
169 }
170
171 pub fn local_addr(&self) -> &Multiaddr {
172 &self.local_addr
173 }
174
175 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
176 self.health_reporter.clone()
177 }
178
179 pub fn handle(&self) -> &sui_http::ServerHandle {
180 &self.server
181 }
182}
183
184fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
185 addr.replace(1, |protocol| {
186 if let Protocol::Tcp(_) = protocol {
187 Some(Protocol::Tcp(port))
188 } else {
189 panic!("expected tcp protocol at index 1");
190 }
191 })
192 .expect("tcp protocol at index 1")
193}
194
195#[derive(Clone)]
196struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
197 metrics_provider: M,
198}
199
200impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
201 type Service = RequestLifetime<M, S>;
202
203 fn layer(&self, inner: S) -> Self::Service {
204 RequestLifetime {
205 inner,
206 metrics_provider: self.metrics_provider.clone(),
207 path: None,
208 }
209 }
210}
211
212#[derive(Clone)]
213struct RequestLifetime<M: MetricsCallbackProvider, S> {
214 inner: S,
215 metrics_provider: M,
216 path: Option<String>,
217}
218
219impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
220 for RequestLifetime<M, S>
221where
222 S: Service<Request<RequestBody>>,
223{
224 type Response = S::Response;
225 type Error = S::Error;
226 type Future = S::Future;
227
228 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
229 self.inner.poll_ready(cx)
230 }
231
232 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
233 if self.path.is_none() {
234 let path = request.uri().path().to_string();
235 self.metrics_provider.on_start(&path);
236 self.path = Some(path);
237 }
238 self.inner.call(request)
239 }
240}
241
242impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
243 fn drop(&mut self) {
244 if let Some(path) = &self.path {
245 self.metrics_provider.on_drop(path)
246 }
247 }
248}
249
250#[cfg(test)]
251mod test {
252 use fastcrypto::ed25519::Ed25519KeyPair;
253 use fastcrypto::traits::KeyPair;
254 use mysten_network::Multiaddr;
255 use mysten_network::config::Config;
256 use mysten_network::metrics::MetricsCallbackProvider;
257 use std::ops::Deref;
258 use std::sync::{Arc, Mutex};
259 use std::time::Duration;
260 use tonic::Code;
261 use tonic_health::pb::HealthCheckRequest;
262 use tonic_health::pb::health_client::HealthClient;
263
264 #[tokio::test]
265 async fn test_metrics_layer_successful() {
266 #[derive(Clone)]
267 struct Metrics {
268 metrics_called: Arc<Mutex<bool>>,
271 }
272
273 impl MetricsCallbackProvider for Metrics {
274 fn on_request(&self, path: String) {
275 assert_eq!(path, "/grpc.health.v1.Health/Check");
276 }
277
278 fn on_response(
279 &self,
280 path: String,
281 _latency: Duration,
282 status: u16,
283 grpc_status_code: Code,
284 ) {
285 assert_eq!(path, "/grpc.health.v1.Health/Check");
286 assert_eq!(status, 200);
287 assert_eq!(grpc_status_code, Code::Ok);
288 let mut m = self.metrics_called.lock().unwrap();
289 *m = true
290 }
291 }
292
293 let metrics = Metrics {
294 metrics_called: Arc::new(Mutex::new(false)),
295 };
296
297 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
298 let config = Config::new();
299 let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
300
301 let server = super::ServerBuilder::from_config(&config, metrics.clone())
302 .bind(
303 &address,
304 Some(sui_tls::create_rustls_server_config(
305 keypair.copy().private(),
306 "test".to_string(),
307 )),
308 )
309 .await
310 .unwrap();
311
312 let address = server.local_addr().to_owned();
313 let channel = config
314 .connect(
315 &address,
316 sui_tls::create_rustls_client_config(
317 keypair.public().to_owned(),
318 "test".to_string(),
319 None,
320 ),
321 )
322 .await
323 .unwrap();
324 let mut client = HealthClient::new(channel);
325
326 client
327 .check(HealthCheckRequest {
328 service: "".to_owned(),
329 })
330 .await
331 .unwrap();
332
333 server.server.shutdown().await;
334
335 assert!(metrics.metrics_called.lock().unwrap().deref());
336 }
337
338 #[tokio::test]
339 async fn test_metrics_layer_error() {
340 #[derive(Clone)]
341 struct Metrics {
342 metrics_called: Arc<Mutex<bool>>,
345 }
346
347 impl MetricsCallbackProvider for Metrics {
348 fn on_request(&self, path: String) {
349 assert_eq!(path, "/grpc.health.v1.Health/Check");
350 }
351
352 fn on_response(
353 &self,
354 path: String,
355 _latency: Duration,
356 status: u16,
357 grpc_status_code: Code,
358 ) {
359 assert_eq!(path, "/grpc.health.v1.Health/Check");
360 assert_eq!(status, 200);
361 assert_eq!(grpc_status_code, Code::NotFound);
364 let mut m = self.metrics_called.lock().unwrap();
365 *m = true
366 }
367 }
368
369 let metrics = Metrics {
370 metrics_called: Arc::new(Mutex::new(false)),
371 };
372
373 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
374 let config = Config::new();
375 let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
376
377 let server = super::ServerBuilder::from_config(&config, metrics.clone())
378 .bind(
379 &address,
380 Some(sui_tls::create_rustls_server_config(
381 keypair.copy().private(),
382 "test".to_string(),
383 )),
384 )
385 .await
386 .unwrap();
387 let address = server.local_addr().to_owned();
388 let channel = config
389 .connect(
390 &address,
391 sui_tls::create_rustls_client_config(
392 keypair.public().to_owned(),
393 "test".to_string(),
394 None,
395 ),
396 )
397 .await
398 .unwrap();
399 let mut client = HealthClient::new(channel);
400
401 let _ = client
405 .check(HealthCheckRequest {
406 service: "non-existing-service".to_owned(),
407 })
408 .await;
409
410 server.server.shutdown().await;
411
412 assert!(metrics.metrics_called.lock().unwrap().deref());
413 }
414
415 async fn test_multiaddr(address: Multiaddr) {
416 let config = Config::new();
417 let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
418
419 let server_handle = super::ServerBuilder::from_config(
420 &config,
421 mysten_network::metrics::DefaultMetricsCallbackProvider::default(),
422 )
423 .bind(
424 &address,
425 Some(sui_tls::create_rustls_server_config(
426 keypair.copy().private(),
427 "test".to_string(),
428 )),
429 )
430 .await
431 .unwrap();
432 let address = server_handle.local_addr().to_owned();
433 let channel = config
434 .connect(
435 &address,
436 sui_tls::create_rustls_client_config(
437 keypair.public().to_owned(),
438 "test".to_string(),
439 None,
440 ),
441 )
442 .await
443 .unwrap();
444 let mut client = HealthClient::new(channel);
445
446 client
447 .check(HealthCheckRequest {
448 service: "".to_owned(),
449 })
450 .await
451 .unwrap();
452
453 server_handle.server.shutdown().await;
454 }
455
456 #[tokio::test]
457 async fn dns() {
458 let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
459 test_multiaddr(address).await;
460 }
461
462 #[tokio::test]
463 async fn ip4() {
464 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
465 test_multiaddr(address).await;
466 }
467
468 #[tokio::test]
469 async fn ip6() {
470 let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
471 test_multiaddr(address).await;
472 }
473}