sui_pg_db/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::ops::{Deref, DerefMut};
5use std::path::PathBuf;
6use std::time::Duration;
7
8use anyhow::anyhow;
9use diesel::ConnectionError;
10use diesel::migration::{Migration, MigrationSource, MigrationVersion};
11use diesel::pg::Pg;
12use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
13use diesel_async::pooled_connection::ManagerConfig;
14use diesel_async::{
15    AsyncPgConnection, RunQueryDsl,
16    pooled_connection::{
17        AsyncDieselConnectionManager,
18        bb8::{Pool, PooledConnection},
19    },
20};
21use futures::FutureExt;
22use tracing::info;
23use url::Url;
24
25use tls::{build_tls_config, establish_tls_connection};
26
27mod model;
28mod tls;
29
30pub use sui_field_count::FieldCount;
31pub use sui_sql_macro::sql;
32
33pub mod query;
34pub mod schema;
35pub mod store;
36pub mod temp;
37
38use diesel_migrations::{EmbeddedMigrations, embed_migrations};
39
40pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
41
42#[derive(clap::Args, Debug, Clone)]
43pub struct DbArgs {
44    /// Number of connections to keep in the pool.
45    #[arg(long, default_value_t = Self::default().db_connection_pool_size)]
46    pub db_connection_pool_size: u32,
47
48    /// Time spent waiting for a connection from the pool to become available, in milliseconds.
49    #[arg(long, default_value_t = Self::default().db_connection_timeout_ms)]
50    pub db_connection_timeout_ms: u64,
51
52    #[arg(long)]
53    /// Time spent waiting for statements to complete, in milliseconds.
54    pub db_statement_timeout_ms: Option<u64>,
55
56    #[arg(long)]
57    /// Enable server certificate verification. By default, this is set to false to match the
58    /// default behavior of libpq.
59    pub tls_verify_cert: bool,
60
61    #[arg(long)]
62    /// Path to a custom CA certificate to use for server certificate verification.
63    pub tls_ca_cert_path: Option<PathBuf>,
64}
65
66#[derive(Clone)]
67pub struct Db(Pool<AsyncPgConnection>);
68
69/// Wrapper struct over the remote `PooledConnection` type for dealing with the `Store` trait.
70pub struct Connection<'a>(PooledConnection<'a, AsyncPgConnection>);
71
72impl DbArgs {
73    pub fn connection_timeout(&self) -> Duration {
74        Duration::from_millis(self.db_connection_timeout_ms)
75    }
76
77    pub fn statement_timeout(&self) -> Option<Duration> {
78        self.db_statement_timeout_ms.map(Duration::from_millis)
79    }
80}
81
82impl Db {
83    /// Construct a new DB connection pool talking to the database at `database_url` that supports
84    /// write and reads. Instances of [Db] can be cloned to share access to the same pool.
85    pub async fn for_write(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
86        Ok(Self(pool(database_url, config, false).await?))
87    }
88
89    /// Construct a new DB connection pool talking to the database at `database_url` that defaults
90    /// to read-only transactions. Instances of [Db] can be cloned to share access to the same
91    /// pool.
92    pub async fn for_read(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
93        Ok(Self(pool(database_url, config, true).await?))
94    }
95
96    /// Retrieves a connection from the pool. Can fail with a timeout if a connection cannot be
97    /// established before the [DbArgs::connection_timeout] has elapsed.
98    pub async fn connect(&self) -> anyhow::Result<Connection<'_>> {
99        Ok(Connection(self.0.get().await?))
100    }
101
102    /// Statistics about the connection pool
103    pub fn state(&self) -> bb8::State {
104        self.0.state()
105    }
106
107    async fn clear_database(&self) -> anyhow::Result<()> {
108        info!("Clearing the database...");
109        let mut conn = self.connect().await?;
110        let drop_all_tables = "
111        DO $$ DECLARE
112            r RECORD;
113        BEGIN
114        FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
115            LOOP
116                EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
117            END LOOP;
118        END $$;";
119        diesel::sql_query(drop_all_tables)
120            .execute(&mut conn)
121            .await?;
122        info!("Dropped all tables.");
123
124        let drop_all_procedures = "
125        DO $$ DECLARE
126            r RECORD;
127        BEGIN
128            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
129                      FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
130                      WHERE ns.nspname = 'public' AND prokind = 'p')
131            LOOP
132                EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
133            END LOOP;
134        END $$;";
135        diesel::sql_query(drop_all_procedures)
136            .execute(&mut conn)
137            .await?;
138        info!("Dropped all procedures.");
139
140        let drop_all_functions = "
141        DO $$ DECLARE
142            r RECORD;
143        BEGIN
144            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
145                      FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
146                      WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
147            LOOP
148                EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
149            END LOOP;
150        END $$;";
151        diesel::sql_query(drop_all_functions)
152            .execute(&mut conn)
153            .await?;
154        info!("Database cleared.");
155        Ok(())
156    }
157
158    /// Run migrations on the database. Use Diesel's `embed_migrations!` macro to generate the
159    /// `migrations` parameter for your indexer.
160    pub async fn run_migrations(
161        &self,
162        migrations: Option<&'static EmbeddedMigrations>,
163    ) -> anyhow::Result<Vec<MigrationVersion<'static>>> {
164        use diesel_migrations::MigrationHarness;
165
166        let merged_migrations = merge_migrations(migrations);
167
168        info!("Running migrations ...");
169        let conn = self.0.dedicated_connection().await?;
170        let mut wrapper: AsyncConnectionWrapper<AsyncPgConnection> =
171            diesel_async::async_connection_wrapper::AsyncConnectionWrapper::from(conn);
172
173        let finished_migrations = tokio::task::spawn_blocking(move || {
174            wrapper
175                .run_pending_migrations(merged_migrations)
176                .map(|versions| versions.iter().map(MigrationVersion::as_owned).collect())
177        })
178        .await?
179        .map_err(|e| anyhow!("Failed to run migrations: {:?}", e))?;
180
181        info!("Migrations complete.");
182        Ok(finished_migrations)
183    }
184}
185
186impl Default for DbArgs {
187    fn default() -> Self {
188        Self {
189            db_connection_pool_size: 100,
190            db_connection_timeout_ms: 60_000,
191            db_statement_timeout_ms: None,
192            tls_verify_cert: false,
193            tls_ca_cert_path: None,
194        }
195    }
196}
197
198/// Drop all tables, and re-run migrations if supplied.
199pub async fn reset_database(
200    database_url: Url,
201    db_config: DbArgs,
202    migrations: Option<&'static EmbeddedMigrations>,
203) -> anyhow::Result<()> {
204    let db = Db::for_write(database_url, db_config).await?;
205    db.clear_database().await?;
206    if let Some(migrations) = migrations {
207        db.run_migrations(Some(migrations)).await?;
208    }
209
210    Ok(())
211}
212
213impl<'a> Deref for Connection<'a> {
214    type Target = PooledConnection<'a, AsyncPgConnection>;
215
216    fn deref(&self) -> &Self::Target {
217        &self.0
218    }
219}
220
221impl DerefMut for Connection<'_> {
222    fn deref_mut(&mut self) -> &mut Self::Target {
223        &mut self.0
224    }
225}
226
227async fn pool(
228    database_url: Url,
229    args: DbArgs,
230    read_only: bool,
231) -> anyhow::Result<Pool<AsyncPgConnection>> {
232    let statement_timeout = args.statement_timeout();
233
234    // Build TLS configuration once
235    let tls_config = build_tls_config(args.tls_verify_cert, args.tls_ca_cert_path.clone())?;
236
237    let mut config = ManagerConfig::default();
238
239    config.custom_setup = Box::new(move |url| {
240        let tls_config = tls_config.clone();
241
242        async move {
243            let mut conn = establish_tls_connection(url, tls_config).await?;
244
245            if let Some(timeout) = statement_timeout {
246                diesel::sql_query(format!("SET statement_timeout = {}", timeout.as_millis()))
247                    .execute(&mut conn)
248                    .await
249                    .map_err(ConnectionError::CouldntSetupConfiguration)?;
250            }
251
252            if read_only {
253                diesel::sql_query("SET default_transaction_read_only = 'on'")
254                    .execute(&mut conn)
255                    .await
256                    .map_err(ConnectionError::CouldntSetupConfiguration)?;
257            }
258
259            Ok(conn)
260        }
261        .boxed()
262    });
263
264    let manager = AsyncDieselConnectionManager::new_with_config(database_url.as_str(), config);
265
266    Ok(Pool::builder()
267        .max_size(args.db_connection_pool_size)
268        .connection_timeout(args.connection_timeout())
269        .build(manager)
270        .await?)
271}
272
273/// Returns new migrations derived from the combination of provided migrations and migrations
274/// defined in this crate.
275pub fn merge_migrations(
276    migrations: Option<&'static EmbeddedMigrations>,
277) -> impl MigrationSource<Pg> + Send + Sync + 'static {
278    struct Migrations(Option<&'static EmbeddedMigrations>);
279    impl MigrationSource<Pg> for Migrations {
280        fn migrations(&self) -> diesel::migration::Result<Vec<Box<dyn Migration<Pg>>>> {
281            let mut migrations = MIGRATIONS.migrations()?;
282            if let Some(more_migrations) = self.0 {
283                migrations.extend(more_migrations.migrations()?);
284            }
285            Ok(migrations)
286        }
287    }
288
289    Migrations(migrations)
290}
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use diesel::prelude::QueryableByName;
295    use diesel_async::RunQueryDsl;
296
297    #[tokio::test]
298    async fn temp_db_smoketest() {
299        telemetry_subscribers::init_for_testing();
300        let db = temp::TempDb::new().unwrap();
301        let url = db.database().url();
302
303        info!(%url);
304        let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
305        let mut conn = db.connect().await.unwrap();
306
307        // Run a simple query to verify the db can properly be queried
308        let resp = diesel::sql_query("SELECT datname FROM pg_database")
309            .execute(&mut conn)
310            .await
311            .unwrap();
312
313        info!(?resp);
314    }
315
316    #[derive(Debug, QueryableByName)]
317    struct CountResult {
318        #[diesel(sql_type = diesel::sql_types::BigInt)]
319        cnt: i64,
320    }
321
322    #[tokio::test]
323    async fn test_reset_database_skip_migrations() {
324        let temp_db = temp::TempDb::new().unwrap();
325        let url = temp_db.database().url();
326
327        let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
328        let mut conn = db.connect().await.unwrap();
329        diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
330            .execute(&mut conn)
331            .await
332            .unwrap();
333        let cnt = diesel::sql_query(
334            "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
335        )
336        .get_result::<CountResult>(&mut conn)
337        .await
338        .unwrap();
339        assert_eq!(cnt.cnt, 1);
340
341        reset_database(url.clone(), DbArgs::default(), None)
342            .await
343            .unwrap();
344
345        let mut conn = db.connect().await.unwrap();
346        let cnt: CountResult = diesel::sql_query(
347            "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
348        )
349        .get_result(&mut conn)
350        .await
351        .unwrap();
352        assert_eq!(cnt.cnt, 0);
353    }
354
355    #[tokio::test]
356    async fn test_read_only() {
357        let temp_db = temp::TempDb::new().unwrap();
358        let url = temp_db.database().url();
359
360        let writer = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
361        let reader = Db::for_read(url.clone(), DbArgs::default()).await.unwrap();
362
363        {
364            // Create a table
365            let mut conn = writer.connect().await.unwrap();
366            diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
367                .execute(&mut conn)
368                .await
369                .unwrap();
370        }
371
372        {
373            // Try an insert into it using the read-only connection, which should fail
374            let mut conn = reader.connect().await.unwrap();
375            let result = diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
376                .execute(&mut conn)
377                .await;
378            assert!(result.is_err());
379        }
380
381        {
382            // Try and select from it using the read-only connection, which should succeed, but
383            // return no results.
384            let mut conn = reader.connect().await.unwrap();
385            let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
386                .get_result(&mut conn)
387                .await
388                .unwrap();
389            assert_eq!(cnt.cnt, 0);
390        }
391
392        {
393            // Then try to write to it using the write connection, which should succeed
394            let mut conn = writer.connect().await.unwrap();
395            diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
396                .execute(&mut conn)
397                .await
398                .unwrap();
399        }
400
401        {
402            // Finally, try to read from it using the read-only connection, which should now return
403            // results.
404            let mut conn = reader.connect().await.unwrap();
405            let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
406                .get_result(&mut conn)
407                .await
408                .unwrap();
409            assert_eq!(cnt.cnt, 1);
410        }
411    }
412
413    #[tokio::test]
414    async fn test_statement_timeout() {
415        let temp_db = temp::TempDb::new().unwrap();
416        let url = temp_db.database().url();
417
418        let reader = Db::for_read(
419            url.clone(),
420            DbArgs {
421                db_statement_timeout_ms: Some(200),
422                ..DbArgs::default()
423            },
424        )
425        .await
426        .unwrap();
427
428        {
429            // A simple query should not timeout
430            let mut conn = reader.connect().await.unwrap();
431            let cnt: CountResult = diesel::sql_query("SELECT 1::BIGINT AS cnt")
432                .get_result(&mut conn)
433                .await
434                .unwrap();
435
436            assert_eq!(cnt.cnt, 1);
437        }
438
439        {
440            // A query that waits a bit, which should cause a timeout
441            let mut conn = reader.connect().await.unwrap();
442            diesel::sql_query("SELECT PG_SLEEP(2), 1::BIGINT AS cnt")
443                .get_result::<CountResult>(&mut conn)
444                .await
445                .expect_err("This request should fail because of a timeout");
446        }
447    }
448}