sui_pg_db/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::ops::Deref;
5use std::ops::DerefMut;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::Duration;
9
10use anyhow::anyhow;
11use diesel::ConnectionError;
12use diesel::migration::Migration;
13use diesel::migration::MigrationSource;
14use diesel::migration::MigrationVersion;
15use diesel::pg::Pg;
16use diesel_async::RunQueryDsl;
17use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
18use diesel_async::pooled_connection::AsyncDieselConnectionManager;
19use diesel_async::pooled_connection::ManagerConfig;
20use diesel_async::pooled_connection::bb8::Pool;
21use diesel_async::pooled_connection::bb8::PooledConnection;
22use diesel_migrations::EmbeddedMigrations;
23use diesel_migrations::embed_migrations;
24use futures::FutureExt;
25use prometheus::Registry;
26use tracing::info;
27use url::Url;
28
29use crate::tls::AsyncPgConnectionWithId;
30use crate::tls::build_tls_config;
31use crate::tls::establish_tls_connection;
32
33mod metrics;
34mod model;
35pub mod query;
36pub mod schema;
37pub mod store;
38pub mod temp;
39mod tls;
40
41use crate::metrics::PoolMetrics;
42pub use sui_field_count::FieldCount;
43pub use sui_sql_macro::sql;
44
45pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
46
47#[derive(clap::Args, Debug, Clone)]
48pub struct DbArgs {
49    /// Number of connections to keep in the pool.
50    #[arg(long, default_value_t = Self::default().db_connection_pool_size)]
51    pub db_connection_pool_size: u32,
52
53    /// Time spent waiting for a connection from the pool to become available, in milliseconds.
54    #[arg(long, default_value_t = Self::default().db_connection_timeout_ms)]
55    pub db_connection_timeout_ms: u64,
56
57    #[arg(long)]
58    /// Time spent waiting for statements to complete, in milliseconds.
59    pub db_statement_timeout_ms: Option<u64>,
60
61    #[arg(long)]
62    /// Enable server certificate verification. By default, this is set to false to match the
63    /// default behavior of libpq.
64    pub tls_verify_cert: bool,
65
66    #[arg(long)]
67    /// Path to a custom CA certificate to use for server certificate verification.
68    pub tls_ca_cert_path: Option<PathBuf>,
69}
70
71#[derive(Clone)]
72pub struct Db {
73    pool: Pool<AsyncPgConnectionWithId>,
74    pool_metrics: Option<Arc<PoolMetrics>>,
75}
76
77// Ensures that unacquired_canceled is incremented if neither acquired nor unacquired_error are so
78// that one of acquired, unacquired_error, unacquired_canceled is incremented for every requested
79// connection. This is used to be able to calculate the number of pending connections:
80// pending = requested - (acquired + unacquired_error + unacquired_canceled)
81struct CancelGuard<'m>(Option<&'m PoolMetrics>);
82
83/// Wrapper struct over the remote `PooledConnection` type for dealing with the `Store` trait.
84pub struct Connection<'a>(PooledConnection<'a, AsyncPgConnectionWithId>);
85
86impl DbArgs {
87    pub fn connection_timeout(&self) -> Duration {
88        Duration::from_millis(self.db_connection_timeout_ms)
89    }
90
91    pub fn statement_timeout(&self) -> Option<Duration> {
92        self.db_statement_timeout_ms.map(Duration::from_millis)
93    }
94}
95
96impl Db {
97    /// Construct a new DB connection pool talking to the database at `database_url` that supports
98    /// write and reads. Instances of [Db] can be cloned to share access to the same pool.
99    pub async fn for_write(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
100        Self::new(database_url, config, false).await
101    }
102
103    /// Construct a new DB connection pool talking to the database at `database_url` that defaults
104    /// to read-only transactions. Instances of [Db] can be cloned to share access to the same
105    /// pool.
106    pub async fn for_read(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
107        Self::new(database_url, config, true).await
108    }
109
110    async fn new(database_url: Url, db_args: DbArgs, read_only: bool) -> anyhow::Result<Self> {
111        Ok(Db {
112            pool: pool(database_url, db_args, read_only).await?,
113            pool_metrics: None,
114        })
115    }
116
117    pub fn register_metrics(
118        mut self,
119        prefix: Option<&str>,
120        registry: &Registry,
121    ) -> anyhow::Result<Self> {
122        let pool_metrics = PoolMetrics::new(prefix, registry)?;
123        self.pool_metrics = Some(pool_metrics);
124        Ok(self)
125    }
126
127    /// Retrieves a connection from the pool. Can fail with a timeout if a connection cannot be
128    /// established before the [DbArgs::connection_timeout] has elapsed.
129    pub async fn connect(&self) -> anyhow::Result<Connection<'_>> {
130        if let Some(pool_metrics) = &self.pool_metrics {
131            let guard = CancelGuard::request(pool_metrics);
132            match self.pool.get().await {
133                Ok(c) => {
134                    guard.acquired();
135                    Ok(Connection(c))
136                }
137                Err(e) => {
138                    guard.unacquired_error();
139                    Err(e.into())
140                }
141            }
142        } else {
143            Ok(Connection(self.pool.get().await?))
144        }
145    }
146
147    /// Statistics about the connection pool
148    pub fn state(&self) -> bb8::State {
149        self.pool.state()
150    }
151
152    pub fn pool_metrics(&self) -> Option<&PoolMetrics> {
153        self.pool_metrics.as_ref().map(|p| p.as_ref())
154    }
155
156    async fn clear_database(&self) -> anyhow::Result<()> {
157        info!("Clearing the database...");
158        let mut conn = self.connect().await?;
159        let drop_all_tables = "
160        DO $$ DECLARE
161            r RECORD;
162        BEGIN
163        FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
164            LOOP
165                EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
166            END LOOP;
167        END $$;";
168        diesel::sql_query(drop_all_tables)
169            .execute(&mut conn)
170            .await?;
171        info!("Dropped all tables.");
172
173        let drop_all_procedures = "
174        DO $$ DECLARE
175            r RECORD;
176        BEGIN
177            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
178                      FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
179                      WHERE ns.nspname = 'public' AND prokind = 'p')
180            LOOP
181                EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
182            END LOOP;
183        END $$;";
184        diesel::sql_query(drop_all_procedures)
185            .execute(&mut conn)
186            .await?;
187        info!("Dropped all procedures.");
188
189        let drop_all_functions = "
190        DO $$ DECLARE
191            r RECORD;
192        BEGIN
193            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
194                      FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
195                      WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
196            LOOP
197                EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
198            END LOOP;
199        END $$;";
200        diesel::sql_query(drop_all_functions)
201            .execute(&mut conn)
202            .await?;
203        info!("Database cleared.");
204        Ok(())
205    }
206
207    /// Run migrations on the database. Use Diesel's `embed_migrations!` macro to generate the
208    /// `migrations` parameter for your indexer.
209    pub async fn run_migrations(
210        &self,
211        migrations: Option<&'static EmbeddedMigrations>,
212    ) -> anyhow::Result<Vec<MigrationVersion<'static>>> {
213        use diesel_migrations::MigrationHarness;
214
215        let merged_migrations = merge_migrations(migrations);
216
217        info!("Running migrations ...");
218        let conn = self.pool.dedicated_connection().await?;
219        let mut wrapper: AsyncConnectionWrapper<AsyncPgConnectionWithId> =
220            AsyncConnectionWrapper::from(conn);
221
222        let finished_migrations = tokio::task::spawn_blocking(move || {
223            wrapper
224                .run_pending_migrations(merged_migrations)
225                .map(|versions| versions.iter().map(MigrationVersion::as_owned).collect())
226        })
227        .await?
228        .map_err(|e| anyhow!("Failed to run migrations: {:?}", e))?;
229
230        info!("Migrations complete.");
231        Ok(finished_migrations)
232    }
233}
234
235impl<'m> CancelGuard<'m> {
236    fn request(stats: &'m PoolMetrics) -> Self {
237        stats.requested.inc();
238        Self(Some(stats))
239    }
240
241    fn acquired(mut self) {
242        if let Some(m) = self.0.take() {
243            m.acquired.inc()
244        }
245    }
246
247    fn unacquired_error(mut self) {
248        if let Some(m) = self.0.take() {
249            m.unacquired_error.inc()
250        }
251    }
252}
253
254impl Default for DbArgs {
255    fn default() -> Self {
256        Self {
257            db_connection_pool_size: 100,
258            db_connection_timeout_ms: 60_000,
259            db_statement_timeout_ms: None,
260            tls_verify_cert: false,
261            tls_ca_cert_path: None,
262        }
263    }
264}
265
266impl<'m> Drop for CancelGuard<'m> {
267    fn drop(&mut self) {
268        if let Some(m) = self.0.take() {
269            m.unacquired_canceled.inc()
270        }
271    }
272}
273
274/// Drop all tables, and re-run migrations if supplied.
275pub async fn reset_database(
276    database_url: Url,
277    db_args: DbArgs,
278    migrations: Option<&'static EmbeddedMigrations>,
279) -> anyhow::Result<()> {
280    let db = Db::for_write(database_url, db_args).await?;
281    db.clear_database().await?;
282    if let Some(migrations) = migrations {
283        db.run_migrations(Some(migrations)).await?;
284    }
285
286    Ok(())
287}
288
289impl<'a> Deref for Connection<'a> {
290    type Target = PooledConnection<'a, AsyncPgConnectionWithId>;
291
292    fn deref(&self) -> &Self::Target {
293        &self.0
294    }
295}
296
297impl DerefMut for Connection<'_> {
298    fn deref_mut(&mut self) -> &mut Self::Target {
299        &mut self.0
300    }
301}
302
303async fn pool(
304    database_url: Url,
305    args: DbArgs,
306    read_only: bool,
307) -> anyhow::Result<Pool<AsyncPgConnectionWithId>> {
308    let statement_timeout = args.statement_timeout();
309
310    // Build TLS configuration once
311    let tls_config = build_tls_config(args.tls_verify_cert, args.tls_ca_cert_path.clone())?;
312
313    let mut config = ManagerConfig::default();
314
315    config.custom_setup = Box::new(move |url| {
316        let tls_config = tls_config.clone();
317
318        async move {
319            let mut conn = establish_tls_connection(url, tls_config).await?;
320
321            if let Some(timeout) = statement_timeout {
322                diesel::sql_query(format!("SET statement_timeout = {}", timeout.as_millis()))
323                    .execute(&mut conn)
324                    .await
325                    .map_err(ConnectionError::CouldntSetupConfiguration)?;
326            }
327
328            if read_only {
329                diesel::sql_query("SET default_transaction_read_only = 'on'")
330                    .execute(&mut conn)
331                    .await
332                    .map_err(ConnectionError::CouldntSetupConfiguration)?;
333            }
334
335            Ok(conn)
336        }
337        .boxed()
338    });
339
340    let manager = AsyncDieselConnectionManager::new_with_config(database_url.as_str(), config);
341
342    Ok(Pool::builder()
343        .max_size(args.db_connection_pool_size)
344        .connection_timeout(args.connection_timeout())
345        .build(manager)
346        .await?)
347}
348
349/// Returns new migrations derived from the combination of provided migrations and migrations
350/// defined in this crate.
351pub fn merge_migrations(
352    migrations: Option<&'static EmbeddedMigrations>,
353) -> impl MigrationSource<Pg> + Send + Sync + 'static {
354    struct Migrations(Option<&'static EmbeddedMigrations>);
355    impl MigrationSource<Pg> for Migrations {
356        fn migrations(&self) -> diesel::migration::Result<Vec<Box<dyn Migration<Pg>>>> {
357            let mut migrations = MIGRATIONS.migrations()?;
358            if let Some(more_migrations) = self.0 {
359                migrations.extend(more_migrations.migrations()?);
360            }
361            Ok(migrations)
362        }
363    }
364
365    Migrations(migrations)
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::temp::TempDb;
372    use anyhow::Error;
373    use diesel::prelude::QueryableByName;
374    use tokio::spawn;
375    use tokio::time::timeout;
376
377    struct MetricTest {
378        db: Db,
379        _temp_db: TempDb,
380        _registry: Registry,
381    }
382
383    impl MetricTest {
384        async fn new(db_connection_timeout: Duration) -> Arc<Self> {
385            let temp_db = TempDb::new().unwrap();
386            let url = temp_db.database().url();
387            let db_args = DbArgs {
388                db_connection_pool_size: 1,
389                db_connection_timeout_ms: db_connection_timeout.as_millis() as u64,
390                ..Default::default()
391            };
392            let registry = Registry::new();
393            let db = Db::for_read(url.clone(), db_args)
394                .await
395                .unwrap()
396                .register_metrics(None, &registry)
397                .unwrap();
398            Arc::new(Self {
399                db,
400                _temp_db: temp_db,
401                _registry: registry,
402            })
403        }
404
405        fn pool_metrics(&self) -> &PoolMetrics {
406            self.db.pool_metrics().unwrap()
407        }
408
409        async fn select_sleep(&self, duration: Duration) -> Result<(), Error> {
410            let mut conn = self.db.connect().await?;
411            let duration_s = duration.as_secs_f64();
412            diesel::sql_query(format!("SELECT pg_sleep({duration_s});"))
413                .execute(&mut conn)
414                .await?;
415            Ok(())
416        }
417    }
418
419    #[tokio::test]
420    async fn temp_db_smoketest() {
421        telemetry_subscribers::init_for_testing();
422        let db = TempDb::new().unwrap();
423        let url = db.database().url();
424
425        info!(%url);
426        let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
427        let mut conn = db.connect().await.unwrap();
428
429        // Run a simple query to verify the db can properly be queried
430        let resp = diesel::sql_query("SELECT datname FROM pg_database")
431            .execute(&mut conn)
432            .await
433            .unwrap();
434
435        info!(?resp);
436    }
437
438    #[derive(Debug, QueryableByName)]
439    struct CountResult {
440        #[diesel(sql_type = diesel::sql_types::BigInt)]
441        cnt: i64,
442    }
443
444    #[tokio::test]
445    async fn test_reset_database_skip_migrations() {
446        let temp_db = TempDb::new().unwrap();
447        let url = temp_db.database().url();
448
449        let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
450        let mut conn = db.connect().await.unwrap();
451        diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
452            .execute(&mut conn)
453            .await
454            .unwrap();
455        let cnt = diesel::sql_query(
456            "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
457        )
458        .get_result::<CountResult>(&mut conn)
459        .await
460        .unwrap();
461        assert_eq!(cnt.cnt, 1);
462
463        reset_database(url.clone(), DbArgs::default(), None)
464            .await
465            .unwrap();
466
467        let mut conn = db.connect().await.unwrap();
468        let cnt: CountResult = diesel::sql_query(
469            "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
470        )
471        .get_result(&mut conn)
472        .await
473        .unwrap();
474        assert_eq!(cnt.cnt, 0);
475    }
476
477    #[tokio::test]
478    async fn test_read_only() {
479        let temp_db = TempDb::new().unwrap();
480        let url = temp_db.database().url();
481
482        let writer = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
483        let reader = Db::for_read(url.clone(), DbArgs::default()).await.unwrap();
484
485        {
486            // Create a table
487            let mut conn = writer.connect().await.unwrap();
488            diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
489                .execute(&mut conn)
490                .await
491                .unwrap();
492        }
493
494        {
495            // Try an insert into it using the read-only connection, which should fail
496            let mut conn = reader.connect().await.unwrap();
497            let result = diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
498                .execute(&mut conn)
499                .await;
500            assert!(result.is_err());
501        }
502
503        {
504            // Try and select from it using the read-only connection, which should succeed, but
505            // return no results.
506            let mut conn = reader.connect().await.unwrap();
507            let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
508                .get_result(&mut conn)
509                .await
510                .unwrap();
511            assert_eq!(cnt.cnt, 0);
512        }
513
514        {
515            // Then try to write to it using the write connection, which should succeed
516            let mut conn = writer.connect().await.unwrap();
517            diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
518                .execute(&mut conn)
519                .await
520                .unwrap();
521        }
522
523        {
524            // Finally, try to read from it using the read-only connection, which should now return
525            // results.
526            let mut conn = reader.connect().await.unwrap();
527            let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
528                .get_result(&mut conn)
529                .await
530                .unwrap();
531            assert_eq!(cnt.cnt, 1);
532        }
533    }
534
535    #[tokio::test]
536    async fn test_statement_timeout() {
537        let temp_db = TempDb::new().unwrap();
538        let url = temp_db.database().url();
539
540        let reader = Db::for_read(
541            url.clone(),
542            DbArgs {
543                db_statement_timeout_ms: Some(200),
544                ..DbArgs::default()
545            },
546        )
547        .await
548        .unwrap();
549
550        {
551            // A simple query should not timeout
552            let mut conn = reader.connect().await.unwrap();
553            let cnt: CountResult = diesel::sql_query("SELECT 1::BIGINT AS cnt")
554                .get_result(&mut conn)
555                .await
556                .unwrap();
557
558            assert_eq!(cnt.cnt, 1);
559        }
560
561        {
562            // A query that waits a bit, which should cause a timeout
563            let mut conn = reader.connect().await.unwrap();
564            diesel::sql_query("SELECT PG_SLEEP(2), 1::BIGINT AS cnt")
565                .get_result::<CountResult>(&mut conn)
566                .await
567                .expect_err("This request should fail because of a timeout");
568        }
569    }
570
571    #[tokio::test]
572    async fn test_unacquired_error() {
573        let db_connection_timeout = Duration::from_millis(500);
574        let metric_test = MetricTest::new(db_connection_timeout).await;
575
576        let metric_test_clone = metric_test.clone();
577        let task1 = spawn(async move {
578            metric_test_clone
579                .select_sleep(db_connection_timeout + Duration::from_millis(500))
580                .await
581        });
582        // 1st task takes longer than db_connection_timeout so 2nd task times out
583        let metric_test_clone = metric_test.clone();
584        let task2 = spawn(async move {
585            // sleep duration does not matter because it will never execute
586            metric_test_clone.select_sleep(Duration::ZERO).await
587        });
588        assert!(task1.await.unwrap().is_ok());
589        assert!(task2.await.unwrap().is_err());
590
591        let PoolMetrics {
592            requested,
593            acquired,
594            unacquired_error,
595            unacquired_canceled,
596        } = metric_test.pool_metrics();
597        assert_eq!(requested.get(), 2);
598        assert_eq!(acquired.get(), 1);
599        assert_eq!(unacquired_error.get(), 1);
600        assert_eq!(unacquired_canceled.get(), 0);
601    }
602
603    #[tokio::test]
604    async fn test_unacquired_canceled() {
605        let task_timeout = Duration::from_millis(500);
606        let sleep_timeout = task_timeout + Duration::from_millis(500);
607        let db_connection_timeout = sleep_timeout + Duration::from_millis(500);
608        let metric_test = MetricTest::new(db_connection_timeout).await;
609
610        let metric_test_clone = metric_test.clone();
611        let task1 = spawn(async move {
612            metric_test_clone
613                .select_sleep(db_connection_timeout + Duration::from_millis(500))
614                .await
615        });
616        let metric_test_clone = metric_test.clone();
617        // 1st task takes longer than task_timeout so 2nd task times out
618        let task2 = spawn(async move {
619            // sleep duration does not matter because it will never execute
620            timeout(task_timeout, metric_test_clone.select_sleep(Duration::ZERO)).await
621        });
622        assert!(task1.await.unwrap().is_ok());
623        assert!(task2.await.unwrap().is_err());
624
625        let PoolMetrics {
626            requested,
627            acquired,
628            unacquired_error,
629            unacquired_canceled,
630        } = metric_test.pool_metrics();
631        assert_eq!(requested.get(), 2);
632        assert_eq!(acquired.get(), 1);
633        assert_eq!(unacquired_error.get(), 0);
634        assert_eq!(unacquired_canceled.get(), 1);
635    }
636}