sui_graphql_rpc/server/
version.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use axum::{
5    body::Body,
6    extract::State,
7    http::{HeaderName, HeaderValue, Request},
8    middleware::Next,
9    response::Response,
10};
11
12use crate::config::Version;
13
14pub(crate) static VERSION_HEADER: HeaderName = HeaderName::from_static("x-sui-rpc-version");
15
16/// Mark every outgoing response with a header indicating the precise version of the RPC that was
17/// used (including the patch version and sha).
18pub(crate) async fn set_version_middleware(
19    State(version): State<Version>,
20    request: Request<Body>,
21    next: Next,
22) -> Response {
23    let mut response = next.run(request).await;
24    let headers = response.headers_mut();
25    headers.insert(
26        VERSION_HEADER.clone(),
27        HeaderValue::from_static(version.full),
28    );
29    response
30}
31
32#[cfg(test)]
33mod tests {
34    use std::net::SocketAddr;
35
36    use super::*;
37    use crate::{
38        config::{ConnectionConfig, ServiceConfig, Version},
39        metrics::Metrics,
40        server::builder::AppState,
41    };
42    use axum::{body::Body, middleware, routing::get, Router};
43    use http::StatusCode;
44    use mysten_metrics;
45    use tokio_util::sync::CancellationToken;
46    use tower::ServiceExt;
47
48    fn metrics() -> Metrics {
49        let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
50        let registry = mysten_metrics::start_prometheus_server(binding_address).default_registry();
51        Metrics::new(&registry)
52    }
53
54    fn service() -> Router {
55        let version = Version::for_testing();
56        let metrics = metrics();
57        let cancellation_token = CancellationToken::new();
58        let connection_config = ConnectionConfig::default();
59        let service_config = ServiceConfig::default();
60        let state = AppState::new(
61            connection_config.clone(),
62            service_config.clone(),
63            metrics.clone(),
64            cancellation_token.clone(),
65            version,
66        );
67
68        Router::new()
69            .route("/", get(|| async { "Hello, Versioning!" }))
70            .route("/graphql", get(|| async { "Hello, Versioning!" }))
71            .layer(middleware::from_fn_with_state(
72                state.version,
73                set_version_middleware,
74            ))
75    }
76
77    fn graphql_request() -> Request<Body> {
78        Request::builder()
79            .uri("/graphql")
80            .body(Body::empty())
81            .unwrap()
82    }
83
84    fn plain_request() -> Request<Body> {
85        Request::builder().uri("/").body(Body::empty()).unwrap()
86    }
87
88    #[tokio::test]
89    async fn default_graphql_route() {
90        let version = Version::for_testing();
91        let service = service();
92        let response = service.oneshot(graphql_request()).await.unwrap();
93        assert_eq!(response.status(), StatusCode::OK);
94        assert_eq!(
95            response.headers().get(&VERSION_HEADER),
96            Some(&HeaderValue::from_static(version.full))
97        );
98    }
99
100    #[tokio::test]
101    async fn default_plain_route() {
102        let version = Version::for_testing();
103        let service = service();
104        let response = service.oneshot(plain_request()).await.unwrap();
105        assert_eq!(response.status(), StatusCode::OK);
106        assert_eq!(
107            response.headers().get(&VERSION_HEADER),
108            Some(&HeaderValue::from_static(version.full))
109        );
110    }
111}