sui_graphql_rpc/extensions/
directive_checker.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use async_graphql::{
7    extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery},
8    parser::types::{Directive, ExecutableDocument, Selection},
9    Positioned, ServerResult,
10};
11use async_graphql_value::Variables;
12use async_trait::async_trait;
13
14use crate::error::{code, graphql_error_at_pos};
15
16const ALLOWED_DIRECTIVES: [&str; 2] = ["include", "skip"];
17
18/// Extension factory to add a check that all the directives used in the query are accepted and
19/// understood by the service.
20pub(crate) struct DirectiveChecker;
21
22struct DirectiveCheckerExt;
23
24impl ExtensionFactory for DirectiveChecker {
25    fn create(&self) -> Arc<dyn Extension> {
26        Arc::new(DirectiveCheckerExt)
27    }
28}
29
30#[async_trait]
31impl Extension for DirectiveCheckerExt {
32    async fn parse_query(
33        &self,
34        ctx: &ExtensionContext<'_>,
35        query: &str,
36        variables: &Variables,
37        next: NextParseQuery<'_>,
38    ) -> ServerResult<ExecutableDocument> {
39        let doc = next.run(ctx, query, variables).await?;
40
41        let mut selection_sets = vec![];
42        for fragment in doc.fragments.values() {
43            check_directives(&fragment.node.directives)?;
44            selection_sets.push(&fragment.node.selection_set);
45        }
46
47        for (_name, op) in doc.operations.iter() {
48            check_directives(&op.node.directives)?;
49
50            for var in &op.node.variable_definitions {
51                check_directives(&var.node.directives)?;
52            }
53
54            selection_sets.push(&op.node.selection_set);
55        }
56
57        while let Some(selection_set) = selection_sets.pop() {
58            for selection in &selection_set.node.items {
59                match &selection.node {
60                    Selection::Field(field) => {
61                        check_directives(&field.node.directives)?;
62                        selection_sets.push(&field.node.selection_set);
63                    }
64                    Selection::FragmentSpread(spread) => {
65                        check_directives(&spread.node.directives)?;
66                    }
67                    Selection::InlineFragment(fragment) => {
68                        check_directives(&fragment.node.directives)?;
69                        selection_sets.push(&fragment.node.selection_set);
70                    }
71                }
72            }
73        }
74
75        Ok(doc)
76    }
77}
78
79fn check_directives(directives: &[Positioned<Directive>]) -> ServerResult<()> {
80    for directive in directives {
81        let name = &directive.node.name.node;
82        if !ALLOWED_DIRECTIVES.contains(&name.as_str()) {
83            let supported: Vec<_> = ALLOWED_DIRECTIVES
84                .iter()
85                .map(|s| format!("`@{s}`"))
86                .collect();
87
88            return Err(graphql_error_at_pos(
89                code::BAD_USER_INPUT,
90                format!(
91                    "Directive `@{name}` is not supported. Supported directives are {}",
92                    supported.join(", "),
93                ),
94                directive.pos,
95            ));
96        }
97    }
98    Ok(())
99}