1use std::convert::Infallible;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use eyre::{Result, eyre};
9use mysten_network::{
10 config::Config,
11 metrics::{
12 DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
13 MetricsHandler,
14 },
15 multiaddr::{Multiaddr, Protocol},
16};
17use tokio_rustls::rustls::ServerConfig;
18use tonic::codegen::http::HeaderValue;
19use tonic::{
20 body::Body,
21 codegen::http::{Request, Response},
22 server::NamedService,
23};
24use tower::{Layer, Service, ServiceBuilder, ServiceExt};
25use tower_http::propagate_header::PropagateHeaderLayer;
26use tower_http::set_header::SetRequestHeaderLayer;
27use tower_http::trace::TraceLayer;
28
29pub const DEFAULT_GRPC_REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32 config: Config,
33 metrics_provider: M,
34 router: tonic::service::Routes,
35 health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40 let (health_reporter, health_service) = tonic_health::server::health_reporter();
41 let router = tonic::service::Routes::new(health_service);
42
43 Self {
44 config: config.to_owned(),
45 metrics_provider,
46 router,
47 health_reporter,
48 }
49 }
50
51 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52 self.health_reporter.clone()
53 }
54
55 pub fn add_service<S>(mut self, svc: S) -> Self
57 where
58 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59 + NamedService
60 + Clone
61 + Send
62 + Sync
63 + 'static,
64 S::Future: Send + 'static,
65 {
66 self.router = self.router.add_service(svc);
67 self
68 }
69
70 pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71 let request_timeout = self
72 .config
73 .request_timeout
74 .unwrap_or(DEFAULT_GRPC_REQUEST_TIMEOUT);
75 let metrics_provider = self.metrics_provider;
76 let metrics = MetricsHandler::new(metrics_provider.clone());
77 let request_metrics = TraceLayer::new_for_grpc()
78 .on_request(metrics.clone())
79 .on_response(metrics.clone())
80 .on_failure(metrics);
81
82 fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
83 let path = request.uri().path();
84 Some(HeaderValue::from_str(path).unwrap())
85 }
86
87 let limiting_layers = ServiceBuilder::new()
88 .option_layer(
89 self.config
90 .load_shed
91 .unwrap_or_default()
92 .then_some(tower::load_shed::LoadShedLayer::new()),
93 )
94 .option_layer(
95 self.config
96 .global_concurrency_limit
97 .map(tower::limit::GlobalConcurrencyLimitLayer::new),
98 );
99 let route_layers = ServiceBuilder::new()
100 .map_request(|mut request: http::Request<_>| {
101 if let Some(connect_info) = request.extensions().get::<sui_http::ConnectInfo>() {
102 let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
103 local_addr: Some(connect_info.local_addr),
104 remote_addr: Some(connect_info.remote_addr),
105 };
106 request.extensions_mut().insert(tonic_connect_info);
107 }
108 request
109 })
110 .layer(RequestLifetimeLayer { metrics_provider })
111 .layer(SetRequestHeaderLayer::overriding(
112 GRPC_ENDPOINT_PATH_HEADER.clone(),
113 add_path_to_request_header,
114 ))
115 .layer(request_metrics)
116 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
117 .layer_fn(move |service| {
118 mysten_network::grpc_timeout::GrpcTimeout::new(service, request_timeout)
119 });
120
121 let mut builder = sui_http::Builder::new().config(self.config.http_config());
122
123 if let Some(tls_config) = tls_config {
124 builder = builder.tls_config(tls_config);
125 }
126
127 let server_handle = builder
128 .serve(
129 addr,
130 limiting_layers.service(
131 self.router
132 .into_axum_router()
133 .layer(route_layers)
134 .into_service()
135 .map_err(tower::BoxError::from),
136 ),
137 )
138 .map_err(|e| eyre!(e))?;
139
140 let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
141 Ok(Server {
142 server: server_handle,
143 local_addr,
144 health_reporter: self.health_reporter,
145 })
146 }
147}
148
149pub const SUI_TLS_SERVER_NAME: &str = "sui";
151
152pub struct Server {
153 server: sui_http::ServerHandle,
154 local_addr: Multiaddr,
155 health_reporter: tonic_health::server::HealthReporter,
156}
157
158impl Server {
159 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
160 self.server.wait_for_shutdown().await;
161 Ok(())
162 }
163
164 pub fn local_addr(&self) -> &Multiaddr {
165 &self.local_addr
166 }
167
168 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
169 self.health_reporter.clone()
170 }
171
172 pub fn handle(&self) -> &sui_http::ServerHandle {
173 &self.server
174 }
175}
176
177fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
178 addr.replace(1, |protocol| {
179 if let Protocol::Tcp(_) = protocol {
180 Some(Protocol::Tcp(port))
181 } else {
182 panic!("expected tcp protocol at index 1");
183 }
184 })
185 .expect("tcp protocol at index 1")
186}
187
188#[derive(Clone)]
189struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
190 metrics_provider: M,
191}
192
193impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
194 type Service = RequestLifetime<M, S>;
195
196 fn layer(&self, inner: S) -> Self::Service {
197 RequestLifetime {
198 inner,
199 metrics_provider: self.metrics_provider.clone(),
200 path: None,
201 }
202 }
203}
204
205#[derive(Clone)]
206struct RequestLifetime<M: MetricsCallbackProvider, S> {
207 inner: S,
208 metrics_provider: M,
209 path: Option<String>,
210}
211
212impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
213 for RequestLifetime<M, S>
214where
215 S: Service<Request<RequestBody>>,
216{
217 type Response = S::Response;
218 type Error = S::Error;
219 type Future = S::Future;
220
221 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
222 self.inner.poll_ready(cx)
223 }
224
225 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
226 if self.path.is_none() {
227 let path = request.uri().path().to_string();
228 self.metrics_provider.on_start(&path);
229 self.path = Some(path);
230 }
231 self.inner.call(request)
232 }
233}
234
235impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
236 fn drop(&mut self) {
237 if let Some(path) = &self.path {
238 self.metrics_provider.on_drop(path)
239 }
240 }
241}
242
243#[cfg(test)]
244mod test {
245 use fastcrypto::ed25519::Ed25519KeyPair;
246 use fastcrypto::traits::KeyPair;
247 use mysten_network::Multiaddr;
248 use mysten_network::config::Config;
249 use mysten_network::metrics::MetricsCallbackProvider;
250 use std::ops::Deref;
251 use std::sync::{Arc, Mutex};
252 use std::time::Duration;
253 use tonic::Code;
254 use tonic_health::pb::HealthCheckRequest;
255 use tonic_health::pb::health_client::HealthClient;
256
257 #[tokio::test]
258 async fn test_metrics_layer_successful() {
259 #[derive(Clone)]
260 struct Metrics {
261 metrics_called: Arc<Mutex<bool>>,
264 }
265
266 impl MetricsCallbackProvider for Metrics {
267 fn on_request(&self, path: String) {
268 assert_eq!(path, "/grpc.health.v1.Health/Check");
269 }
270
271 fn on_response(
272 &self,
273 path: String,
274 _latency: Duration,
275 status: u16,
276 grpc_status_code: Code,
277 ) {
278 assert_eq!(path, "/grpc.health.v1.Health/Check");
279 assert_eq!(status, 200);
280 assert_eq!(grpc_status_code, Code::Ok);
281 let mut m = self.metrics_called.lock().unwrap();
282 *m = true
283 }
284 }
285
286 let metrics = Metrics {
287 metrics_called: Arc::new(Mutex::new(false)),
288 };
289
290 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/https".parse().unwrap();
291 let config = Config::new();
292 let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
293
294 let server = super::ServerBuilder::from_config(&config, metrics.clone())
295 .bind(
296 &address,
297 Some(sui_tls::create_rustls_server_config(
298 keypair.copy().private(),
299 "test".to_string(),
300 )),
301 )
302 .await
303 .unwrap();
304
305 let address = server.local_addr().to_owned();
306 let channel = config
307 .connect(
308 &address,
309 sui_tls::create_rustls_client_config(
310 keypair.public().to_owned(),
311 "test".to_string(),
312 None,
313 ),
314 )
315 .await
316 .unwrap();
317 let mut client = HealthClient::new(channel);
318
319 client
320 .check(HealthCheckRequest {
321 service: "".to_owned(),
322 })
323 .await
324 .unwrap();
325
326 server.server.shutdown().await;
327
328 assert!(metrics.metrics_called.lock().unwrap().deref());
329 }
330
331 #[tokio::test]
332 async fn test_metrics_layer_error() {
333 #[derive(Clone)]
334 struct Metrics {
335 metrics_called: Arc<Mutex<bool>>,
338 }
339
340 impl MetricsCallbackProvider for Metrics {
341 fn on_request(&self, path: String) {
342 assert_eq!(path, "/grpc.health.v1.Health/Check");
343 }
344
345 fn on_response(
346 &self,
347 path: String,
348 _latency: Duration,
349 status: u16,
350 grpc_status_code: Code,
351 ) {
352 assert_eq!(path, "/grpc.health.v1.Health/Check");
353 assert_eq!(status, 200);
354 assert_eq!(grpc_status_code, Code::NotFound);
357 let mut m = self.metrics_called.lock().unwrap();
358 *m = true
359 }
360 }
361
362 let metrics = Metrics {
363 metrics_called: Arc::new(Mutex::new(false)),
364 };
365
366 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/https".parse().unwrap();
367 let config = Config::new();
368 let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
369
370 let server = super::ServerBuilder::from_config(&config, metrics.clone())
371 .bind(
372 &address,
373 Some(sui_tls::create_rustls_server_config(
374 keypair.copy().private(),
375 "test".to_string(),
376 )),
377 )
378 .await
379 .unwrap();
380 let address = server.local_addr().to_owned();
381 let channel = config
382 .connect(
383 &address,
384 sui_tls::create_rustls_client_config(
385 keypair.public().to_owned(),
386 "test".to_string(),
387 None,
388 ),
389 )
390 .await
391 .unwrap();
392 let mut client = HealthClient::new(channel);
393
394 let _ = client
398 .check(HealthCheckRequest {
399 service: "non-existing-service".to_owned(),
400 })
401 .await;
402
403 server.server.shutdown().await;
404
405 assert!(metrics.metrics_called.lock().unwrap().deref());
406 }
407
408 async fn test_multiaddr(address: Multiaddr) {
409 let config = Config::new();
410 let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
411
412 let server_handle = super::ServerBuilder::from_config(
413 &config,
414 mysten_network::metrics::DefaultMetricsCallbackProvider::default(),
415 )
416 .bind(
417 &address,
418 Some(sui_tls::create_rustls_server_config(
419 keypair.copy().private(),
420 "test".to_string(),
421 )),
422 )
423 .await
424 .unwrap();
425 let address = server_handle.local_addr().to_owned();
426 let channel = config
427 .connect(
428 &address,
429 sui_tls::create_rustls_client_config(
430 keypair.public().to_owned(),
431 "test".to_string(),
432 None,
433 ),
434 )
435 .await
436 .unwrap();
437 let mut client = HealthClient::new(channel);
438
439 client
440 .check(HealthCheckRequest {
441 service: "".to_owned(),
442 })
443 .await
444 .unwrap();
445
446 server_handle.server.shutdown().await;
447 }
448
449 #[tokio::test]
450 async fn dns() {
451 let address: Multiaddr = "/dns/localhost/tcp/0/https".parse().unwrap();
452 test_multiaddr(address).await;
453 }
454
455 #[tokio::test]
456 async fn ip4() {
457 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/https".parse().unwrap();
458 test_multiaddr(address).await;
459 }
460
461 #[tokio::test]
462 async fn ip6() {
463 let address: Multiaddr = "/ip6/::1/tcp/0/https".parse().unwrap();
464 test_multiaddr(address).await;
465 }
466}