sui_graphql_rpc/data/
pg.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::QueryExecutor;
5use crate::{config::Limits, error::Error, metrics::Metrics};
6use async_trait::async_trait;
7use diesel::{
8    pg::Pg,
9    query_builder::{Query, QueryFragment, QueryId},
10    QueryResult,
11};
12use diesel_async::{methods::LoadQuery, scoped_futures::ScopedBoxFuture};
13use diesel_async::{scoped_futures::ScopedFutureExt, RunQueryDsl};
14use std::fmt;
15use std::time::Instant;
16use sui_indexer::indexer_reader::IndexerReader;
17
18use tracing::error;
19
20#[derive(Clone)]
21pub(crate) struct PgExecutor {
22    pub inner: IndexerReader,
23    pub limits: Limits,
24    pub metrics: Metrics,
25}
26
27pub(crate) struct PgConnection<'c> {
28    max_cost: u32,
29    conn: &'c mut diesel_async::AsyncPgConnection,
30}
31
32pub(crate) struct ByteaLiteral<'a>(pub &'a [u8]);
33
34impl PgExecutor {
35    pub(crate) fn new(inner: IndexerReader, limits: Limits, metrics: Metrics) -> Self {
36        Self {
37            inner,
38            limits,
39            metrics,
40        }
41    }
42}
43
44#[async_trait]
45impl QueryExecutor for PgExecutor {
46    type Connection = diesel_async::AsyncPgConnection;
47    type Backend = Pg;
48    type DbConnection<'c> = PgConnection<'c>;
49
50    async fn execute<'c, T, U, E>(&self, txn: T) -> Result<U, Error>
51    where
52        T: for<'r> FnOnce(
53                &'r mut Self::DbConnection<'_>,
54            ) -> ScopedBoxFuture<'static, 'r, Result<U, E>>
55            + Send
56            + 'c,
57        E: From<diesel::result::Error> + std::error::Error,
58        T: Send + 'static,
59        U: Send + 'static,
60        E: Send + 'static,
61    {
62        let max_cost = self.limits.max_db_query_cost;
63        let instant = Instant::now();
64        let mut connection = self
65            .inner
66            .pool()
67            .get()
68            .await
69            .map_err(|e| Error::Internal(e.to_string()))?;
70
71        let result = connection
72            .build_transaction()
73            .read_only()
74            .run(|conn| {
75                async move {
76                    let mut connection = PgConnection { max_cost, conn };
77                    txn(&mut connection).await
78                }
79                .scope_boxed()
80            })
81            .await;
82
83        self.metrics
84            .observe_db_data(instant.elapsed(), result.is_ok());
85        if let Err(e) = &result {
86            error!("DB query error: {e:?}");
87        }
88        result.map_err(|e| Error::Internal(e.to_string()))
89    }
90
91    async fn execute_repeatable<'c, T, U, E>(&self, txn: T) -> Result<U, Error>
92    where
93        T: for<'r> FnOnce(
94                &'r mut Self::DbConnection<'_>,
95            ) -> ScopedBoxFuture<'static, 'r, Result<U, E>>
96            + Send
97            + 'c,
98        E: From<diesel::result::Error> + std::error::Error,
99        T: Send + 'static,
100        U: Send + 'static,
101        E: Send + 'static,
102    {
103        let max_cost = self.limits.max_db_query_cost;
104        let instant = Instant::now();
105
106        let mut connection = self
107            .inner
108            .pool()
109            .get()
110            .await
111            .map_err(|e| Error::Internal(e.to_string()))?;
112
113        let result = connection
114            .build_transaction()
115            .read_only()
116            .repeatable_read()
117            .run(|conn| {
118                async move {
119                    //
120                    txn(&mut PgConnection { max_cost, conn }).await
121                }
122                .scope_boxed()
123            })
124            .await;
125
126        self.metrics
127            .observe_db_data(instant.elapsed(), result.is_ok());
128        if let Err(e) = &result {
129            error!("DB query error: {e:?}");
130        }
131        result.map_err(|e| Error::Internal(e.to_string()))
132    }
133}
134
135#[async_trait]
136impl super::DbConnection for PgConnection<'_> {
137    type Connection = diesel_async::AsyncPgConnection;
138    type Backend = Pg;
139
140    async fn result<T, Q, U>(&mut self, query: T) -> QueryResult<U>
141    where
142        T: Fn() -> Q + Send,
143        Q: diesel::query_builder::Query + Send + 'static,
144        Q: LoadQuery<'static, Self::Connection, U>,
145        Q: QueryId + QueryFragment<Self::Backend>,
146        U: Send,
147    {
148        query_cost::log(self.conn, self.max_cost, query()).await;
149        query().get_result(self.conn).await
150    }
151
152    async fn results<T, Q, U>(&mut self, query: T) -> QueryResult<Vec<U>>
153    where
154        T: Fn() -> Q + Send,
155        Q: diesel::query_builder::Query + Send + 'static,
156        Q: LoadQuery<'static, Self::Connection, U>,
157        Q: QueryId + QueryFragment<Self::Backend>,
158        U: Send,
159    {
160        query_cost::log(self.conn, self.max_cost, query()).await;
161        query().get_results(self.conn).await
162    }
163}
164
165impl fmt::Display for ByteaLiteral<'_> {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        write!(f, "'\\x{}'::bytea", hex::encode(self.0))
168    }
169}
170
171pub(crate) fn bytea_literal(slice: &[u8]) -> ByteaLiteral<'_> {
172    ByteaLiteral(slice)
173}
174
175/// Support for calculating estimated query cost using EXPLAIN and then logging it.
176mod query_cost {
177    use super::*;
178
179    use diesel::{query_builder::AstPass, sql_types::Text, QueryResult};
180    use diesel_async::AsyncPgConnection;
181    use serde_json::Value;
182    use tap::{TapFallible, TapOptional};
183    use tracing::{debug, info, warn};
184
185    #[derive(Debug, Clone, Copy, QueryId)]
186    struct Explained<Q> {
187        query: Q,
188    }
189
190    impl<Q: Query> Query for Explained<Q> {
191        type SqlType = Text;
192    }
193
194    impl<Q: QueryFragment<Pg>> QueryFragment<Pg> for Explained<Q> {
195        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
196            out.push_sql("EXPLAIN (FORMAT JSON) ");
197            self.query.walk_ast(out.reborrow())?;
198            Ok(())
199        }
200    }
201
202    /// Run `EXPLAIN` on the `query`, and log the estimated cost.
203    pub(crate) async fn log<Q>(conn: &mut AsyncPgConnection, max_db_query_cost: u32, query: Q)
204    where
205        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<AsyncPgConnection> + Send,
206    {
207        debug!("Estimating: {}", diesel::debug_query(&query).to_string());
208
209        let Some(cost) = explain(conn, query).await else {
210            warn!("Failed to extract cost from EXPLAIN.");
211            return;
212        };
213
214        if cost > max_db_query_cost as f64 {
215            warn!(cost, max_db_query_cost, exceeds = true, "Estimated cost");
216        } else {
217            info!(cost, max_db_query_cost, exceeds = false, "Estimated cost");
218        }
219    }
220
221    pub(crate) async fn explain<Q>(conn: &mut AsyncPgConnection, query: Q) -> Option<f64>
222    where
223        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<AsyncPgConnection> + Send,
224    {
225        let result: String = Explained { query }
226            .get_result(conn)
227            .await
228            .tap_err(|e| warn!("Failed to run EXPLAIN: {e}"))
229            .ok()?;
230
231        let parsed = serde_json::from_str(&result)
232            .tap_err(|e| warn!("Failed to parse EXPLAIN result: {e}"))
233            .ok()?;
234
235        extract_cost(&parsed).tap_none(|| warn!("Failed to extract cost from EXPLAIN"))
236    }
237
238    fn extract_cost(parsed: &Value) -> Option<f64> {
239        parsed.get(0)?.get("Plan")?.get("Total Cost")?.as_f64()
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use diesel::QueryDsl;
247    use sui_framework::BuiltInFramework;
248    use sui_indexer::{
249        database::Connection, db::reset_database, models::objects::StoredObject, schema::objects,
250        types::IndexedObject,
251    };
252    use sui_pg_db::temp::TempDb;
253
254    #[tokio::test]
255    async fn test_query_cost() {
256        let database = TempDb::new().unwrap();
257        reset_database(
258            Connection::dedicated(database.database().url())
259                .await
260                .unwrap(),
261        )
262        .await
263        .unwrap();
264        let mut connection = Connection::dedicated(database.database().url())
265            .await
266            .unwrap();
267
268        let objects: Vec<StoredObject> = BuiltInFramework::iter_system_packages()
269            .map(|pkg| IndexedObject::from_object(1, pkg.genesis_object(), None).into())
270            .collect();
271
272        let expect = objects.len();
273        let actual = diesel::insert_into(objects::dsl::objects)
274            .values(objects)
275            .execute(&mut connection)
276            .await
277            .unwrap();
278
279        assert_eq!(expect, actual, "Failed to write objects");
280
281        use objects::dsl;
282        let query_one = dsl::objects.select(dsl::objects.star()).limit(1);
283        let query_all = dsl::objects.select(dsl::objects.star());
284
285        // Test estimating query costs
286        let cost_one = query_cost::explain(&mut connection, query_one)
287            .await
288            .unwrap();
289        let cost_all = query_cost::explain(&mut connection, query_all)
290            .await
291            .unwrap();
292
293        assert!(
294            cost_one < cost_all,
295            "cost_one = {cost_one} >= {cost_all} = cost_all"
296        );
297    }
298}