sui_graphql_rpc/extensions/
feature_gate.rs1use std::sync::Arc;
5
6use async_graphql::{
7 extensions::{Extension, ExtensionContext, ExtensionFactory, NextResolve, ResolveInfo},
8 ServerError, ServerResult, Value,
9};
10use async_trait::async_trait;
11
12use crate::{
13 config::ServiceConfig,
14 error::{code, graphql_error},
15 functional_group::functional_group,
16};
17
18pub(crate) struct FeatureGate;
19
20impl ExtensionFactory for FeatureGate {
21 fn create(&self) -> Arc<dyn Extension> {
22 Arc::new(FeatureGate)
23 }
24}
25
26#[async_trait]
27impl Extension for FeatureGate {
28 async fn resolve(
29 &self,
30 ctx: &ExtensionContext<'_>,
31 info: ResolveInfo<'_>,
32 next: NextResolve<'_>,
33 ) -> ServerResult<Option<Value>> {
34 let ResolveInfo {
35 parent_type,
36 name,
37 is_for_introspection,
38 ..
39 } = &info;
40
41 let ServiceConfig {
42 disabled_features, ..
43 } = ctx.data().map_err(|_| {
44 graphql_error(
45 code::INTERNAL_SERVER_ERROR,
46 "Unable to fetch service configuration",
47 )
48 })?;
49
50 if let Some(group) = functional_group(parent_type, name) {
56 if disabled_features.contains(&group) {
57 return if *is_for_introspection {
58 Ok(None)
59 } else {
60 Err(ServerError::new(
61 format!(
62 "Cannot query field \"{name}\" on type \"{parent_type}\". \
63 Feature {} is disabled.",
64 group.name(),
65 ),
66 None,
70 ))
71 };
72 }
73 }
74
75 next.run(ctx, info).await
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use std::collections::BTreeSet;
82
83 use async_graphql::{EmptySubscription, Schema};
84 use expect_test::expect;
85
86 use crate::{functional_group::FunctionalGroup, mutation::Mutation, types::query::Query};
87
88 use super::*;
89
90 #[tokio::test]
91 #[should_panic] async fn test_accessing_an_enabled_field() {
93 Schema::build(Query, Mutation, EmptySubscription)
94 .data(ServiceConfig::default())
95 .extension(FeatureGate)
96 .finish()
97 .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
98 .await;
99 }
100
101 #[tokio::test]
102 async fn test_accessing_a_disabled_field() {
103 let errs: Vec<_> = Schema::build(Query, Mutation, EmptySubscription)
104 .data(ServiceConfig {
105 disabled_features: BTreeSet::from_iter([FunctionalGroup::SystemState]),
106 ..Default::default()
107 })
108 .extension(FeatureGate)
109 .finish()
110 .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
111 .await
112 .into_result()
113 .unwrap_err()
114 .into_iter()
115 .map(|e| e.message)
116 .collect();
117
118 let expect = expect![[r#"
119 [
120 "Cannot query field \"protocolConfig\" on type \"Query\". Feature \"system-state\" is disabled.",
121 ]"#]];
122 expect.assert_eq(&format!("{errs:#?}"));
123 }
124}