sui_network/validator/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::convert::Infallible;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use eyre::{Result, eyre};
9use mysten_network::{
10    config::Config,
11    metrics::{
12        DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
13        MetricsHandler,
14    },
15    multiaddr::{Multiaddr, Protocol},
16};
17use tokio_rustls::rustls::ServerConfig;
18use tonic::codegen::http::HeaderValue;
19use tonic::{
20    body::Body,
21    codegen::http::{Request, Response},
22    server::NamedService,
23};
24use tower::{Layer, Service, ServiceBuilder, ServiceExt};
25use tower_http::propagate_header::PropagateHeaderLayer;
26use tower_http::set_header::SetRequestHeaderLayer;
27use tower_http::trace::TraceLayer;
28
29pub const DEFAULT_GRPC_REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32    config: Config,
33    metrics_provider: M,
34    router: tonic::service::Routes,
35    health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40        let (health_reporter, health_service) = tonic_health::server::health_reporter();
41        let router = tonic::service::Routes::new(health_service);
42
43        Self {
44            config: config.to_owned(),
45            metrics_provider,
46            router,
47            health_reporter,
48        }
49    }
50
51    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52        self.health_reporter.clone()
53    }
54
55    /// Add a new service to this Server.
56    pub fn add_service<S>(mut self, svc: S) -> Self
57    where
58        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59            + NamedService
60            + Clone
61            + Send
62            + Sync
63            + 'static,
64        S::Future: Send + 'static,
65    {
66        self.router = self.router.add_service(svc);
67        self
68    }
69
70    pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71        let request_timeout = self
72            .config
73            .request_timeout
74            .unwrap_or(DEFAULT_GRPC_REQUEST_TIMEOUT);
75        let metrics_provider = self.metrics_provider;
76        let metrics = MetricsHandler::new(metrics_provider.clone());
77        let request_metrics = TraceLayer::new_for_grpc()
78            .on_request(metrics.clone())
79            .on_response(metrics.clone())
80            .on_failure(metrics);
81
82        fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
83            let path = request.uri().path();
84            Some(HeaderValue::from_str(path).unwrap())
85        }
86
87        let limiting_layers = ServiceBuilder::new()
88            .option_layer(
89                self.config
90                    .load_shed
91                    .unwrap_or_default()
92                    .then_some(tower::load_shed::LoadShedLayer::new()),
93            )
94            .option_layer(
95                self.config
96                    .global_concurrency_limit
97                    .map(tower::limit::GlobalConcurrencyLimitLayer::new),
98            );
99        let route_layers = ServiceBuilder::new()
100            .map_request(|mut request: http::Request<_>| {
101                if let Some(connect_info) = request.extensions().get::<sui_http::ConnectInfo>() {
102                    let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
103                        local_addr: Some(connect_info.local_addr),
104                        remote_addr: Some(connect_info.remote_addr),
105                    };
106                    request.extensions_mut().insert(tonic_connect_info);
107                }
108                request
109            })
110            .layer(RequestLifetimeLayer { metrics_provider })
111            .layer(SetRequestHeaderLayer::overriding(
112                GRPC_ENDPOINT_PATH_HEADER.clone(),
113                add_path_to_request_header,
114            ))
115            .layer(request_metrics)
116            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
117            .layer_fn(move |service| {
118                mysten_network::grpc_timeout::GrpcTimeout::new(service, request_timeout)
119            });
120
121        let mut builder = sui_http::Builder::new().config(self.config.http_config());
122
123        if let Some(tls_config) = tls_config {
124            builder = builder.tls_config(tls_config);
125        }
126
127        let server_handle = builder
128            .serve(
129                addr,
130                limiting_layers.service(
131                    self.router
132                        .into_axum_router()
133                        .layer(route_layers)
134                        .into_service()
135                        .map_err(tower::BoxError::from),
136                ),
137            )
138            .map_err(|e| eyre!(e))?;
139
140        let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
141        Ok(Server {
142            server: server_handle,
143            local_addr,
144            health_reporter: self.health_reporter,
145        })
146    }
147}
148
149/// TLS server name to use for the public Sui validator interface.
150pub const SUI_TLS_SERVER_NAME: &str = "sui";
151
152pub struct Server {
153    server: sui_http::ServerHandle,
154    local_addr: Multiaddr,
155    health_reporter: tonic_health::server::HealthReporter,
156}
157
158impl Server {
159    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
160        self.server.wait_for_shutdown().await;
161        Ok(())
162    }
163
164    pub fn local_addr(&self) -> &Multiaddr {
165        &self.local_addr
166    }
167
168    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
169        self.health_reporter.clone()
170    }
171
172    pub fn handle(&self) -> &sui_http::ServerHandle {
173        &self.server
174    }
175}
176
177fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
178    addr.replace(1, |protocol| {
179        if let Protocol::Tcp(_) = protocol {
180            Some(Protocol::Tcp(port))
181        } else {
182            panic!("expected tcp protocol at index 1");
183        }
184    })
185    .expect("tcp protocol at index 1")
186}
187
188#[derive(Clone)]
189struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
190    metrics_provider: M,
191}
192
193impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
194    type Service = RequestLifetime<M, S>;
195
196    fn layer(&self, inner: S) -> Self::Service {
197        RequestLifetime {
198            inner,
199            metrics_provider: self.metrics_provider.clone(),
200            path: None,
201        }
202    }
203}
204
205#[derive(Clone)]
206struct RequestLifetime<M: MetricsCallbackProvider, S> {
207    inner: S,
208    metrics_provider: M,
209    path: Option<String>,
210}
211
212impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
213    for RequestLifetime<M, S>
214where
215    S: Service<Request<RequestBody>>,
216{
217    type Response = S::Response;
218    type Error = S::Error;
219    type Future = S::Future;
220
221    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
222        self.inner.poll_ready(cx)
223    }
224
225    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
226        if self.path.is_none() {
227            let path = request.uri().path().to_string();
228            self.metrics_provider.on_start(&path);
229            self.path = Some(path);
230        }
231        self.inner.call(request)
232    }
233}
234
235impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
236    fn drop(&mut self) {
237        if let Some(path) = &self.path {
238            self.metrics_provider.on_drop(path)
239        }
240    }
241}
242
243#[cfg(test)]
244mod test {
245    use fastcrypto::ed25519::Ed25519KeyPair;
246    use fastcrypto::traits::KeyPair;
247    use mysten_network::Multiaddr;
248    use mysten_network::config::Config;
249    use mysten_network::metrics::MetricsCallbackProvider;
250    use std::ops::Deref;
251    use std::sync::{Arc, Mutex};
252    use std::time::Duration;
253    use tonic::Code;
254    use tonic_health::pb::HealthCheckRequest;
255    use tonic_health::pb::health_client::HealthClient;
256
257    #[tokio::test]
258    async fn test_metrics_layer_successful() {
259        #[derive(Clone)]
260        struct Metrics {
261            /// a flag to figure out whether the
262            /// on_request method has been called.
263            metrics_called: Arc<Mutex<bool>>,
264        }
265
266        impl MetricsCallbackProvider for Metrics {
267            fn on_request(&self, path: String) {
268                assert_eq!(path, "/grpc.health.v1.Health/Check");
269            }
270
271            fn on_response(
272                &self,
273                path: String,
274                _latency: Duration,
275                status: u16,
276                grpc_status_code: Code,
277            ) {
278                assert_eq!(path, "/grpc.health.v1.Health/Check");
279                assert_eq!(status, 200);
280                assert_eq!(grpc_status_code, Code::Ok);
281                let mut m = self.metrics_called.lock().unwrap();
282                *m = true
283            }
284        }
285
286        let metrics = Metrics {
287            metrics_called: Arc::new(Mutex::new(false)),
288        };
289
290        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/https".parse().unwrap();
291        let config = Config::new();
292        let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
293
294        let server = super::ServerBuilder::from_config(&config, metrics.clone())
295            .bind(
296                &address,
297                Some(sui_tls::create_rustls_server_config(
298                    keypair.copy().private(),
299                    "test".to_string(),
300                )),
301            )
302            .await
303            .unwrap();
304
305        let address = server.local_addr().to_owned();
306        let channel = config
307            .connect(
308                &address,
309                sui_tls::create_rustls_client_config(
310                    keypair.public().to_owned(),
311                    "test".to_string(),
312                    None,
313                ),
314            )
315            .await
316            .unwrap();
317        let mut client = HealthClient::new(channel);
318
319        client
320            .check(HealthCheckRequest {
321                service: "".to_owned(),
322            })
323            .await
324            .unwrap();
325
326        server.server.shutdown().await;
327
328        assert!(metrics.metrics_called.lock().unwrap().deref());
329    }
330
331    #[tokio::test]
332    async fn test_metrics_layer_error() {
333        #[derive(Clone)]
334        struct Metrics {
335            /// a flag to figure out whether the
336            /// on_request method has been called.
337            metrics_called: Arc<Mutex<bool>>,
338        }
339
340        impl MetricsCallbackProvider for Metrics {
341            fn on_request(&self, path: String) {
342                assert_eq!(path, "/grpc.health.v1.Health/Check");
343            }
344
345            fn on_response(
346                &self,
347                path: String,
348                _latency: Duration,
349                status: u16,
350                grpc_status_code: Code,
351            ) {
352                assert_eq!(path, "/grpc.health.v1.Health/Check");
353                assert_eq!(status, 200);
354                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
355                // code 5 is not_found , which is what we expect to get in this case
356                assert_eq!(grpc_status_code, Code::NotFound);
357                let mut m = self.metrics_called.lock().unwrap();
358                *m = true
359            }
360        }
361
362        let metrics = Metrics {
363            metrics_called: Arc::new(Mutex::new(false)),
364        };
365
366        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/https".parse().unwrap();
367        let config = Config::new();
368        let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
369
370        let server = super::ServerBuilder::from_config(&config, metrics.clone())
371            .bind(
372                &address,
373                Some(sui_tls::create_rustls_server_config(
374                    keypair.copy().private(),
375                    "test".to_string(),
376                )),
377            )
378            .await
379            .unwrap();
380        let address = server.local_addr().to_owned();
381        let channel = config
382            .connect(
383                &address,
384                sui_tls::create_rustls_client_config(
385                    keypair.public().to_owned(),
386                    "test".to_string(),
387                    None,
388                ),
389            )
390            .await
391            .unwrap();
392        let mut client = HealthClient::new(channel);
393
394        // Call the healthcheck for a service that doesn't exist
395        // that should give us back an error with code 5 (not_found)
396        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
397        let _ = client
398            .check(HealthCheckRequest {
399                service: "non-existing-service".to_owned(),
400            })
401            .await;
402
403        server.server.shutdown().await;
404
405        assert!(metrics.metrics_called.lock().unwrap().deref());
406    }
407
408    async fn test_multiaddr(address: Multiaddr) {
409        let config = Config::new();
410        let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
411
412        let server_handle = super::ServerBuilder::from_config(
413            &config,
414            mysten_network::metrics::DefaultMetricsCallbackProvider::default(),
415        )
416        .bind(
417            &address,
418            Some(sui_tls::create_rustls_server_config(
419                keypair.copy().private(),
420                "test".to_string(),
421            )),
422        )
423        .await
424        .unwrap();
425        let address = server_handle.local_addr().to_owned();
426        let channel = config
427            .connect(
428                &address,
429                sui_tls::create_rustls_client_config(
430                    keypair.public().to_owned(),
431                    "test".to_string(),
432                    None,
433                ),
434            )
435            .await
436            .unwrap();
437        let mut client = HealthClient::new(channel);
438
439        client
440            .check(HealthCheckRequest {
441                service: "".to_owned(),
442            })
443            .await
444            .unwrap();
445
446        server_handle.server.shutdown().await;
447    }
448
449    #[tokio::test]
450    async fn dns() {
451        let address: Multiaddr = "/dns/localhost/tcp/0/https".parse().unwrap();
452        test_multiaddr(address).await;
453    }
454
455    #[tokio::test]
456    async fn ip4() {
457        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/https".parse().unwrap();
458        test_multiaddr(address).await;
459    }
460
461    #[tokio::test]
462    async fn ip6() {
463        let address: Multiaddr = "/ip6/::1/tcp/0/https".parse().unwrap();
464        test_multiaddr(address).await;
465    }
466}