sui_network/validator/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::convert::Infallible;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use eyre::{Result, eyre};
9use mysten_network::{
10    config::Config,
11    metrics::{
12        DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
13        MetricsHandler,
14    },
15    multiaddr::{Multiaddr, Protocol},
16};
17use tokio_rustls::rustls::ServerConfig;
18use tonic::codegen::http::HeaderValue;
19use tonic::{
20    body::Body,
21    codegen::http::{Request, Response},
22    server::NamedService,
23};
24use tower::{Layer, Service, ServiceBuilder, ServiceExt};
25use tower_http::propagate_header::PropagateHeaderLayer;
26use tower_http::set_header::SetRequestHeaderLayer;
27use tower_http::trace::TraceLayer;
28
29pub const DEFAULT_GRPC_REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32    config: Config,
33    metrics_provider: M,
34    router: tonic::service::Routes,
35    health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40        let (health_reporter, health_service) = tonic_health::server::health_reporter();
41        let router = tonic::service::Routes::new(health_service);
42
43        Self {
44            config: config.to_owned(),
45            metrics_provider,
46            router,
47            health_reporter,
48        }
49    }
50
51    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52        self.health_reporter.clone()
53    }
54
55    /// Add a new service to this Server.
56    pub fn add_service<S>(mut self, svc: S) -> Self
57    where
58        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59            + NamedService
60            + Clone
61            + Send
62            + Sync
63            + 'static,
64        S::Future: Send + 'static,
65    {
66        self.router = self.router.add_service(svc);
67        self
68    }
69
70    pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71        let http_config = self
72            .config
73            .http_config()
74            // Temporarily continue allowing clients to connection without TLS even when the server
75            // is configured with a tls_config
76            .allow_insecure(true);
77
78        let request_timeout = self
79            .config
80            .request_timeout
81            .unwrap_or(DEFAULT_GRPC_REQUEST_TIMEOUT);
82        let metrics_provider = self.metrics_provider;
83        let metrics = MetricsHandler::new(metrics_provider.clone());
84        let request_metrics = TraceLayer::new_for_grpc()
85            .on_request(metrics.clone())
86            .on_response(metrics.clone())
87            .on_failure(metrics);
88
89        fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
90            let path = request.uri().path();
91            Some(HeaderValue::from_str(path).unwrap())
92        }
93
94        let limiting_layers = ServiceBuilder::new()
95            .option_layer(
96                self.config
97                    .load_shed
98                    .unwrap_or_default()
99                    .then_some(tower::load_shed::LoadShedLayer::new()),
100            )
101            .option_layer(
102                self.config
103                    .global_concurrency_limit
104                    .map(tower::limit::GlobalConcurrencyLimitLayer::new),
105            );
106        let route_layers = ServiceBuilder::new()
107            .map_request(|mut request: http::Request<_>| {
108                if let Some(connect_info) = request.extensions().get::<sui_http::ConnectInfo>() {
109                    let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
110                        local_addr: Some(connect_info.local_addr),
111                        remote_addr: Some(connect_info.remote_addr),
112                    };
113                    request.extensions_mut().insert(tonic_connect_info);
114                }
115                request
116            })
117            .layer(RequestLifetimeLayer { metrics_provider })
118            .layer(SetRequestHeaderLayer::overriding(
119                GRPC_ENDPOINT_PATH_HEADER.clone(),
120                add_path_to_request_header,
121            ))
122            .layer(request_metrics)
123            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
124            .layer_fn(move |service| {
125                mysten_network::grpc_timeout::GrpcTimeout::new(service, request_timeout)
126            });
127
128        let mut builder = sui_http::Builder::new().config(http_config);
129
130        if let Some(tls_config) = tls_config {
131            builder = builder.tls_config(tls_config);
132        }
133
134        let server_handle = builder
135            .serve(
136                addr,
137                limiting_layers.service(
138                    self.router
139                        .into_axum_router()
140                        .layer(route_layers)
141                        .into_service()
142                        .map_err(tower::BoxError::from),
143                ),
144            )
145            .map_err(|e| eyre!(e))?;
146
147        let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
148        Ok(Server {
149            server: server_handle,
150            local_addr,
151            health_reporter: self.health_reporter,
152        })
153    }
154}
155
156/// TLS server name to use for the public Sui validator interface.
157pub const SUI_TLS_SERVER_NAME: &str = "sui";
158
159pub struct Server {
160    server: sui_http::ServerHandle,
161    local_addr: Multiaddr,
162    health_reporter: tonic_health::server::HealthReporter,
163}
164
165impl Server {
166    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
167        self.server.wait_for_shutdown().await;
168        Ok(())
169    }
170
171    pub fn local_addr(&self) -> &Multiaddr {
172        &self.local_addr
173    }
174
175    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
176        self.health_reporter.clone()
177    }
178
179    pub fn handle(&self) -> &sui_http::ServerHandle {
180        &self.server
181    }
182}
183
184fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
185    addr.replace(1, |protocol| {
186        if let Protocol::Tcp(_) = protocol {
187            Some(Protocol::Tcp(port))
188        } else {
189            panic!("expected tcp protocol at index 1");
190        }
191    })
192    .expect("tcp protocol at index 1")
193}
194
195#[derive(Clone)]
196struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
197    metrics_provider: M,
198}
199
200impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
201    type Service = RequestLifetime<M, S>;
202
203    fn layer(&self, inner: S) -> Self::Service {
204        RequestLifetime {
205            inner,
206            metrics_provider: self.metrics_provider.clone(),
207            path: None,
208        }
209    }
210}
211
212#[derive(Clone)]
213struct RequestLifetime<M: MetricsCallbackProvider, S> {
214    inner: S,
215    metrics_provider: M,
216    path: Option<String>,
217}
218
219impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
220    for RequestLifetime<M, S>
221where
222    S: Service<Request<RequestBody>>,
223{
224    type Response = S::Response;
225    type Error = S::Error;
226    type Future = S::Future;
227
228    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
229        self.inner.poll_ready(cx)
230    }
231
232    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
233        if self.path.is_none() {
234            let path = request.uri().path().to_string();
235            self.metrics_provider.on_start(&path);
236            self.path = Some(path);
237        }
238        self.inner.call(request)
239    }
240}
241
242impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
243    fn drop(&mut self) {
244        if let Some(path) = &self.path {
245            self.metrics_provider.on_drop(path)
246        }
247    }
248}
249
250#[cfg(test)]
251mod test {
252    use fastcrypto::ed25519::Ed25519KeyPair;
253    use fastcrypto::traits::KeyPair;
254    use mysten_network::Multiaddr;
255    use mysten_network::config::Config;
256    use mysten_network::metrics::MetricsCallbackProvider;
257    use std::ops::Deref;
258    use std::sync::{Arc, Mutex};
259    use std::time::Duration;
260    use tonic::Code;
261    use tonic_health::pb::HealthCheckRequest;
262    use tonic_health::pb::health_client::HealthClient;
263
264    #[tokio::test]
265    async fn test_metrics_layer_successful() {
266        #[derive(Clone)]
267        struct Metrics {
268            /// a flag to figure out whether the
269            /// on_request method has been called.
270            metrics_called: Arc<Mutex<bool>>,
271        }
272
273        impl MetricsCallbackProvider for Metrics {
274            fn on_request(&self, path: String) {
275                assert_eq!(path, "/grpc.health.v1.Health/Check");
276            }
277
278            fn on_response(
279                &self,
280                path: String,
281                _latency: Duration,
282                status: u16,
283                grpc_status_code: Code,
284            ) {
285                assert_eq!(path, "/grpc.health.v1.Health/Check");
286                assert_eq!(status, 200);
287                assert_eq!(grpc_status_code, Code::Ok);
288                let mut m = self.metrics_called.lock().unwrap();
289                *m = true
290            }
291        }
292
293        let metrics = Metrics {
294            metrics_called: Arc::new(Mutex::new(false)),
295        };
296
297        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
298        let config = Config::new();
299        let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
300
301        let server = super::ServerBuilder::from_config(&config, metrics.clone())
302            .bind(
303                &address,
304                Some(sui_tls::create_rustls_server_config(
305                    keypair.copy().private(),
306                    "test".to_string(),
307                )),
308            )
309            .await
310            .unwrap();
311
312        let address = server.local_addr().to_owned();
313        let channel = config
314            .connect(
315                &address,
316                sui_tls::create_rustls_client_config(
317                    keypair.public().to_owned(),
318                    "test".to_string(),
319                    None,
320                ),
321            )
322            .await
323            .unwrap();
324        let mut client = HealthClient::new(channel);
325
326        client
327            .check(HealthCheckRequest {
328                service: "".to_owned(),
329            })
330            .await
331            .unwrap();
332
333        server.server.shutdown().await;
334
335        assert!(metrics.metrics_called.lock().unwrap().deref());
336    }
337
338    #[tokio::test]
339    async fn test_metrics_layer_error() {
340        #[derive(Clone)]
341        struct Metrics {
342            /// a flag to figure out whether the
343            /// on_request method has been called.
344            metrics_called: Arc<Mutex<bool>>,
345        }
346
347        impl MetricsCallbackProvider for Metrics {
348            fn on_request(&self, path: String) {
349                assert_eq!(path, "/grpc.health.v1.Health/Check");
350            }
351
352            fn on_response(
353                &self,
354                path: String,
355                _latency: Duration,
356                status: u16,
357                grpc_status_code: Code,
358            ) {
359                assert_eq!(path, "/grpc.health.v1.Health/Check");
360                assert_eq!(status, 200);
361                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
362                // code 5 is not_found , which is what we expect to get in this case
363                assert_eq!(grpc_status_code, Code::NotFound);
364                let mut m = self.metrics_called.lock().unwrap();
365                *m = true
366            }
367        }
368
369        let metrics = Metrics {
370            metrics_called: Arc::new(Mutex::new(false)),
371        };
372
373        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
374        let config = Config::new();
375        let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
376
377        let server = super::ServerBuilder::from_config(&config, metrics.clone())
378            .bind(
379                &address,
380                Some(sui_tls::create_rustls_server_config(
381                    keypair.copy().private(),
382                    "test".to_string(),
383                )),
384            )
385            .await
386            .unwrap();
387        let address = server.local_addr().to_owned();
388        let channel = config
389            .connect(
390                &address,
391                sui_tls::create_rustls_client_config(
392                    keypair.public().to_owned(),
393                    "test".to_string(),
394                    None,
395                ),
396            )
397            .await
398            .unwrap();
399        let mut client = HealthClient::new(channel);
400
401        // Call the healthcheck for a service that doesn't exist
402        // that should give us back an error with code 5 (not_found)
403        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
404        let _ = client
405            .check(HealthCheckRequest {
406                service: "non-existing-service".to_owned(),
407            })
408            .await;
409
410        server.server.shutdown().await;
411
412        assert!(metrics.metrics_called.lock().unwrap().deref());
413    }
414
415    async fn test_multiaddr(address: Multiaddr) {
416        let config = Config::new();
417        let keypair = Ed25519KeyPair::generate(&mut rand::thread_rng());
418
419        let server_handle = super::ServerBuilder::from_config(
420            &config,
421            mysten_network::metrics::DefaultMetricsCallbackProvider::default(),
422        )
423        .bind(
424            &address,
425            Some(sui_tls::create_rustls_server_config(
426                keypair.copy().private(),
427                "test".to_string(),
428            )),
429        )
430        .await
431        .unwrap();
432        let address = server_handle.local_addr().to_owned();
433        let channel = config
434            .connect(
435                &address,
436                sui_tls::create_rustls_client_config(
437                    keypair.public().to_owned(),
438                    "test".to_string(),
439                    None,
440                ),
441            )
442            .await
443            .unwrap();
444        let mut client = HealthClient::new(channel);
445
446        client
447            .check(HealthCheckRequest {
448                service: "".to_owned(),
449            })
450            .await
451            .unwrap();
452
453        server_handle.server.shutdown().await;
454    }
455
456    #[tokio::test]
457    async fn dns() {
458        let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
459        test_multiaddr(address).await;
460    }
461
462    #[tokio::test]
463    async fn ip4() {
464        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
465        test_multiaddr(address).await;
466    }
467
468    #[tokio::test]
469    async fn ip6() {
470        let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
471        test_multiaddr(address).await;
472    }
473}