1use crate::database::Connection;
5use crate::errors::IndexerError;
6use crate::handlers::pruner::PrunableTable;
7use clap::Args;
8use diesel::QueryDsl;
9use diesel::migration::{Migration, MigrationSource, MigrationVersion};
10use diesel::pg::Pg;
11use diesel::prelude::QueryableByName;
12use diesel::table;
13use diesel_migrations::{EmbeddedMigrations, embed_migrations};
14use std::collections::{BTreeSet, HashSet};
15use std::time::Duration;
16use strum::IntoEnumIterator;
17use tracing::info;
18
19table! {
20 __diesel_schema_migrations (version) {
21 version -> VarChar,
22 run_on -> Timestamp,
23 }
24}
25
26const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
27
28#[derive(Args, Debug, Clone)]
29pub struct ConnectionPoolConfig {
30 #[arg(long, default_value_t = 100)]
31 #[arg(env = "DB_POOL_SIZE")]
32 pub pool_size: u32,
33 #[arg(long, value_parser = parse_duration, default_value = "30")]
34 #[arg(env = "DB_CONNECTION_TIMEOUT")]
35 pub connection_timeout: Duration,
36 #[arg(long, value_parser = parse_duration, default_value = "3600")]
37 #[arg(env = "DB_STATEMENT_TIMEOUT")]
38 pub statement_timeout: Duration,
39}
40
41fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
42 let seconds = arg.parse()?;
43 Ok(std::time::Duration::from_secs(seconds))
44}
45
46impl ConnectionPoolConfig {
47 const DEFAULT_POOL_SIZE: u32 = 100;
48 const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
49 const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
50
51 pub(crate) fn connection_config(&self) -> ConnectionConfig {
52 ConnectionConfig {
53 statement_timeout: self.statement_timeout,
54 read_only: false,
55 }
56 }
57
58 pub fn set_pool_size(&mut self, size: u32) {
59 self.pool_size = size;
60 }
61
62 pub fn set_connection_timeout(&mut self, timeout: Duration) {
63 self.connection_timeout = timeout;
64 }
65
66 pub fn set_statement_timeout(&mut self, timeout: Duration) {
67 self.statement_timeout = timeout;
68 }
69}
70
71impl Default for ConnectionPoolConfig {
72 fn default() -> Self {
73 Self {
74 pool_size: Self::DEFAULT_POOL_SIZE,
75 connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
76 statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Copy)]
82pub struct ConnectionConfig {
83 pub statement_timeout: Duration,
84 pub read_only: bool,
85}
86
87pub async fn check_db_migration_consistency(conn: &mut Connection<'_>) -> Result<(), IndexerError> {
93 info!("Starting compatibility check");
94 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
95 IndexerError::DbMigrationError(format!(
96 "Failed to fetch local migrations from schema: {err}"
97 ))
98 })?;
99 let local_migrations: Vec<_> = migrations
100 .into_iter()
101 .map(|m| m.name().version().as_owned())
102 .collect();
103 check_db_migration_consistency_impl(conn, local_migrations).await?;
104 info!("Compatibility check passed");
105 Ok(())
106}
107
108async fn check_db_migration_consistency_impl(
109 conn: &mut Connection<'_>,
110 local_migrations: Vec<MigrationVersion<'_>>,
111) -> Result<(), IndexerError> {
112 use diesel_async::RunQueryDsl;
113
114 let applied_migrations: BTreeSet<MigrationVersion<'_>> = BTreeSet::from_iter(
118 __diesel_schema_migrations::table
119 .select(__diesel_schema_migrations::version)
120 .load(conn)
121 .await?,
122 );
123
124 let unapplied_migrations: Vec<_> = local_migrations
126 .into_iter()
127 .filter(|m| !applied_migrations.contains(m))
128 .collect();
129
130 if unapplied_migrations.is_empty() {
131 return Ok(());
132 }
133
134 Err(IndexerError::DbMigrationError(format!(
135 "This binary expected the following migrations to have been run, and they were not: {:?}",
136 unapplied_migrations
137 )))
138}
139
140pub async fn check_prunable_tables_valid(conn: &mut Connection<'_>) -> Result<(), IndexerError> {
142 info!("Starting compatibility check");
143
144 use diesel_async::RunQueryDsl;
145
146 let select_parent_tables = r#"
147 SELECT c.relname AS table_name
148 FROM pg_class c
149 JOIN pg_namespace n ON n.oid = c.relnamespace
150 LEFT JOIN pg_partitioned_table pt ON pt.partrelid = c.oid
151 WHERE c.relkind IN ('r', 'p') -- 'r' for regular tables, 'p' for partitioned tables
152 AND n.nspname = 'public'
153 AND (
154 pt.partrelid IS NOT NULL -- This is a partitioned (parent) table
155 OR NOT EXISTS ( -- This is not a partition (child table)
156 SELECT 1
157 FROM pg_inherits i
158 WHERE i.inhrelid = c.oid
159 )
160 );
161 "#;
162
163 #[derive(QueryableByName)]
164 struct TableName {
165 #[diesel(sql_type = diesel::sql_types::Text)]
166 table_name: String,
167 }
168
169 let result: Vec<TableName> = diesel::sql_query(select_parent_tables)
170 .load(conn)
171 .await
172 .map_err(|e| IndexerError::DbMigrationError(format!("Failed to fetch tables: {e}")))?;
173
174 let parent_tables_from_db: HashSet<_> = result.into_iter().map(|t| t.table_name).collect();
175
176 for key in PrunableTable::iter() {
177 if !parent_tables_from_db.contains(key.as_ref()) {
178 return Err(IndexerError::GenericError(format!(
179 "Invalid retention policy override provided for table {}: does not exist in the database",
180 key
181 )));
182 }
183 }
184
185 info!("Compatibility check passed");
186 Ok(())
187}
188
189pub use setup_postgres::{reset_database, run_migrations};
190
191pub mod setup_postgres {
192 use crate::{database::Connection, db::MIGRATIONS};
193 use anyhow::anyhow;
194 use diesel_async::RunQueryDsl;
195 use tracing::info;
196
197 pub async fn reset_database(mut conn: Connection<'static>) -> Result<(), anyhow::Error> {
198 info!("Resetting PG database ...");
199 clear_database(&mut conn).await?;
200 run_migrations(conn).await?;
201 info!("Reset database complete.");
202 Ok(())
203 }
204
205 pub async fn clear_database(conn: &mut Connection<'static>) -> Result<(), anyhow::Error> {
206 info!("Clearing the database...");
207 let drop_all_tables = "
208 DO $$ DECLARE
209 r RECORD;
210 BEGIN
211 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
212 LOOP
213 EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
214 END LOOP;
215 END $$;";
216 diesel::sql_query(drop_all_tables).execute(conn).await?;
217 info!("Dropped all tables.");
218
219 let drop_all_procedures = "
220 DO $$ DECLARE
221 r RECORD;
222 BEGIN
223 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
224 FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
225 WHERE ns.nspname = 'public' AND prokind = 'p')
226 LOOP
227 EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
228 END LOOP;
229 END $$;";
230 diesel::sql_query(drop_all_procedures).execute(conn).await?;
231 info!("Dropped all procedures.");
232
233 let drop_all_functions = "
234 DO $$ DECLARE
235 r RECORD;
236 BEGIN
237 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
238 FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
239 WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
240 LOOP
241 EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
242 END LOOP;
243 END $$;";
244 diesel::sql_query(drop_all_functions).execute(conn).await?;
245 info!("Database cleared.");
246 Ok(())
247 }
248
249 pub async fn run_migrations(conn: Connection<'static>) -> Result<(), anyhow::Error> {
250 info!("Running migrations ...");
251 conn.run_pending_migrations(MIGRATIONS)
252 .await
253 .map_err(|e| anyhow!("Failed to run migrations {e}"))?;
254 Ok(())
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use crate::database::{Connection, ConnectionPool};
261 use crate::db::{
262 ConnectionPoolConfig, MIGRATIONS, check_db_migration_consistency,
263 check_db_migration_consistency_impl, reset_database,
264 };
265 use diesel::migration::{Migration, MigrationSource};
266 use diesel::pg::Pg;
267 use diesel_migrations::MigrationHarness;
268 use sui_pg_db::temp::TempDb;
269
270 #[tokio::test]
273 async fn db_migration_consistency_smoke_test() {
274 let database = TempDb::new().unwrap();
275 let pool = ConnectionPool::new(
276 database.database().url().to_owned(),
277 ConnectionPoolConfig {
278 pool_size: 2,
279 ..Default::default()
280 },
281 )
282 .await
283 .unwrap();
284
285 reset_database(pool.dedicated_connection().await.unwrap())
286 .await
287 .unwrap();
288 check_db_migration_consistency(&mut pool.get().await.unwrap())
289 .await
290 .unwrap();
291 }
292
293 #[tokio::test]
294 async fn db_migration_consistency_non_prefix_test() {
295 let database = TempDb::new().unwrap();
296 let pool = ConnectionPool::new(
297 database.database().url().to_owned(),
298 ConnectionPoolConfig {
299 pool_size: 2,
300 ..Default::default()
301 },
302 )
303 .await
304 .unwrap();
305
306 reset_database(pool.dedicated_connection().await.unwrap())
307 .await
308 .unwrap();
309 let mut connection = pool.get().await.unwrap();
310
311 let mut sync_connection_wrapper =
312 diesel_async::async_connection_wrapper::AsyncConnectionWrapper::<Connection>::from(
313 pool.dedicated_connection().await.unwrap(),
314 );
315
316 tokio::task::spawn_blocking(move || {
317 sync_connection_wrapper
318 .revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
319 .unwrap();
320 })
321 .await
322 .unwrap();
323 assert!(
326 check_db_migration_consistency(&mut connection)
327 .await
328 .is_err()
329 );
330
331 pool.dedicated_connection()
332 .await
333 .unwrap()
334 .run_pending_migrations(MIGRATIONS)
335 .await
336 .unwrap();
337 check_db_migration_consistency(&mut connection)
339 .await
340 .unwrap();
341 }
342
343 #[tokio::test]
344 async fn db_migration_consistency_prefix_test() {
345 let database = TempDb::new().unwrap();
346 let pool = ConnectionPool::new(
347 database.database().url().to_owned(),
348 ConnectionPoolConfig {
349 pool_size: 2,
350 ..Default::default()
351 },
352 )
353 .await
354 .unwrap();
355
356 reset_database(pool.dedicated_connection().await.unwrap())
357 .await
358 .unwrap();
359
360 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
361 let mut local_migrations: Vec<_> = migrations.iter().map(|m| m.name().version()).collect();
362 local_migrations.pop();
363 check_db_migration_consistency_impl(&mut pool.get().await.unwrap(), local_migrations)
366 .await
367 .unwrap();
368 }
369
370 #[tokio::test]
371 async fn db_migration_consistency_subset_test() {
372 let database = TempDb::new().unwrap();
373 let pool = ConnectionPool::new(
374 database.database().url().to_owned(),
375 ConnectionPoolConfig {
376 pool_size: 2,
377 ..Default::default()
378 },
379 )
380 .await
381 .unwrap();
382
383 reset_database(pool.dedicated_connection().await.unwrap())
384 .await
385 .unwrap();
386
387 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
388 let mut local_migrations: Vec<_> = migrations.iter().map(|m| m.name().version()).collect();
389 local_migrations.remove(2);
390
391 check_db_migration_consistency_impl(&mut pool.get().await.unwrap(), local_migrations)
394 .await
395 .unwrap();
396 }
397
398 #[tokio::test]
399 async fn temp_db_smoketest() {
400 use crate::database::Connection;
401 use diesel_async::RunQueryDsl;
402 use sui_pg_db::temp::TempDb;
403
404 telemetry_subscribers::init_for_testing();
405
406 let db = TempDb::new().unwrap();
407 let url = db.database().url();
408 println!("url: {}", url.as_str());
409 let mut connection = Connection::dedicated(url).await.unwrap();
410
411 let resp = diesel::sql_query("SELECT datname FROM pg_database")
413 .execute(&mut connection)
414 .await
415 .unwrap();
416 println!("resp: {:?}", resp);
417 }
418}