sui_graphql_rpc/server/
version.rs1use axum::{
5 body::Body,
6 extract::State,
7 http::{HeaderName, HeaderValue, Request},
8 middleware::Next,
9 response::Response,
10};
11
12use crate::config::Version;
13
14pub(crate) static VERSION_HEADER: HeaderName = HeaderName::from_static("x-sui-rpc-version");
15
16pub(crate) async fn set_version_middleware(
19 State(version): State<Version>,
20 request: Request<Body>,
21 next: Next,
22) -> Response {
23 let mut response = next.run(request).await;
24 let headers = response.headers_mut();
25 headers.insert(
26 VERSION_HEADER.clone(),
27 HeaderValue::from_static(version.full),
28 );
29 response
30}
31
32#[cfg(test)]
33mod tests {
34 use std::net::SocketAddr;
35
36 use super::*;
37 use crate::{
38 config::{ConnectionConfig, ServiceConfig, Version},
39 metrics::Metrics,
40 server::builder::AppState,
41 };
42 use axum::{body::Body, middleware, routing::get, Router};
43 use http::StatusCode;
44 use mysten_metrics;
45 use tokio_util::sync::CancellationToken;
46 use tower::ServiceExt;
47
48 fn metrics() -> Metrics {
49 let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
50 let registry = mysten_metrics::start_prometheus_server(binding_address).default_registry();
51 Metrics::new(®istry)
52 }
53
54 fn service() -> Router {
55 let version = Version::for_testing();
56 let metrics = metrics();
57 let cancellation_token = CancellationToken::new();
58 let connection_config = ConnectionConfig::default();
59 let service_config = ServiceConfig::default();
60 let state = AppState::new(
61 connection_config.clone(),
62 service_config.clone(),
63 metrics.clone(),
64 cancellation_token.clone(),
65 version,
66 );
67
68 Router::new()
69 .route("/", get(|| async { "Hello, Versioning!" }))
70 .route("/graphql", get(|| async { "Hello, Versioning!" }))
71 .layer(middleware::from_fn_with_state(
72 state.version,
73 set_version_middleware,
74 ))
75 }
76
77 fn graphql_request() -> Request<Body> {
78 Request::builder()
79 .uri("/graphql")
80 .body(Body::empty())
81 .unwrap()
82 }
83
84 fn plain_request() -> Request<Body> {
85 Request::builder().uri("/").body(Body::empty()).unwrap()
86 }
87
88 #[tokio::test]
89 async fn default_graphql_route() {
90 let version = Version::for_testing();
91 let service = service();
92 let response = service.oneshot(graphql_request()).await.unwrap();
93 assert_eq!(response.status(), StatusCode::OK);
94 assert_eq!(
95 response.headers().get(&VERSION_HEADER),
96 Some(&HeaderValue::from_static(version.full))
97 );
98 }
99
100 #[tokio::test]
101 async fn default_plain_route() {
102 let version = Version::for_testing();
103 let service = service();
104 let response = service.oneshot(plain_request()).await.unwrap();
105 assert_eq!(response.status(), StatusCode::OK);
106 assert_eq!(
107 response.headers().get(&VERSION_HEADER),
108 Some(&HeaderValue::from_static(version.full))
109 );
110 }
111}