sui_graphql_rpc/extensions/
directive_checker.rs1use std::sync::Arc;
5
6use async_graphql::{
7 extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery},
8 parser::types::{Directive, ExecutableDocument, Selection},
9 Positioned, ServerResult,
10};
11use async_graphql_value::Variables;
12use async_trait::async_trait;
13
14use crate::error::{code, graphql_error_at_pos};
15
16const ALLOWED_DIRECTIVES: [&str; 2] = ["include", "skip"];
17
18pub(crate) struct DirectiveChecker;
21
22struct DirectiveCheckerExt;
23
24impl ExtensionFactory for DirectiveChecker {
25 fn create(&self) -> Arc<dyn Extension> {
26 Arc::new(DirectiveCheckerExt)
27 }
28}
29
30#[async_trait]
31impl Extension for DirectiveCheckerExt {
32 async fn parse_query(
33 &self,
34 ctx: &ExtensionContext<'_>,
35 query: &str,
36 variables: &Variables,
37 next: NextParseQuery<'_>,
38 ) -> ServerResult<ExecutableDocument> {
39 let doc = next.run(ctx, query, variables).await?;
40
41 let mut selection_sets = vec![];
42 for fragment in doc.fragments.values() {
43 check_directives(&fragment.node.directives)?;
44 selection_sets.push(&fragment.node.selection_set);
45 }
46
47 for (_name, op) in doc.operations.iter() {
48 check_directives(&op.node.directives)?;
49
50 for var in &op.node.variable_definitions {
51 check_directives(&var.node.directives)?;
52 }
53
54 selection_sets.push(&op.node.selection_set);
55 }
56
57 while let Some(selection_set) = selection_sets.pop() {
58 for selection in &selection_set.node.items {
59 match &selection.node {
60 Selection::Field(field) => {
61 check_directives(&field.node.directives)?;
62 selection_sets.push(&field.node.selection_set);
63 }
64 Selection::FragmentSpread(spread) => {
65 check_directives(&spread.node.directives)?;
66 }
67 Selection::InlineFragment(fragment) => {
68 check_directives(&fragment.node.directives)?;
69 selection_sets.push(&fragment.node.selection_set);
70 }
71 }
72 }
73 }
74
75 Ok(doc)
76 }
77}
78
79fn check_directives(directives: &[Positioned<Directive>]) -> ServerResult<()> {
80 for directive in directives {
81 let name = &directive.node.name.node;
82 if !ALLOWED_DIRECTIVES.contains(&name.as_str()) {
83 let supported: Vec<_> = ALLOWED_DIRECTIVES
84 .iter()
85 .map(|s| format!("`@{s}`"))
86 .collect();
87
88 return Err(graphql_error_at_pos(
89 code::BAD_USER_INPUT,
90 format!(
91 "Directive `@{name}` is not supported. Supported directives are {}",
92 supported.join(", "),
93 ),
94 directive.pos,
95 ));
96 }
97 }
98 Ok(())
99}