sui_graphql_rpc/extensions/
feature_gate.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use async_graphql::{
7    extensions::{Extension, ExtensionContext, ExtensionFactory, NextResolve, ResolveInfo},
8    ServerError, ServerResult, Value,
9};
10use async_trait::async_trait;
11
12use crate::{
13    config::ServiceConfig,
14    error::{code, graphql_error},
15    functional_group::functional_group,
16};
17
18pub(crate) struct FeatureGate;
19
20impl ExtensionFactory for FeatureGate {
21    fn create(&self) -> Arc<dyn Extension> {
22        Arc::new(FeatureGate)
23    }
24}
25
26#[async_trait]
27impl Extension for FeatureGate {
28    async fn resolve(
29        &self,
30        ctx: &ExtensionContext<'_>,
31        info: ResolveInfo<'_>,
32        next: NextResolve<'_>,
33    ) -> ServerResult<Option<Value>> {
34        let ResolveInfo {
35            parent_type,
36            name,
37            is_for_introspection,
38            ..
39        } = &info;
40
41        let ServiceConfig {
42            disabled_features, ..
43        } = ctx.data().map_err(|_| {
44            graphql_error(
45                code::INTERNAL_SERVER_ERROR,
46                "Unable to fetch service configuration",
47            )
48        })?;
49
50        // TODO: Is there a way to set `is_visible` on `MetaField` and `MetaType` in a generic way
51        // after building the schema? (to a function which reads the `ServiceConfig` from the
52        // `Context`). This is (probably) required to hide disabled types and interfaces in the
53        // schema.
54
55        if let Some(group) = functional_group(parent_type, name) {
56            if disabled_features.contains(&group) {
57                return if *is_for_introspection {
58                    Ok(None)
59                } else {
60                    Err(ServerError::new(
61                        format!(
62                            "Cannot query field \"{name}\" on type \"{parent_type}\". \
63                             Feature {} is disabled.",
64                            group.name(),
65                        ),
66                        // TODO: Fork `async-graphl` to add field position information to
67                        // `ResolveInfo`, so the error can take advantage of it.  Similarly for
68                        // utilising the `path_node` to set the error path.
69                        None,
70                    ))
71                };
72            }
73        }
74
75        next.run(ctx, info).await
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use std::collections::BTreeSet;
82
83    use async_graphql::{EmptySubscription, Schema};
84    use expect_test::expect;
85
86    use crate::{functional_group::FunctionalGroup, mutation::Mutation, types::query::Query};
87
88    use super::*;
89
90    #[tokio::test]
91    #[should_panic] // because it tries to access the data provider, which isn't there
92    async fn test_accessing_an_enabled_field() {
93        Schema::build(Query, Mutation, EmptySubscription)
94            .data(ServiceConfig::default())
95            .extension(FeatureGate)
96            .finish()
97            .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
98            .await;
99    }
100
101    #[tokio::test]
102    async fn test_accessing_a_disabled_field() {
103        let errs: Vec<_> = Schema::build(Query, Mutation, EmptySubscription)
104            .data(ServiceConfig {
105                disabled_features: BTreeSet::from_iter([FunctionalGroup::SystemState]),
106                ..Default::default()
107            })
108            .extension(FeatureGate)
109            .finish()
110            .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
111            .await
112            .into_result()
113            .unwrap_err()
114            .into_iter()
115            .map(|e| e.message)
116            .collect();
117
118        let expect = expect![[r#"
119            [
120                "Cannot query field \"protocolConfig\" on type \"Query\". Feature \"system-state\" is disabled.",
121            ]"#]];
122        expect.assert_eq(&format!("{errs:#?}"));
123    }
124}