sui_indexer/
db.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::database::Connection;
5use crate::errors::IndexerError;
6use crate::handlers::pruner::PrunableTable;
7use clap::Args;
8use diesel::QueryDsl;
9use diesel::migration::{Migration, MigrationSource, MigrationVersion};
10use diesel::pg::Pg;
11use diesel::prelude::QueryableByName;
12use diesel::table;
13use diesel_migrations::{EmbeddedMigrations, embed_migrations};
14use std::collections::{BTreeSet, HashSet};
15use std::time::Duration;
16use strum::IntoEnumIterator;
17use tracing::info;
18
19table! {
20    __diesel_schema_migrations (version) {
21        version -> VarChar,
22        run_on -> Timestamp,
23    }
24}
25
26const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
27
28#[derive(Args, Debug, Clone)]
29pub struct ConnectionPoolConfig {
30    #[arg(long, default_value_t = 100)]
31    #[arg(env = "DB_POOL_SIZE")]
32    pub pool_size: u32,
33    #[arg(long, value_parser = parse_duration, default_value = "30")]
34    #[arg(env = "DB_CONNECTION_TIMEOUT")]
35    pub connection_timeout: Duration,
36    #[arg(long, value_parser = parse_duration, default_value = "3600")]
37    #[arg(env = "DB_STATEMENT_TIMEOUT")]
38    pub statement_timeout: Duration,
39}
40
41fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
42    let seconds = arg.parse()?;
43    Ok(std::time::Duration::from_secs(seconds))
44}
45
46impl ConnectionPoolConfig {
47    const DEFAULT_POOL_SIZE: u32 = 100;
48    const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
49    const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
50
51    pub(crate) fn connection_config(&self) -> ConnectionConfig {
52        ConnectionConfig {
53            statement_timeout: self.statement_timeout,
54            read_only: false,
55        }
56    }
57
58    pub fn set_pool_size(&mut self, size: u32) {
59        self.pool_size = size;
60    }
61
62    pub fn set_connection_timeout(&mut self, timeout: Duration) {
63        self.connection_timeout = timeout;
64    }
65
66    pub fn set_statement_timeout(&mut self, timeout: Duration) {
67        self.statement_timeout = timeout;
68    }
69}
70
71impl Default for ConnectionPoolConfig {
72    fn default() -> Self {
73        Self {
74            pool_size: Self::DEFAULT_POOL_SIZE,
75            connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
76            statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
77        }
78    }
79}
80
81#[derive(Debug, Clone, Copy)]
82pub struct ConnectionConfig {
83    pub statement_timeout: Duration,
84    pub read_only: bool,
85}
86
87/// Checks that the local migration scripts is a prefix of the records in the database.
88/// This allows us run migration scripts against a DB at anytime, without worrying about
89/// existing readers fail over.
90/// We do however need to make sure that whenever we are deploying a new version of either reader or writer,
91/// we must first run migration scripts to ensure that there is not more local scripts than in the DB record.
92pub async fn check_db_migration_consistency(conn: &mut Connection<'_>) -> Result<(), IndexerError> {
93    info!("Starting compatibility check");
94    let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
95        IndexerError::DbMigrationError(format!(
96            "Failed to fetch local migrations from schema: {err}"
97        ))
98    })?;
99    let local_migrations: Vec<_> = migrations
100        .into_iter()
101        .map(|m| m.name().version().as_owned())
102        .collect();
103    check_db_migration_consistency_impl(conn, local_migrations).await?;
104    info!("Compatibility check passed");
105    Ok(())
106}
107
108async fn check_db_migration_consistency_impl(
109    conn: &mut Connection<'_>,
110    local_migrations: Vec<MigrationVersion<'_>>,
111) -> Result<(), IndexerError> {
112    use diesel_async::RunQueryDsl;
113
114    // Unfortunately we cannot call applied_migrations() directly on the connection,
115    // since it implicitly creates the __diesel_schema_migrations table if it doesn't exist,
116    // which is a write operation that we don't want to do in this function.
117    let applied_migrations: BTreeSet<MigrationVersion<'_>> = BTreeSet::from_iter(
118        __diesel_schema_migrations::table
119            .select(__diesel_schema_migrations::version)
120            .load(conn)
121            .await?,
122    );
123
124    // We check that the local migrations is a subset of the applied migrations.
125    let unapplied_migrations: Vec<_> = local_migrations
126        .into_iter()
127        .filter(|m| !applied_migrations.contains(m))
128        .collect();
129
130    if unapplied_migrations.is_empty() {
131        return Ok(());
132    }
133
134    Err(IndexerError::DbMigrationError(format!(
135        "This binary expected the following migrations to have been run, and they were not: {:?}",
136        unapplied_migrations
137    )))
138}
139
140/// Check that prunable tables exist in the database.
141pub async fn check_prunable_tables_valid(conn: &mut Connection<'_>) -> Result<(), IndexerError> {
142    info!("Starting compatibility check");
143
144    use diesel_async::RunQueryDsl;
145
146    let select_parent_tables = r#"
147    SELECT c.relname AS table_name
148    FROM pg_class c
149    JOIN pg_namespace n ON n.oid = c.relnamespace
150    LEFT JOIN pg_partitioned_table pt ON pt.partrelid = c.oid
151    WHERE c.relkind IN ('r', 'p')  -- 'r' for regular tables, 'p' for partitioned tables
152        AND n.nspname = 'public'
153        AND (
154            pt.partrelid IS NOT NULL  -- This is a partitioned (parent) table
155            OR NOT EXISTS (  -- This is not a partition (child table)
156                SELECT 1
157                FROM pg_inherits i
158                WHERE i.inhrelid = c.oid
159            )
160        );
161    "#;
162
163    #[derive(QueryableByName)]
164    struct TableName {
165        #[diesel(sql_type = diesel::sql_types::Text)]
166        table_name: String,
167    }
168
169    let result: Vec<TableName> = diesel::sql_query(select_parent_tables)
170        .load(conn)
171        .await
172        .map_err(|e| IndexerError::DbMigrationError(format!("Failed to fetch tables: {e}")))?;
173
174    let parent_tables_from_db: HashSet<_> = result.into_iter().map(|t| t.table_name).collect();
175
176    for key in PrunableTable::iter() {
177        if !parent_tables_from_db.contains(key.as_ref()) {
178            return Err(IndexerError::GenericError(format!(
179                "Invalid retention policy override provided for table {}: does not exist in the database",
180                key
181            )));
182        }
183    }
184
185    info!("Compatibility check passed");
186    Ok(())
187}
188
189pub use setup_postgres::{reset_database, run_migrations};
190
191pub mod setup_postgres {
192    use crate::{database::Connection, db::MIGRATIONS};
193    use anyhow::anyhow;
194    use diesel_async::RunQueryDsl;
195    use tracing::info;
196
197    pub async fn reset_database(mut conn: Connection<'static>) -> Result<(), anyhow::Error> {
198        info!("Resetting PG database ...");
199        clear_database(&mut conn).await?;
200        run_migrations(conn).await?;
201        info!("Reset database complete.");
202        Ok(())
203    }
204
205    pub async fn clear_database(conn: &mut Connection<'static>) -> Result<(), anyhow::Error> {
206        info!("Clearing the database...");
207        let drop_all_tables = "
208        DO $$ DECLARE
209            r RECORD;
210        BEGIN
211        FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
212            LOOP
213                EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
214            END LOOP;
215        END $$;";
216        diesel::sql_query(drop_all_tables).execute(conn).await?;
217        info!("Dropped all tables.");
218
219        let drop_all_procedures = "
220        DO $$ DECLARE
221            r RECORD;
222        BEGIN
223            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
224                      FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
225                      WHERE ns.nspname = 'public' AND prokind = 'p')
226            LOOP
227                EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
228            END LOOP;
229        END $$;";
230        diesel::sql_query(drop_all_procedures).execute(conn).await?;
231        info!("Dropped all procedures.");
232
233        let drop_all_functions = "
234        DO $$ DECLARE
235            r RECORD;
236        BEGIN
237            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
238                      FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
239                      WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
240            LOOP
241                EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
242            END LOOP;
243        END $$;";
244        diesel::sql_query(drop_all_functions).execute(conn).await?;
245        info!("Database cleared.");
246        Ok(())
247    }
248
249    pub async fn run_migrations(conn: Connection<'static>) -> Result<(), anyhow::Error> {
250        info!("Running migrations ...");
251        conn.run_pending_migrations(MIGRATIONS)
252            .await
253            .map_err(|e| anyhow!("Failed to run migrations {e}"))?;
254        Ok(())
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use crate::database::{Connection, ConnectionPool};
261    use crate::db::{
262        ConnectionPoolConfig, MIGRATIONS, check_db_migration_consistency,
263        check_db_migration_consistency_impl, reset_database,
264    };
265    use diesel::migration::{Migration, MigrationSource};
266    use diesel::pg::Pg;
267    use diesel_migrations::MigrationHarness;
268    use sui_pg_db::temp::TempDb;
269
270    // Check that the migration records in the database created from the local schema
271    // pass the consistency check.
272    #[tokio::test]
273    async fn db_migration_consistency_smoke_test() {
274        let database = TempDb::new().unwrap();
275        let pool = ConnectionPool::new(
276            database.database().url().to_owned(),
277            ConnectionPoolConfig {
278                pool_size: 2,
279                ..Default::default()
280            },
281        )
282        .await
283        .unwrap();
284
285        reset_database(pool.dedicated_connection().await.unwrap())
286            .await
287            .unwrap();
288        check_db_migration_consistency(&mut pool.get().await.unwrap())
289            .await
290            .unwrap();
291    }
292
293    #[tokio::test]
294    async fn db_migration_consistency_non_prefix_test() {
295        let database = TempDb::new().unwrap();
296        let pool = ConnectionPool::new(
297            database.database().url().to_owned(),
298            ConnectionPoolConfig {
299                pool_size: 2,
300                ..Default::default()
301            },
302        )
303        .await
304        .unwrap();
305
306        reset_database(pool.dedicated_connection().await.unwrap())
307            .await
308            .unwrap();
309        let mut connection = pool.get().await.unwrap();
310
311        let mut sync_connection_wrapper =
312            diesel_async::async_connection_wrapper::AsyncConnectionWrapper::<Connection>::from(
313                pool.dedicated_connection().await.unwrap(),
314            );
315
316        tokio::task::spawn_blocking(move || {
317            sync_connection_wrapper
318                .revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
319                .unwrap();
320        })
321        .await
322        .unwrap();
323        // Local migrations is one record more than the applied migrations.
324        // This will fail the consistency check since it's not a prefix.
325        assert!(
326            check_db_migration_consistency(&mut connection)
327                .await
328                .is_err()
329        );
330
331        pool.dedicated_connection()
332            .await
333            .unwrap()
334            .run_pending_migrations(MIGRATIONS)
335            .await
336            .unwrap();
337        // After running pending migrations they should be consistent.
338        check_db_migration_consistency(&mut connection)
339            .await
340            .unwrap();
341    }
342
343    #[tokio::test]
344    async fn db_migration_consistency_prefix_test() {
345        let database = TempDb::new().unwrap();
346        let pool = ConnectionPool::new(
347            database.database().url().to_owned(),
348            ConnectionPoolConfig {
349                pool_size: 2,
350                ..Default::default()
351            },
352        )
353        .await
354        .unwrap();
355
356        reset_database(pool.dedicated_connection().await.unwrap())
357            .await
358            .unwrap();
359
360        let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
361        let mut local_migrations: Vec<_> = migrations.iter().map(|m| m.name().version()).collect();
362        local_migrations.pop();
363        // Local migrations is one record less than the applied migrations.
364        // This should pass the consistency check since it's still a prefix.
365        check_db_migration_consistency_impl(&mut pool.get().await.unwrap(), local_migrations)
366            .await
367            .unwrap();
368    }
369
370    #[tokio::test]
371    async fn db_migration_consistency_subset_test() {
372        let database = TempDb::new().unwrap();
373        let pool = ConnectionPool::new(
374            database.database().url().to_owned(),
375            ConnectionPoolConfig {
376                pool_size: 2,
377                ..Default::default()
378            },
379        )
380        .await
381        .unwrap();
382
383        reset_database(pool.dedicated_connection().await.unwrap())
384            .await
385            .unwrap();
386
387        let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
388        let mut local_migrations: Vec<_> = migrations.iter().map(|m| m.name().version()).collect();
389        local_migrations.remove(2);
390
391        // Local migrations are missing one record compared to the applied migrations, which should
392        // still be okay.
393        check_db_migration_consistency_impl(&mut pool.get().await.unwrap(), local_migrations)
394            .await
395            .unwrap();
396    }
397
398    #[tokio::test]
399    async fn temp_db_smoketest() {
400        use crate::database::Connection;
401        use diesel_async::RunQueryDsl;
402        use sui_pg_db::temp::TempDb;
403
404        telemetry_subscribers::init_for_testing();
405
406        let db = TempDb::new().unwrap();
407        let url = db.database().url();
408        println!("url: {}", url.as_str());
409        let mut connection = Connection::dedicated(url).await.unwrap();
410
411        // Run a simple query to verify the db can properly be queried
412        let resp = diesel::sql_query("SELECT datname FROM pg_database")
413            .execute(&mut connection)
414            .await
415            .unwrap();
416        println!("resp: {:?}", resp);
417    }
418}