1use crate::config::{Limits, ServiceConfig};
5use crate::error::{code, graphql_error, graphql_error_at_pos};
6use crate::metrics::Metrics;
7use async_graphql::extensions::NextParseQuery;
8use async_graphql::extensions::NextRequest;
9use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory};
10use async_graphql::parser::types::{
11 DocumentOperations, ExecutableDocument, Field, FragmentDefinition, OperationDefinition,
12 Selection,
13};
14use async_graphql::{value, Name, Pos, Positioned, Response, ServerError, ServerResult, Variables};
15use async_graphql_value::Value as GqlValue;
16use async_graphql_value::{ConstValue, Value};
17use async_trait::async_trait;
18use axum::http::HeaderName;
19use serde::Serialize;
20use std::collections::{HashMap, HashSet};
21use std::mem;
22use std::net::SocketAddr;
23use std::sync::{Arc, Mutex};
24use std::time::Instant;
25use sui_graphql_rpc_headers::LIMITS_HEADER;
26use tracing::{error, info};
27use uuid::Uuid;
28
29pub(crate) const CONNECTION_FIELDS: [&str; 2] = ["edges", "nodes"];
30const DRY_RUN_TX_BLOCK: &str = "dryRunTransactionBlock";
31const EXECUTE_TX_BLOCK: &str = "executeTransactionBlock";
32const MULTI_GET_PREFIX: &str = "multiGet";
33const MULTI_GET_KEYS: &str = "keys";
34const VERIFY_ZKLOGIN: &str = "verifyZkloginSignature";
35
36#[derive(Clone, Copy, Debug)]
38pub(crate) struct PayloadSize(pub u64);
39
40pub(crate) struct QueryLimitsChecker;
42
43#[derive(Debug, Default)]
44struct QueryLimitsCheckerExt {
45 usage: Mutex<Option<Usage>>,
46}
47
48pub(crate) struct ShowUsage;
50
51struct LimitsTraversal<'a> {
54 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
56 variables: &'a Variables,
57
58 reporter: &'a Reporter<'a>,
60
61 payload_size: u64,
63
64 tx_variables_used: HashSet<&'a Name>,
67
68 tx_payload_budget: u32,
70 input_budget: u32,
71 output_budget: u32,
72 depth_seen: u32,
73}
74
75struct Reporter<'a> {
77 limits: &'a Limits,
78 query_id: &'a Uuid,
79 session_id: &'a SocketAddr,
80}
81
82#[derive(Clone, Debug, Default, Serialize)]
83#[serde(rename_all = "camelCase")]
84struct Usage {
85 input_nodes: u32,
86 output_nodes: u32,
87 depth: u32,
88 variables: u32,
89 fragments: u32,
90 query_payload: u32,
91}
92
93impl ShowUsage {
94 pub(crate) fn name() -> &'static HeaderName {
95 &LIMITS_HEADER
96 }
97}
98
99impl<'a> LimitsTraversal<'a> {
100 fn new(
101 PayloadSize(payload_size): PayloadSize,
102 reporter: &'a Reporter<'a>,
103 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
104 variables: &'a Variables,
105 ) -> Self {
106 Self {
107 fragments,
108 variables,
109 payload_size,
110 reporter,
111 tx_variables_used: HashSet::new(),
112 tx_payload_budget: reporter.limits.max_tx_payload_size,
113 input_budget: reporter.limits.max_query_nodes,
114 output_budget: reporter.limits.max_output_nodes,
115 depth_seen: 0,
116 }
117 }
118
119 fn check_document(&mut self, doc: &'a ExecutableDocument) -> ServerResult<()> {
121 for (_name, op) in doc.operations.iter() {
125 self.check_input_limits(op)?;
126 }
127
128 for (_name, op) in doc.operations.iter() {
131 self.check_tx_payload(op)?;
132 }
133
134 let limits = self.reporter.limits;
137 let tx_payload_size = (limits.max_tx_payload_size - self.tx_payload_budget) as u64;
138 let query_payload_size = self.payload_size - tx_payload_size;
139 if query_payload_size > limits.max_query_payload_size as u64 {
140 let message = format!("Query part too large: {query_payload_size} bytes");
141 return Err(self.reporter.payload_size_error(&message));
142 }
143
144 for (_name, op) in doc.operations.iter() {
147 self.check_output_limits(op)?;
148 }
149
150 Ok(())
151 }
152
153 fn check_input_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
155 let limits = self.reporter.limits;
156
157 let mut next_level = vec![];
158 let mut curr_level = vec![];
159 let mut depth_budget = limits.max_query_depth;
160
161 next_level.extend(&op.node.selection_set.node.items);
162 while let Some(next) = next_level.first() {
163 if depth_budget == 0 {
164 return Err(self.reporter.graphql_error_at_pos(
165 code::BAD_USER_INPUT,
166 format!("Query nesting is over {}", limits.max_query_depth),
167 next.pos,
168 ));
169 } else {
170 depth_budget -= 1;
171 }
172
173 mem::swap(&mut next_level, &mut curr_level);
174
175 for selection in curr_level.drain(..) {
176 if self.input_budget == 0 {
177 return Err(self.reporter.graphql_error_at_pos(
178 code::BAD_USER_INPUT,
179 format!("Query has over {} nodes", limits.max_query_nodes),
180 selection.pos,
181 ));
182 } else {
183 self.input_budget -= 1;
184 }
185
186 match &selection.node {
187 Selection::Field(f) => {
188 next_level.extend(&f.node.selection_set.node.items);
189 }
190
191 Selection::InlineFragment(f) => {
192 next_level.extend(&f.node.selection_set.node.items);
193 }
194
195 Selection::FragmentSpread(fs) => {
196 let name = &fs.node.fragment_name.node;
197 let def = self
198 .fragments
199 .get(name)
200 .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
201
202 next_level.extend(&def.node.selection_set.node.items);
203 }
204 }
205 }
206 }
207
208 self.depth_seen = self.depth_seen.max(limits.max_query_depth - depth_budget);
209 Ok(())
210 }
211
212 fn check_tx_payload(&mut self, op: &'a Positioned<OperationDefinition>) -> ServerResult<()> {
218 for item in &op.node.selection_set.node.items {
219 self.traverse_selection_for_tx_payload(item)?;
220 }
221 Ok(())
222 }
223
224 fn traverse_selection_for_tx_payload(
227 &mut self,
228 item: &'a Positioned<Selection>,
229 ) -> ServerResult<()> {
230 match &item.node {
231 Selection::Field(f) => {
232 let name = &f.node.name.node;
233
234 if name == DRY_RUN_TX_BLOCK || name == EXECUTE_TX_BLOCK {
235 for (_name, value) in &f.node.arguments {
236 self.check_tx_arg(value)?;
237 }
238 } else if name == VERIFY_ZKLOGIN {
239 if let Some(value) = f.node.get_argument("bytes") {
240 self.check_tx_arg(value)?;
241 }
242
243 if let Some(value) = f.node.get_argument("signature") {
244 self.check_tx_arg(value)?;
245 }
246 }
247 }
248
249 Selection::InlineFragment(f) => {
250 for selection in &f.node.selection_set.node.items {
251 self.traverse_selection_for_tx_payload(selection)?;
252 }
253 }
254
255 Selection::FragmentSpread(fs) => {
256 let name = &fs.node.fragment_name.node;
257 let def = self
258 .fragments
259 .get(name)
260 .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
261
262 for selection in &def.node.selection_set.node.items {
263 self.traverse_selection_for_tx_payload(selection)?;
264 }
265 }
266 }
267 Ok(())
268 }
269
270 fn check_tx_arg(&mut self, value: &'a Positioned<Value>) -> ServerResult<()> {
276 use GqlValue as V;
277
278 let mut stack = vec![&value.node];
279 while let Some(value) = stack.pop() {
280 match value {
281 V::Variable(name) => self.check_tx_var(name)?,
282
283 V::String(s) => {
284 let debit = s.len() + 2;
286 if debit > self.tx_payload_budget as usize {
287 return Err(self.tx_payload_size_error());
288 } else {
289 self.tx_payload_budget -= debit as u32;
292 }
293 }
294
295 V::List(vs) => {
296 let debit = vs.len().saturating_sub(1) + 2;
299 if debit > self.tx_payload_budget as usize {
300 return Err(self.tx_payload_size_error());
301 } else {
302 self.tx_payload_budget -= debit as u32;
305 stack.extend(vs)
306 }
307 }
308
309 V::Null
310 | V::Number(_)
311 | V::Boolean(_)
312 | V::Binary(_)
313 | V::Enum(_)
314 | V::Object(_) => {
315 }
323 }
324 }
325
326 Ok(())
327 }
328
329 fn check_tx_var(&mut self, name: &'a Name) -> ServerResult<()> {
334 use ConstValue as CV;
335
336 if !self.tx_variables_used.insert(name) {
338 return Ok(());
339 }
340
341 let Some(value) = self.variables.get(name) else {
343 return Ok(());
344 };
345
346 let mut stack = vec![value];
347 while let Some(value) = stack.pop() {
348 match &value {
349 CV::String(s) => {
350 let debit = s.len() + 2;
352 if debit > self.tx_payload_budget as usize {
353 return Err(self.tx_payload_size_error());
354 } else {
355 self.tx_payload_budget -= debit as u32;
358 }
359 }
360
361 CV::List(vs) => {
362 let debit = vs.len().saturating_sub(1) + 2;
365 if debit > self.tx_payload_budget as usize {
366 return Err(self.tx_payload_size_error());
367 } else {
368 self.tx_payload_budget -= debit as u32;
371 stack.extend(vs)
372 }
373 }
374
375 CV::Null
376 | CV::Number(_)
377 | CV::Boolean(_)
378 | CV::Binary(_)
379 | CV::Enum(_)
380 | CV::Object(_) => {
381 }
383 }
384 }
385
386 Ok(())
387 }
388
389 fn check_output_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
394 for selection in &op.node.selection_set.node.items {
395 self.traverse_selection_for_output(selection, 1, None)?;
396 }
397 Ok(())
398 }
399
400 fn traverse_selection_for_output(
408 &mut self,
409 selection: &Positioned<Selection>,
410 multiplicity: u32,
411 page_size: Option<u32>,
412 ) -> ServerResult<()> {
413 match &selection.node {
414 Selection::Field(f) => {
415 if multiplicity > self.output_budget {
416 return Err(self.output_node_error());
417 } else {
418 self.output_budget -= multiplicity;
419 }
420
421 let name = &f.node.name.node;
422
423 let multiplicity = 'm: {
425 if let Some(page_size) = self.multi_get_page_size(f)? {
427 break 'm multiplicity * page_size;
428 }
429
430 if !CONNECTION_FIELDS.contains(&name.as_str()) {
431 break 'm multiplicity;
432 }
433
434 let Some(page_size) = page_size else {
435 break 'm multiplicity;
436 };
437 multiplicity
438 .checked_mul(page_size)
439 .ok_or_else(|| self.output_node_error())?
440 };
441
442 let page_size = self.connection_page_size(f)?;
443 for selection in &f.node.selection_set.node.items {
444 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
445 }
446 }
447
448 Selection::InlineFragment(f) => {
450 for selection in &f.node.selection_set.node.items {
451 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
452 }
453 }
454
455 Selection::FragmentSpread(fs) => {
456 let name = &fs.node.fragment_name.node;
457 let def = self
458 .fragments
459 .get(name)
460 .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
461
462 for selection in &def.node.selection_set.node.items {
463 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
464 }
465 }
466 }
467
468 Ok(())
469 }
470
471 fn connection_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
474 if !self.is_connection(f) {
475 return Ok(None);
476 }
477
478 let first = f.node.get_argument("first");
479 let last = f.node.get_argument("last");
480
481 let page_size = match (self.resolve_u64(first), self.resolve_u64(last)) {
482 (Some(f), Some(l)) => f.max(l),
483 (Some(p), _) | (_, Some(p)) => p,
484 (None, None) => self.reporter.limits.default_page_size as u64,
485 };
486
487 Ok(Some(
488 page_size.try_into().map_err(|_| self.output_node_error())?,
489 ))
490 }
491
492 fn multi_get_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
495 if !f.node.name.node.starts_with(MULTI_GET_PREFIX) {
496 return Ok(None);
497 }
498
499 let keys = f.node.get_argument(MULTI_GET_KEYS);
500 let Some(page_size) = self.resolve_list_size(keys) else {
501 return Ok(None);
502 };
503
504 Ok(Some(
505 page_size.try_into().map_err(|_| self.output_node_error())?,
506 ))
507 }
508
509 fn is_connection(&self, f: &Positioned<Field>) -> bool {
513 f.node
514 .selection_set
515 .node
516 .items
517 .iter()
518 .any(|s| self.has_connection_fields(s))
519 }
520
521 fn has_connection_fields(&self, s: &Positioned<Selection>) -> bool {
525 match &s.node {
526 Selection::Field(f) => {
527 let name = &f.node.name.node;
528 CONNECTION_FIELDS.contains(&name.as_str())
529 }
530
531 Selection::InlineFragment(f) => f
532 .node
533 .selection_set
534 .node
535 .items
536 .iter()
537 .any(|s| self.has_connection_fields(s)),
538
539 Selection::FragmentSpread(fs) => {
540 let name = &fs.node.fragment_name.node;
541 let Some(def) = self.fragments.get(name) else {
542 return false;
543 };
544
545 def.node
546 .selection_set
547 .node
548 .items
549 .iter()
550 .any(|s| self.has_connection_fields(s))
551 }
552 }
553 }
554
555 fn resolve_u64(&self, value: Option<&Positioned<Value>>) -> Option<u64> {
557 match &value?.node {
558 Value::Number(num) => num,
559
560 Value::Variable(var) => {
561 if let ConstValue::Number(num) = self.variables.get(var)? {
562 num
563 } else {
564 return None;
565 }
566 }
567
568 _ => return None,
569 }
570 .as_u64()
571 }
572
573 fn resolve_list_size(&self, value: Option<&Positioned<Value>>) -> Option<usize> {
575 match &value?.node {
576 Value::List(list) => Some(list.len()),
577 Value::Variable(var) => {
578 if let ConstValue::List(list) = self.variables.get(var)? {
579 Some(list.len())
580 } else {
581 None
582 }
583 }
584 _ => None,
585 }
586 }
587
588 fn tx_payload_size_error(&mut self) -> ServerError {
593 self.tx_payload_budget = 0;
594 self.reporter
595 .payload_size_error("Transaction payload too large")
596 }
597
598 fn output_node_error(&mut self) -> ServerError {
603 self.output_budget = 0;
604 self.reporter.output_node_error()
605 }
606
607 fn finish(self, query_payload: u32) -> Usage {
609 let limits = self.reporter.limits;
610 Usage {
611 input_nodes: limits.max_query_nodes - self.input_budget,
612 output_nodes: limits.max_output_nodes - self.output_budget,
613 depth: self.depth_seen,
614 variables: self.variables.len() as u32,
615 fragments: self.fragments.len() as u32,
616 query_payload,
617 }
618 }
619}
620
621impl<'a> Reporter<'a> {
622 fn new(ctx: &'a ExtensionContext<'a>) -> Self {
623 let cfg: &ServiceConfig = ctx.data_unchecked();
624 Self {
625 limits: &cfg.limits,
626 query_id: ctx.data_unchecked(),
627 session_id: ctx.data_unchecked(),
628 }
629 }
630
631 fn fragment_not_found_error(&self, name: &Name, pos: Pos) -> ServerError {
633 self.graphql_error_at_pos(
634 code::BAD_USER_INPUT,
635 format!("Fragment {name} referred to but not found in document"),
636 pos,
637 )
638 }
639
640 fn output_node_error(&self) -> ServerError {
642 self.graphql_error(
643 code::BAD_USER_INPUT,
644 format!(
645 "Estimated output nodes exceeds {}",
646 self.limits.max_output_nodes
647 ),
648 )
649 }
650
651 fn payload_size_error(&self, message: &str) -> ServerError {
653 self.graphql_error(
654 code::BAD_USER_INPUT,
655 format!(
656 "{message}. Requests are limited to {max_tx_payload} bytes or fewer on transaction \
657 payloads (all inputs to executeTransactionBlock, dryRunTransactionBlock, or \
658 verifyZkloginSignature) and the rest of the request (the query part) must be \
659 {max_query_payload} bytes or fewer.",
660 max_tx_payload = self.limits.max_tx_payload_size,
661 max_query_payload = self.limits.max_query_payload_size,
662 ),
663 )
664 }
665
666 fn graphql_error(&self, code: &str, message: String) -> ServerError {
668 self.log_error(code, &message);
669 graphql_error(code, message)
670 }
671
672 fn graphql_error_at_pos(&self, code: &str, message: String, pos: Pos) -> ServerError {
674 self.log_error(code, &message);
675 graphql_error_at_pos(code, message, pos)
676 }
677
678 fn log_error(&self, error_code: &str, message: &str) {
680 if error_code == code::INTERNAL_SERVER_ERROR {
681 error!(
682 query_id = %self.query_id,
683 session_id = %self.session_id,
684 error_code,
685 "Internal error while checking limits: {message}",
686 );
687 } else {
688 info!(
689 query_id = %self.query_id,
690 session_id = %self.session_id,
691 error_code,
692 "Limits error: {message}",
693 );
694 }
695 }
696}
697
698impl Usage {
699 fn report(&self, metrics: &Metrics) {
700 metrics
701 .request_metrics
702 .input_nodes
703 .observe(self.input_nodes as f64);
704 metrics
705 .request_metrics
706 .output_nodes
707 .observe(self.output_nodes as f64);
708 metrics
709 .request_metrics
710 .query_depth
711 .observe(self.depth as f64);
712 metrics
713 .request_metrics
714 .query_payload_size
715 .observe(self.query_payload as f64);
716 }
717}
718
719impl ExtensionFactory for QueryLimitsChecker {
720 fn create(&self) -> Arc<dyn Extension> {
721 Arc::new(QueryLimitsCheckerExt {
722 usage: Mutex::new(None),
723 })
724 }
725}
726
727#[async_trait]
728impl Extension for QueryLimitsCheckerExt {
729 async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
730 let resp = next.run(ctx).await;
731 let usage = self.usage.lock().unwrap().take();
732 if let Some(usage) = usage {
733 resp.extension("usage", value!(usage))
734 } else {
735 resp
736 }
737 }
738
739 async fn parse_query(
742 &self,
743 ctx: &ExtensionContext<'_>,
744 query: &str,
745 variables: &Variables,
746 next: NextParseQuery<'_>,
747 ) -> ServerResult<ExecutableDocument> {
748 let metrics: &Metrics = ctx.data_unchecked();
749 let payload_size: &PayloadSize = ctx.data_unchecked();
750 let reporter = Reporter::new(ctx);
751
752 let instant = Instant::now();
753
754 let max_payload_size = reporter.limits.max_query_payload_size as u64
756 + reporter.limits.max_tx_payload_size as u64;
757
758 if payload_size.0 > max_payload_size {
759 let message = format!("Overall request too large: {} bytes", payload_size.0);
760 return Err(reporter.payload_size_error(&message));
761 }
762
763 let doc = next.run(ctx, query, variables).await?;
765
766 if let DocumentOperations::Single(op) = &doc.operations {
770 if let [field] = &op.node.selection_set.node.items[..] {
771 if let Selection::Field(f) = &field.node {
772 if f.node.name.node == "__schema" {
773 return Ok(doc);
774 }
775 }
776 }
777 }
778
779 let mut traversal =
780 LimitsTraversal::new(*payload_size, &reporter, &doc.fragments, variables);
781
782 let res = traversal.check_document(&doc);
783 let usage = traversal.finish(query.len() as u32);
784 metrics.query_validation_latency(instant.elapsed());
785 usage.report(metrics);
786
787 res.map(|()| {
788 if ctx.data_opt::<ShowUsage>().is_some() {
789 *self.usage.lock().unwrap() = Some(usage);
790 }
791
792 doc
793 })
794 }
795}