sui_graphql_rpc/extensions/
query_limits_checker.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::config::{Limits, ServiceConfig};
5use crate::error::{code, graphql_error, graphql_error_at_pos};
6use crate::metrics::Metrics;
7use async_graphql::extensions::NextParseQuery;
8use async_graphql::extensions::NextRequest;
9use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory};
10use async_graphql::parser::types::{
11    DocumentOperations, ExecutableDocument, Field, FragmentDefinition, OperationDefinition,
12    Selection,
13};
14use async_graphql::{value, Name, Pos, Positioned, Response, ServerError, ServerResult, Variables};
15use async_graphql_value::Value as GqlValue;
16use async_graphql_value::{ConstValue, Value};
17use async_trait::async_trait;
18use axum::http::HeaderName;
19use serde::Serialize;
20use std::collections::{HashMap, HashSet};
21use std::mem;
22use std::net::SocketAddr;
23use std::sync::{Arc, Mutex};
24use std::time::Instant;
25use sui_graphql_rpc_headers::LIMITS_HEADER;
26use tracing::{error, info};
27use uuid::Uuid;
28
29pub(crate) const CONNECTION_FIELDS: [&str; 2] = ["edges", "nodes"];
30const DRY_RUN_TX_BLOCK: &str = "dryRunTransactionBlock";
31const EXECUTE_TX_BLOCK: &str = "executeTransactionBlock";
32const MULTI_GET_PREFIX: &str = "multiGet";
33const MULTI_GET_KEYS: &str = "keys";
34const VERIFY_ZKLOGIN: &str = "verifyZkloginSignature";
35
36/// The size of the query payload in bytes, as it comes from the request header: `Content-Length`.
37#[derive(Clone, Copy, Debug)]
38pub(crate) struct PayloadSize(pub u64);
39
40/// Extension factory for adding checks that the query is within configurable limits.
41pub(crate) struct QueryLimitsChecker;
42
43#[derive(Debug, Default)]
44struct QueryLimitsCheckerExt {
45    usage: Mutex<Option<Usage>>,
46}
47
48/// Only display usage information if this header was in the request.
49pub(crate) struct ShowUsage;
50
51/// State for traversing a document to check for limits. Holds on to environments for looking up
52/// variables and fragments, limits, and the remainder of the limit that can be used.
53struct LimitsTraversal<'a> {
54    // Environments for resolving lookups in the document
55    fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
56    variables: &'a Variables,
57
58    /// Creates and trace errors
59    reporter: &'a Reporter<'a>,
60
61    /// Raw size of the request
62    payload_size: u64,
63
64    /// Variables that are used in transaction executions and dry-runs. If these variables are used
65    /// multiple times, the size of their contents should not be double counted.
66    tx_variables_used: HashSet<&'a Name>,
67
68    // Remaining budget for the traversal
69    tx_payload_budget: u32,
70    input_budget: u32,
71    output_budget: u32,
72    depth_seen: u32,
73}
74
75/// Builds error messages and reports them to tracing.
76struct Reporter<'a> {
77    limits: &'a Limits,
78    query_id: &'a Uuid,
79    session_id: &'a SocketAddr,
80}
81
82#[derive(Clone, Debug, Default, Serialize)]
83#[serde(rename_all = "camelCase")]
84struct Usage {
85    input_nodes: u32,
86    output_nodes: u32,
87    depth: u32,
88    variables: u32,
89    fragments: u32,
90    query_payload: u32,
91}
92
93impl ShowUsage {
94    pub(crate) fn name() -> &'static HeaderName {
95        &LIMITS_HEADER
96    }
97}
98
99impl<'a> LimitsTraversal<'a> {
100    fn new(
101        PayloadSize(payload_size): PayloadSize,
102        reporter: &'a Reporter<'a>,
103        fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
104        variables: &'a Variables,
105    ) -> Self {
106        Self {
107            fragments,
108            variables,
109            payload_size,
110            reporter,
111            tx_variables_used: HashSet::new(),
112            tx_payload_budget: reporter.limits.max_tx_payload_size,
113            input_budget: reporter.limits.max_query_nodes,
114            output_budget: reporter.limits.max_output_nodes,
115            depth_seen: 0,
116        }
117    }
118
119    /// Main entrypoint for checking all limits.
120    fn check_document(&mut self, doc: &'a ExecutableDocument) -> ServerResult<()> {
121        // First, check the size of the query inputs. This is done using a non-recursive algorithm in
122        // case the input has too many nodes or is too deep. This allows subsequent checks to be
123        // implemented recursively.
124        for (_name, op) in doc.operations.iter() {
125            self.check_input_limits(op)?;
126        }
127
128        // Then gather inputs to transaction execution and dry-run nodes, and make sure these are
129        // within budget, cumulatively.
130        for (_name, op) in doc.operations.iter() {
131            self.check_tx_payload(op)?;
132        }
133
134        // Next, with the transaction payloads accounted for, ensure the remaining query is within
135        // the size limit.
136        let limits = self.reporter.limits;
137        let tx_payload_size = (limits.max_tx_payload_size - self.tx_payload_budget) as u64;
138        let query_payload_size = self.payload_size - tx_payload_size;
139        if query_payload_size > limits.max_query_payload_size as u64 {
140            let message = format!("Query part too large: {query_payload_size} bytes");
141            return Err(self.reporter.payload_size_error(&message));
142        }
143
144        // Finally, run output node estimation, to check that the output won't contain too many
145        // nodes, in the worst case.
146        for (_name, op) in doc.operations.iter() {
147            self.check_output_limits(op)?;
148        }
149
150        Ok(())
151    }
152
153    /// Test that the operation meets input limits (number of nodes and depth).
154    fn check_input_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
155        let limits = self.reporter.limits;
156
157        let mut next_level = vec![];
158        let mut curr_level = vec![];
159        let mut depth_budget = limits.max_query_depth;
160
161        next_level.extend(&op.node.selection_set.node.items);
162        while let Some(next) = next_level.first() {
163            if depth_budget == 0 {
164                return Err(self.reporter.graphql_error_at_pos(
165                    code::BAD_USER_INPUT,
166                    format!("Query nesting is over {}", limits.max_query_depth),
167                    next.pos,
168                ));
169            } else {
170                depth_budget -= 1;
171            }
172
173            mem::swap(&mut next_level, &mut curr_level);
174
175            for selection in curr_level.drain(..) {
176                if self.input_budget == 0 {
177                    return Err(self.reporter.graphql_error_at_pos(
178                        code::BAD_USER_INPUT,
179                        format!("Query has over {} nodes", limits.max_query_nodes),
180                        selection.pos,
181                    ));
182                } else {
183                    self.input_budget -= 1;
184                }
185
186                match &selection.node {
187                    Selection::Field(f) => {
188                        next_level.extend(&f.node.selection_set.node.items);
189                    }
190
191                    Selection::InlineFragment(f) => {
192                        next_level.extend(&f.node.selection_set.node.items);
193                    }
194
195                    Selection::FragmentSpread(fs) => {
196                        let name = &fs.node.fragment_name.node;
197                        let def = self
198                            .fragments
199                            .get(name)
200                            .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
201
202                        next_level.extend(&def.node.selection_set.node.items);
203                    }
204                }
205            }
206        }
207
208        self.depth_seen = self.depth_seen.max(limits.max_query_depth - depth_budget);
209        Ok(())
210    }
211
212    /// Test that inputs to `executeTransactionBlock` and `dryRunTransactionBlock` take up less
213    /// space than the service's transaction payload limit, cumulatively.
214    ///
215    /// This check must be done after the input limit check, because it relies on the query depth
216    /// being bounded to protect it from recursing too deeply.
217    fn check_tx_payload(&mut self, op: &'a Positioned<OperationDefinition>) -> ServerResult<()> {
218        for item in &op.node.selection_set.node.items {
219            self.traverse_selection_for_tx_payload(item)?;
220        }
221        Ok(())
222    }
223
224    /// Look for `executeTransactionBlock` and `dryRunTransactionBlock` nodes among the
225    /// query selections, and check their argument sizes are under the service limits.
226    fn traverse_selection_for_tx_payload(
227        &mut self,
228        item: &'a Positioned<Selection>,
229    ) -> ServerResult<()> {
230        match &item.node {
231            Selection::Field(f) => {
232                let name = &f.node.name.node;
233
234                if name == DRY_RUN_TX_BLOCK || name == EXECUTE_TX_BLOCK {
235                    for (_name, value) in &f.node.arguments {
236                        self.check_tx_arg(value)?;
237                    }
238                } else if name == VERIFY_ZKLOGIN {
239                    if let Some(value) = f.node.get_argument("bytes") {
240                        self.check_tx_arg(value)?;
241                    }
242
243                    if let Some(value) = f.node.get_argument("signature") {
244                        self.check_tx_arg(value)?;
245                    }
246                }
247            }
248
249            Selection::InlineFragment(f) => {
250                for selection in &f.node.selection_set.node.items {
251                    self.traverse_selection_for_tx_payload(selection)?;
252                }
253            }
254
255            Selection::FragmentSpread(fs) => {
256                let name = &fs.node.fragment_name.node;
257                let def = self
258                    .fragments
259                    .get(name)
260                    .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
261
262                for selection in &def.node.selection_set.node.items {
263                    self.traverse_selection_for_tx_payload(selection)?;
264                }
265            }
266        }
267        Ok(())
268    }
269
270    /// Deduct the size of the transaction argument's `value` from the transaction payload budget.
271    /// This operation resolves variables and deducts their size from the budget as well, as long
272    /// as they have not already been encountered in some previous transaction payload.
273    ///
274    /// Fails if there is insufficient remaining budget.
275    fn check_tx_arg(&mut self, value: &'a Positioned<Value>) -> ServerResult<()> {
276        use GqlValue as V;
277
278        let mut stack = vec![&value.node];
279        while let Some(value) = stack.pop() {
280            match value {
281                V::Variable(name) => self.check_tx_var(name)?,
282
283                V::String(s) => {
284                    // Pay for the string, plus the quotes around it.
285                    let debit = s.len() + 2;
286                    if debit > self.tx_payload_budget as usize {
287                        return Err(self.tx_payload_size_error());
288                    } else {
289                        // SAFETY: We know that debit <= self.tx_payload_budget, which is a u32, so
290                        // the cast and subtraction are both safe.
291                        self.tx_payload_budget -= debit as u32;
292                    }
293                }
294
295                V::List(vs) => {
296                    // Pay for the opening and closing brackets and every comma up-front so that
297                    // deeply nested lists are not free.
298                    let debit = vs.len().saturating_sub(1) + 2;
299                    if debit > self.tx_payload_budget as usize {
300                        return Err(self.tx_payload_size_error());
301                    } else {
302                        // SAFETY: We know that debit <= self.tx_payload_budget, which is a u32, so
303                        // the cast and subtraction are both safe.
304                        self.tx_payload_budget -= debit as u32;
305                        stack.extend(vs)
306                    }
307                }
308
309                V::Null
310                | V::Number(_)
311                | V::Boolean(_)
312                | V::Binary(_)
313                | V::Enum(_)
314                | V::Object(_) => {
315                    // Transaction payloads cannot be any of these types, so this request is
316                    // destined to fail. Ignore these values for now, so that it can fail later on
317                    // with a more legible error message.
318                    //
319                    // From a limits perspective, it is safe to ignore these values here, because
320                    // they will still be counted as part of the query payload (and so are still
321                    // subject to a limit).
322                }
323            }
324        }
325
326        Ok(())
327    }
328
329    /// Deduct the size of the value that variable `name` resolve to from the transaction payload
330    /// budget, if it has not already been encountered in a previous transaction payload.
331    ///
332    /// Fails if there is insufficient remaining budget.
333    fn check_tx_var(&mut self, name: &'a Name) -> ServerResult<()> {
334        use ConstValue as CV;
335
336        // Already used in a transaction, don't double count.
337        if !self.tx_variables_used.insert(name) {
338            return Ok(());
339        }
340
341        // Can't find the variable, so it can't count towards the transaction payload.
342        let Some(value) = self.variables.get(name) else {
343            return Ok(());
344        };
345
346        let mut stack = vec![value];
347        while let Some(value) = stack.pop() {
348            match &value {
349                CV::String(s) => {
350                    // Pay for the string, plus the quotes around it.
351                    let debit = s.len() + 2;
352                    if debit > self.tx_payload_budget as usize {
353                        return Err(self.tx_payload_size_error());
354                    } else {
355                        // SAFETY: We know that debit <= self.tx_payload_budget, which is a u32, so
356                        // the cast and subtraction are both safe.
357                        self.tx_payload_budget -= debit as u32;
358                    }
359                }
360
361                CV::List(vs) => {
362                    // Pay for the opening and closing brackets and every comma up-front so that
363                    // deeply nested lists are not free.
364                    let debit = vs.len().saturating_sub(1) + 2;
365                    if debit > self.tx_payload_budget as usize {
366                        return Err(self.tx_payload_size_error());
367                    } else {
368                        // SAFETY: We know that debit <= self.tx_payload_budget, which is a u32, so
369                        // the cast and subtraction are both safe.
370                        self.tx_payload_budget -= debit as u32;
371                        stack.extend(vs)
372                    }
373                }
374
375                CV::Null
376                | CV::Number(_)
377                | CV::Boolean(_)
378                | CV::Binary(_)
379                | CV::Enum(_)
380                | CV::Object(_) => {
381                    // As in `check_tx_arg`, these are safe to ignore.
382                }
383            }
384        }
385
386        Ok(())
387    }
388
389    /// Check that the operation's output node estimate will not exceed the service's limit.
390    ///
391    /// This check must be done after the input limit check, because it relies on the query depth
392    /// being bounded to protect it from recursing too deeply.
393    fn check_output_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
394        for selection in &op.node.selection_set.node.items {
395            self.traverse_selection_for_output(selection, 1, None)?;
396        }
397        Ok(())
398    }
399
400    /// Account for the estimated output size of this selection and its children.
401    ///
402    /// `multiplicity` is the number of times this selection will be output, on account of being
403    /// nested within paginated ancestors.
404    ///
405    /// If this field is inside a connection, but not inside one of its fields, `page_size` is the
406    /// size of the connection's page.
407    fn traverse_selection_for_output(
408        &mut self,
409        selection: &Positioned<Selection>,
410        multiplicity: u32,
411        page_size: Option<u32>,
412    ) -> ServerResult<()> {
413        match &selection.node {
414            Selection::Field(f) => {
415                if multiplicity > self.output_budget {
416                    return Err(self.output_node_error());
417                } else {
418                    self.output_budget -= multiplicity;
419                }
420
421                let name = &f.node.name.node;
422
423                // Handle regular connection fields and multiGet queries
424                let multiplicity = 'm: {
425                    // check if it is a multiGet query and return the number of keys
426                    if let Some(page_size) = self.multi_get_page_size(f)? {
427                        break 'm multiplicity * page_size;
428                    }
429
430                    if !CONNECTION_FIELDS.contains(&name.as_str()) {
431                        break 'm multiplicity;
432                    }
433
434                    let Some(page_size) = page_size else {
435                        break 'm multiplicity;
436                    };
437                    multiplicity
438                        .checked_mul(page_size)
439                        .ok_or_else(|| self.output_node_error())?
440                };
441
442                let page_size = self.connection_page_size(f)?;
443                for selection in &f.node.selection_set.node.items {
444                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
445                }
446            }
447
448            // Just recurse through fragments, because they are inlined into their "call site".
449            Selection::InlineFragment(f) => {
450                for selection in &f.node.selection_set.node.items {
451                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
452                }
453            }
454
455            Selection::FragmentSpread(fs) => {
456                let name = &fs.node.fragment_name.node;
457                let def = self
458                    .fragments
459                    .get(name)
460                    .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
461
462                for selection in &def.node.selection_set.node.items {
463                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
464                }
465            }
466        }
467
468        Ok(())
469    }
470
471    /// If the field `f` is a connection, extract its page size, otherwise return `None`.
472    /// Returns an error if the page size cannot be represented as a `u32`.
473    fn connection_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
474        if !self.is_connection(f) {
475            return Ok(None);
476        }
477
478        let first = f.node.get_argument("first");
479        let last = f.node.get_argument("last");
480
481        let page_size = match (self.resolve_u64(first), self.resolve_u64(last)) {
482            (Some(f), Some(l)) => f.max(l),
483            (Some(p), _) | (_, Some(p)) => p,
484            (None, None) => self.reporter.limits.default_page_size as u64,
485        };
486
487        Ok(Some(
488            page_size.try_into().map_err(|_| self.output_node_error())?,
489        ))
490    }
491
492    // If the field `f` is a multiGet query, extract the number of keys, otherwise return `None`.
493    // Returns an error if the number of keys cannot be represented as a `u32`.
494    fn multi_get_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
495        if !f.node.name.node.starts_with(MULTI_GET_PREFIX) {
496            return Ok(None);
497        }
498
499        let keys = f.node.get_argument(MULTI_GET_KEYS);
500        let Some(page_size) = self.resolve_list_size(keys) else {
501            return Ok(None);
502        };
503
504        Ok(Some(
505            page_size.try_into().map_err(|_| self.output_node_error())?,
506        ))
507    }
508
509    /// Checks if the given field corresponds to a connection based on whether it contains a
510    /// selection for `edges` or `nodes`. That selection could be immediately in that field's
511    /// selection set, or nested within a fragment or inline fragment spread.
512    fn is_connection(&self, f: &Positioned<Field>) -> bool {
513        f.node
514            .selection_set
515            .node
516            .items
517            .iter()
518            .any(|s| self.has_connection_fields(s))
519    }
520
521    /// Look for fields that suggest the container for this selection is a connection. Recurses
522    /// through fragment and inline fragment applications, but does not look recursively through
523    /// fields, as only the fields requested from the immediate parent are relevant.
524    fn has_connection_fields(&self, s: &Positioned<Selection>) -> bool {
525        match &s.node {
526            Selection::Field(f) => {
527                let name = &f.node.name.node;
528                CONNECTION_FIELDS.contains(&name.as_str())
529            }
530
531            Selection::InlineFragment(f) => f
532                .node
533                .selection_set
534                .node
535                .items
536                .iter()
537                .any(|s| self.has_connection_fields(s)),
538
539            Selection::FragmentSpread(fs) => {
540                let name = &fs.node.fragment_name.node;
541                let Some(def) = self.fragments.get(name) else {
542                    return false;
543                };
544
545                def.node
546                    .selection_set
547                    .node
548                    .items
549                    .iter()
550                    .any(|s| self.has_connection_fields(s))
551            }
552        }
553    }
554
555    /// Translate a GraphQL value into a u64, if possible, resolving variables if necessary.
556    fn resolve_u64(&self, value: Option<&Positioned<Value>>) -> Option<u64> {
557        match &value?.node {
558            Value::Number(num) => num,
559
560            Value::Variable(var) => {
561                if let ConstValue::Number(num) = self.variables.get(var)? {
562                    num
563                } else {
564                    return None;
565                }
566            }
567
568            _ => return None,
569        }
570        .as_u64()
571    }
572
573    /// Find the size of a list, resolving variables if necessary.
574    fn resolve_list_size(&self, value: Option<&Positioned<Value>>) -> Option<usize> {
575        match &value?.node {
576            Value::List(list) => Some(list.len()),
577            Value::Variable(var) => {
578                if let ConstValue::List(list) = self.variables.get(var)? {
579                    Some(list.len())
580                } else {
581                    None
582                }
583            }
584            _ => None,
585        }
586    }
587
588    /// Error returned if transaction payloads exceed limit. Also sets the transaction payload
589    /// budget to zero to indicate it has been spent (This is done to prevent future checks for
590    /// smaller arguments from succeeding even though a previous larger argument has already
591    /// failed).
592    fn tx_payload_size_error(&mut self) -> ServerError {
593        self.tx_payload_budget = 0;
594        self.reporter
595            .payload_size_error("Transaction payload too large")
596    }
597
598    /// Error returned if output node estimate exceeds limit. Also sets the output budget to zero,
599    /// to indicate that it has been spent (This is done because unlike other budgets, the output
600    /// budget is not decremented one unit at a time, so we can have hit the limit previously but
601    /// still have budget left over).
602    fn output_node_error(&mut self) -> ServerError {
603        self.output_budget = 0;
604        self.reporter.output_node_error()
605    }
606
607    /// Finish the traversal and report its usage.
608    fn finish(self, query_payload: u32) -> Usage {
609        let limits = self.reporter.limits;
610        Usage {
611            input_nodes: limits.max_query_nodes - self.input_budget,
612            output_nodes: limits.max_output_nodes - self.output_budget,
613            depth: self.depth_seen,
614            variables: self.variables.len() as u32,
615            fragments: self.fragments.len() as u32,
616            query_payload,
617        }
618    }
619}
620
621impl<'a> Reporter<'a> {
622    fn new(ctx: &'a ExtensionContext<'a>) -> Self {
623        let cfg: &ServiceConfig = ctx.data_unchecked();
624        Self {
625            limits: &cfg.limits,
626            query_id: ctx.data_unchecked(),
627            session_id: ctx.data_unchecked(),
628        }
629    }
630
631    /// Error returned if a fragment is referred to but not found in the document.
632    fn fragment_not_found_error(&self, name: &Name, pos: Pos) -> ServerError {
633        self.graphql_error_at_pos(
634            code::BAD_USER_INPUT,
635            format!("Fragment {name} referred to but not found in document"),
636            pos,
637        )
638    }
639
640    /// Error returned if output node estimate exceeds limit.
641    fn output_node_error(&self) -> ServerError {
642        self.graphql_error(
643            code::BAD_USER_INPUT,
644            format!(
645                "Estimated output nodes exceeds {}",
646                self.limits.max_output_nodes
647            ),
648        )
649    }
650
651    /// Error returned if the payload size exceeds the limit.
652    fn payload_size_error(&self, message: &str) -> ServerError {
653        self.graphql_error(
654            code::BAD_USER_INPUT,
655            format!(
656                "{message}. Requests are limited to {max_tx_payload} bytes or fewer on transaction \
657                 payloads (all inputs to executeTransactionBlock, dryRunTransactionBlock, or \
658                 verifyZkloginSignature) and the rest of the request (the query part) must be \
659                 {max_query_payload} bytes or fewer.",
660                max_tx_payload = self.limits.max_tx_payload_size,
661                max_query_payload = self.limits.max_query_payload_size,
662            ),
663        )
664    }
665
666    /// Build a GraphQL Server Error and also log it.
667    fn graphql_error(&self, code: &str, message: String) -> ServerError {
668        self.log_error(code, &message);
669        graphql_error(code, message)
670    }
671
672    /// Like `graphql_error` but for an error at a specific position in the query.
673    fn graphql_error_at_pos(&self, code: &str, message: String, pos: Pos) -> ServerError {
674        self.log_error(code, &message);
675        graphql_error_at_pos(code, message, pos)
676    }
677
678    /// Log an error (used before returning an error response.
679    fn log_error(&self, error_code: &str, message: &str) {
680        if error_code == code::INTERNAL_SERVER_ERROR {
681            error!(
682                query_id = %self.query_id,
683                session_id = %self.session_id,
684                error_code,
685                "Internal error while checking limits: {message}",
686            );
687        } else {
688            info!(
689                query_id = %self.query_id,
690                session_id = %self.session_id,
691                error_code,
692                "Limits error: {message}",
693            );
694        }
695    }
696}
697
698impl Usage {
699    fn report(&self, metrics: &Metrics) {
700        metrics
701            .request_metrics
702            .input_nodes
703            .observe(self.input_nodes as f64);
704        metrics
705            .request_metrics
706            .output_nodes
707            .observe(self.output_nodes as f64);
708        metrics
709            .request_metrics
710            .query_depth
711            .observe(self.depth as f64);
712        metrics
713            .request_metrics
714            .query_payload_size
715            .observe(self.query_payload as f64);
716    }
717}
718
719impl ExtensionFactory for QueryLimitsChecker {
720    fn create(&self) -> Arc<dyn Extension> {
721        Arc::new(QueryLimitsCheckerExt {
722            usage: Mutex::new(None),
723        })
724    }
725}
726
727#[async_trait]
728impl Extension for QueryLimitsCheckerExt {
729    async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
730        let resp = next.run(ctx).await;
731        let usage = self.usage.lock().unwrap().take();
732        if let Some(usage) = usage {
733            resp.extension("usage", value!(usage))
734        } else {
735            resp
736        }
737    }
738
739    /// Validates the query against the limits set in the service config
740    /// If the limits are hit, the operation terminates early
741    async fn parse_query(
742        &self,
743        ctx: &ExtensionContext<'_>,
744        query: &str,
745        variables: &Variables,
746        next: NextParseQuery<'_>,
747    ) -> ServerResult<ExecutableDocument> {
748        let metrics: &Metrics = ctx.data_unchecked();
749        let payload_size: &PayloadSize = ctx.data_unchecked();
750        let reporter = Reporter::new(ctx);
751
752        let instant = Instant::now();
753
754        // Make sure the request meets a basic size limit before trying to parse it.
755        let max_payload_size = reporter.limits.max_query_payload_size as u64
756            + reporter.limits.max_tx_payload_size as u64;
757
758        if payload_size.0 > max_payload_size {
759            let message = format!("Overall request too large: {} bytes", payload_size.0);
760            return Err(reporter.payload_size_error(&message));
761        }
762
763        // Document layout of the query
764        let doc = next.run(ctx, query, variables).await?;
765
766        // If the query is pure introspection, we don't need to check the limits. Pure introspection
767        // queries are queries that only have one operation with one field and that field is a
768        // `__schema` query
769        if let DocumentOperations::Single(op) = &doc.operations {
770            if let [field] = &op.node.selection_set.node.items[..] {
771                if let Selection::Field(f) = &field.node {
772                    if f.node.name.node == "__schema" {
773                        return Ok(doc);
774                    }
775                }
776            }
777        }
778
779        let mut traversal =
780            LimitsTraversal::new(*payload_size, &reporter, &doc.fragments, variables);
781
782        let res = traversal.check_document(&doc);
783        let usage = traversal.finish(query.len() as u32);
784        metrics.query_validation_latency(instant.elapsed());
785        usage.report(metrics);
786
787        res.map(|()| {
788            if ctx.data_opt::<ShowUsage>().is_some() {
789                *self.usage.lock().unwrap() = Some(usage);
790            }
791
792            doc
793        })
794    }
795}