1use super::QueryExecutor;
5use crate::{config::Limits, error::Error, metrics::Metrics};
6use async_trait::async_trait;
7use diesel::{
8 pg::Pg,
9 query_builder::{Query, QueryFragment, QueryId},
10 QueryResult,
11};
12use diesel_async::{methods::LoadQuery, scoped_futures::ScopedBoxFuture};
13use diesel_async::{scoped_futures::ScopedFutureExt, RunQueryDsl};
14use std::fmt;
15use std::time::Instant;
16use sui_indexer::indexer_reader::IndexerReader;
17
18use tracing::error;
19
20#[derive(Clone)]
21pub(crate) struct PgExecutor {
22 pub inner: IndexerReader,
23 pub limits: Limits,
24 pub metrics: Metrics,
25}
26
27pub(crate) struct PgConnection<'c> {
28 max_cost: u32,
29 conn: &'c mut diesel_async::AsyncPgConnection,
30}
31
32pub(crate) struct ByteaLiteral<'a>(pub &'a [u8]);
33
34impl PgExecutor {
35 pub(crate) fn new(inner: IndexerReader, limits: Limits, metrics: Metrics) -> Self {
36 Self {
37 inner,
38 limits,
39 metrics,
40 }
41 }
42}
43
44#[async_trait]
45impl QueryExecutor for PgExecutor {
46 type Connection = diesel_async::AsyncPgConnection;
47 type Backend = Pg;
48 type DbConnection<'c> = PgConnection<'c>;
49
50 async fn execute<'c, T, U, E>(&self, txn: T) -> Result<U, Error>
51 where
52 T: for<'r> FnOnce(
53 &'r mut Self::DbConnection<'_>,
54 ) -> ScopedBoxFuture<'static, 'r, Result<U, E>>
55 + Send
56 + 'c,
57 E: From<diesel::result::Error> + std::error::Error,
58 T: Send + 'static,
59 U: Send + 'static,
60 E: Send + 'static,
61 {
62 let max_cost = self.limits.max_db_query_cost;
63 let instant = Instant::now();
64 let mut connection = self
65 .inner
66 .pool()
67 .get()
68 .await
69 .map_err(|e| Error::Internal(e.to_string()))?;
70
71 let result = connection
72 .build_transaction()
73 .read_only()
74 .run(|conn| {
75 async move {
76 let mut connection = PgConnection { max_cost, conn };
77 txn(&mut connection).await
78 }
79 .scope_boxed()
80 })
81 .await;
82
83 self.metrics
84 .observe_db_data(instant.elapsed(), result.is_ok());
85 if let Err(e) = &result {
86 error!("DB query error: {e:?}");
87 }
88 result.map_err(|e| Error::Internal(e.to_string()))
89 }
90
91 async fn execute_repeatable<'c, T, U, E>(&self, txn: T) -> Result<U, Error>
92 where
93 T: for<'r> FnOnce(
94 &'r mut Self::DbConnection<'_>,
95 ) -> ScopedBoxFuture<'static, 'r, Result<U, E>>
96 + Send
97 + 'c,
98 E: From<diesel::result::Error> + std::error::Error,
99 T: Send + 'static,
100 U: Send + 'static,
101 E: Send + 'static,
102 {
103 let max_cost = self.limits.max_db_query_cost;
104 let instant = Instant::now();
105
106 let mut connection = self
107 .inner
108 .pool()
109 .get()
110 .await
111 .map_err(|e| Error::Internal(e.to_string()))?;
112
113 let result = connection
114 .build_transaction()
115 .read_only()
116 .repeatable_read()
117 .run(|conn| {
118 async move {
119 txn(&mut PgConnection { max_cost, conn }).await
121 }
122 .scope_boxed()
123 })
124 .await;
125
126 self.metrics
127 .observe_db_data(instant.elapsed(), result.is_ok());
128 if let Err(e) = &result {
129 error!("DB query error: {e:?}");
130 }
131 result.map_err(|e| Error::Internal(e.to_string()))
132 }
133}
134
135#[async_trait]
136impl super::DbConnection for PgConnection<'_> {
137 type Connection = diesel_async::AsyncPgConnection;
138 type Backend = Pg;
139
140 async fn result<T, Q, U>(&mut self, query: T) -> QueryResult<U>
141 where
142 T: Fn() -> Q + Send,
143 Q: diesel::query_builder::Query + Send + 'static,
144 Q: LoadQuery<'static, Self::Connection, U>,
145 Q: QueryId + QueryFragment<Self::Backend>,
146 U: Send,
147 {
148 query_cost::log(self.conn, self.max_cost, query()).await;
149 query().get_result(self.conn).await
150 }
151
152 async fn results<T, Q, U>(&mut self, query: T) -> QueryResult<Vec<U>>
153 where
154 T: Fn() -> Q + Send,
155 Q: diesel::query_builder::Query + Send + 'static,
156 Q: LoadQuery<'static, Self::Connection, U>,
157 Q: QueryId + QueryFragment<Self::Backend>,
158 U: Send,
159 {
160 query_cost::log(self.conn, self.max_cost, query()).await;
161 query().get_results(self.conn).await
162 }
163}
164
165impl fmt::Display for ByteaLiteral<'_> {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 write!(f, "'\\x{}'::bytea", hex::encode(self.0))
168 }
169}
170
171pub(crate) fn bytea_literal(slice: &[u8]) -> ByteaLiteral<'_> {
172 ByteaLiteral(slice)
173}
174
175mod query_cost {
177 use super::*;
178
179 use diesel::{query_builder::AstPass, sql_types::Text, QueryResult};
180 use diesel_async::AsyncPgConnection;
181 use serde_json::Value;
182 use tap::{TapFallible, TapOptional};
183 use tracing::{debug, info, warn};
184
185 #[derive(Debug, Clone, Copy, QueryId)]
186 struct Explained<Q> {
187 query: Q,
188 }
189
190 impl<Q: Query> Query for Explained<Q> {
191 type SqlType = Text;
192 }
193
194 impl<Q: QueryFragment<Pg>> QueryFragment<Pg> for Explained<Q> {
195 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
196 out.push_sql("EXPLAIN (FORMAT JSON) ");
197 self.query.walk_ast(out.reborrow())?;
198 Ok(())
199 }
200 }
201
202 pub(crate) async fn log<Q>(conn: &mut AsyncPgConnection, max_db_query_cost: u32, query: Q)
204 where
205 Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<AsyncPgConnection> + Send,
206 {
207 debug!("Estimating: {}", diesel::debug_query(&query).to_string());
208
209 let Some(cost) = explain(conn, query).await else {
210 warn!("Failed to extract cost from EXPLAIN.");
211 return;
212 };
213
214 if cost > max_db_query_cost as f64 {
215 warn!(cost, max_db_query_cost, exceeds = true, "Estimated cost");
216 } else {
217 info!(cost, max_db_query_cost, exceeds = false, "Estimated cost");
218 }
219 }
220
221 pub(crate) async fn explain<Q>(conn: &mut AsyncPgConnection, query: Q) -> Option<f64>
222 where
223 Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<AsyncPgConnection> + Send,
224 {
225 let result: String = Explained { query }
226 .get_result(conn)
227 .await
228 .tap_err(|e| warn!("Failed to run EXPLAIN: {e}"))
229 .ok()?;
230
231 let parsed = serde_json::from_str(&result)
232 .tap_err(|e| warn!("Failed to parse EXPLAIN result: {e}"))
233 .ok()?;
234
235 extract_cost(&parsed).tap_none(|| warn!("Failed to extract cost from EXPLAIN"))
236 }
237
238 fn extract_cost(parsed: &Value) -> Option<f64> {
239 parsed.get(0)?.get("Plan")?.get("Total Cost")?.as_f64()
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use diesel::QueryDsl;
247 use sui_framework::BuiltInFramework;
248 use sui_indexer::{
249 database::Connection, db::reset_database, models::objects::StoredObject, schema::objects,
250 types::IndexedObject,
251 };
252 use sui_pg_db::temp::TempDb;
253
254 #[tokio::test]
255 async fn test_query_cost() {
256 let database = TempDb::new().unwrap();
257 reset_database(
258 Connection::dedicated(database.database().url())
259 .await
260 .unwrap(),
261 )
262 .await
263 .unwrap();
264 let mut connection = Connection::dedicated(database.database().url())
265 .await
266 .unwrap();
267
268 let objects: Vec<StoredObject> = BuiltInFramework::iter_system_packages()
269 .map(|pkg| IndexedObject::from_object(1, pkg.genesis_object(), None).into())
270 .collect();
271
272 let expect = objects.len();
273 let actual = diesel::insert_into(objects::dsl::objects)
274 .values(objects)
275 .execute(&mut connection)
276 .await
277 .unwrap();
278
279 assert_eq!(expect, actual, "Failed to write objects");
280
281 use objects::dsl;
282 let query_one = dsl::objects.select(dsl::objects.star()).limit(1);
283 let query_all = dsl::objects.select(dsl::objects.star());
284
285 let cost_one = query_cost::explain(&mut connection, query_one)
287 .await
288 .unwrap();
289 let cost_all = query_cost::explain(&mut connection, query_all)
290 .await
291 .unwrap();
292
293 assert!(
294 cost_one < cost_all,
295 "cost_one = {cost_one} >= {cost_all} = cost_all"
296 );
297 }
298}