sui_graphql_macros/
schema.rs

1//! GraphQL schema parsing and type lookup.
2
3use graphql_parser::schema as gql;
4use graphql_parser::schema::Definition;
5use graphql_parser::schema::TypeDefinition;
6use std::collections::HashMap;
7use std::sync::LazyLock;
8
9/// The embedded Sui GraphQL schema.
10const SCHEMA_SDL: &str = include_str!("../schema/sui.graphql");
11
12/// Parsed and indexed schema, cached for reuse across macro invocations.
13static SCHEMA: LazyLock<Result<Schema, String>> =
14    LazyLock::new(|| Schema::parse(SCHEMA_SDL).map_err(|e| format!("Failed to parse schema: {e}")));
15
16/// A parsed GraphQL schema with type lookup.
17#[derive(Debug)]
18pub struct Schema {
19    types: HashMap<String, TypeInfo>,
20}
21
22/// Information about a GraphQL type.
23#[derive(Debug)]
24pub struct TypeInfo {
25    pub name: String,
26    pub fields: HashMap<String, FieldInfo>,
27    /// For union types, the member type names. `None` for non-union types.
28    pub union_types: Option<Vec<String>>,
29}
30
31/// Information about a field on a type.
32#[derive(Debug)]
33pub struct FieldInfo {
34    pub name: String,
35    pub type_name: String,
36    pub is_list: bool,
37}
38
39impl Schema {
40    /// Load the embedded Sui GraphQL schema.
41    ///
42    /// Returns an `Arc<Schema>` for cheap sharing without deep copying.
43    pub fn load() -> Result<&'static Schema, syn::Error> {
44        SCHEMA
45            .as_ref()
46            .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), e.clone()))
47    }
48
49    /// Load a custom schema from SDL content.
50    ///
51    /// Used when `#[response(schema = "...")]` specifies a custom schema.
52    pub fn from_sdl(sdl: &str) -> Result<Self, syn::Error> {
53        Self::parse(sdl).map_err(|e| {
54            syn::Error::new(
55                proc_macro2::Span::call_site(),
56                format!("Failed to parse custom schema: {e}"),
57            )
58        })
59    }
60
61    /// Parse a GraphQL SDL schema.
62    fn parse(sdl: &str) -> Result<Self, graphql_parser::schema::ParseError> {
63        let doc = gql::parse_schema::<String>(sdl)?;
64        let mut types = HashMap::new();
65
66        for def in doc.definitions {
67            if let Definition::TypeDefinition(type_def) = def {
68                let type_info = Self::parse_type_definition(type_def);
69                types.insert(type_info.name.clone(), type_info);
70            }
71        }
72
73        Ok(Schema { types })
74    }
75
76    fn parse_type_definition(def: TypeDefinition<String>) -> TypeInfo {
77        match def {
78            TypeDefinition::Object(obj) => TypeInfo {
79                name: obj.name.clone(),
80                fields: obj
81                    .fields
82                    .iter()
83                    .map(|f| {
84                        let info = Self::parse_field(f);
85                        (info.name.clone(), info)
86                    })
87                    .collect(),
88                union_types: None,
89            },
90            TypeDefinition::Interface(i) => TypeInfo {
91                name: i.name.clone(),
92                fields: i
93                    .fields
94                    .iter()
95                    .map(|f| {
96                        let info = Self::parse_field(f);
97                        (info.name.clone(), info)
98                    })
99                    .collect(),
100                union_types: None,
101            },
102            TypeDefinition::Scalar(s) => TypeInfo {
103                name: s.name.clone(),
104                fields: HashMap::new(),
105                union_types: None,
106            },
107            TypeDefinition::Enum(e) => TypeInfo {
108                name: e.name.clone(),
109                fields: HashMap::new(),
110                union_types: None,
111            },
112            TypeDefinition::InputObject(io) => TypeInfo {
113                name: io.name.clone(),
114                fields: HashMap::new(),
115                union_types: None,
116            },
117            TypeDefinition::Union(u) => TypeInfo {
118                name: u.name.clone(),
119                fields: HashMap::new(),
120                union_types: Some(u.types),
121            },
122        }
123    }
124
125    fn parse_field(field: &gql::Field<String>) -> FieldInfo {
126        let (type_name, is_list) = Self::parse_type(&field.field_type);
127        FieldInfo {
128            name: field.name.clone(),
129            type_name,
130            is_list,
131        }
132    }
133
134    /// Parse a GraphQL type, extracting the base type name and list status.
135    fn parse_type(ty: &gql::Type<String>) -> (String, bool) {
136        match ty {
137            gql::Type::NamedType(name) => (name.clone(), false),
138            gql::Type::NonNullType(inner) => Self::parse_type(inner),
139            gql::Type::ListType(inner) => {
140                let (name, _) = Self::parse_type(inner);
141                (name, true)
142            }
143        }
144    }
145
146    /// Look up a field on a type.
147    pub fn field(&self, type_name: &str, field_name: &str) -> Option<&FieldInfo> {
148        self.types.get(type_name)?.fields.get(field_name)
149    }
150
151    /// Get all field names for a type.
152    pub fn field_names(&self, type_name: &str) -> Vec<&str> {
153        self.types
154            .get(type_name)
155            .map(|t| t.fields.keys().map(|s| s.as_str()).collect())
156            .unwrap_or_default()
157    }
158
159    /// Check if a type exists in the schema.
160    pub fn has_type(&self, type_name: &str) -> bool {
161        self.types.contains_key(type_name)
162    }
163
164    /// Get all type names in the schema.
165    pub fn type_names(&self) -> Vec<&str> {
166        self.types.keys().map(|s| s.as_str()).collect()
167    }
168
169    /// Check if a type is a union type.
170    pub fn is_union(&self, type_name: &str) -> bool {
171        self.types
172            .get(type_name)
173            .is_some_and(|t| t.union_types.is_some())
174    }
175
176    /// Get the member type names of a union.
177    pub fn union_types(&self, type_name: &str) -> Vec<&str> {
178        self.types
179            .get(type_name)
180            .and_then(|t| t.union_types.as_ref())
181            .map(|v| v.iter().map(|s| s.as_str()).collect())
182            .unwrap_or_default()
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_has_type() {
192        let schema = Schema::load().unwrap();
193        // Standard GraphQL root types
194        assert!(schema.has_type("Query"));
195        assert!(schema.has_type("Mutation"));
196        // Common Sui types
197        assert!(schema.has_type("Object"));
198        assert!(schema.has_type("DynamicField"));
199        assert!(schema.has_type("MoveObject"));
200        // Non-existent types
201        assert!(!schema.has_type("NonExistent"));
202        assert!(!schema.has_type("query")); // Case-sensitive
203    }
204
205    #[test]
206    fn test_type_names() {
207        let schema = Schema::load().unwrap();
208        let type_names = schema.type_names();
209        // Should contain standard types
210        assert!(type_names.contains(&"Query"));
211        assert!(type_names.contains(&"Mutation"));
212        assert!(type_names.contains(&"Object"));
213    }
214
215    fn test_schema() -> Schema {
216        let sdl = include_str!("../tests/test_schema.graphql");
217        Schema::from_sdl(sdl).unwrap()
218    }
219
220    #[test]
221    fn test_union_types() {
222        let schema = test_schema();
223
224        assert!(schema.is_union("DynamicFieldValue"));
225        let members = schema.union_types("DynamicFieldValue");
226        assert!(members.contains(&"MoveObject"));
227        assert!(members.contains(&"MoveValue"));
228
229        // Non-union types
230        assert!(!schema.is_union("Object"));
231        assert!(schema.union_types("Object").is_empty());
232
233        // Non-existent type
234        assert!(!schema.is_union("NonExistent"));
235        assert!(schema.union_types("NonExistent").is_empty());
236    }
237}