1use std::ops::Deref;
5use std::ops::DerefMut;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::Duration;
9
10use anyhow::anyhow;
11use diesel::ConnectionError;
12use diesel::migration::Migration;
13use diesel::migration::MigrationSource;
14use diesel::migration::MigrationVersion;
15use diesel::pg::Pg;
16use diesel_async::RunQueryDsl;
17use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
18use diesel_async::pooled_connection::AsyncDieselConnectionManager;
19use diesel_async::pooled_connection::ManagerConfig;
20use diesel_async::pooled_connection::bb8::Pool;
21use diesel_async::pooled_connection::bb8::PooledConnection;
22use diesel_migrations::EmbeddedMigrations;
23use diesel_migrations::embed_migrations;
24use futures::FutureExt;
25use prometheus::Registry;
26use tracing::info;
27use url::Url;
28
29use crate::tls::AsyncPgConnectionWithId;
30use crate::tls::build_tls_config;
31use crate::tls::establish_tls_connection;
32
33mod metrics;
34mod model;
35pub mod query;
36pub mod schema;
37pub mod store;
38pub mod temp;
39mod tls;
40
41use crate::metrics::PoolMetrics;
42pub use sui_field_count::FieldCount;
43pub use sui_sql_macro::sql;
44
45pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
46
47#[derive(clap::Args, Debug, Clone)]
48pub struct DbArgs {
49 #[arg(long, default_value_t = Self::default().db_connection_pool_size)]
51 pub db_connection_pool_size: u32,
52
53 #[arg(long, default_value_t = Self::default().db_connection_timeout_ms)]
55 pub db_connection_timeout_ms: u64,
56
57 #[arg(long)]
58 pub db_statement_timeout_ms: Option<u64>,
60
61 #[arg(long)]
62 pub tls_verify_cert: bool,
65
66 #[arg(long)]
67 pub tls_ca_cert_path: Option<PathBuf>,
69}
70
71#[derive(Clone)]
72pub struct Db {
73 pool: Pool<AsyncPgConnectionWithId>,
74 pool_metrics: Option<Arc<PoolMetrics>>,
75}
76
77struct CancelGuard<'m>(Option<&'m PoolMetrics>);
82
83pub struct Connection<'a>(PooledConnection<'a, AsyncPgConnectionWithId>);
85
86impl DbArgs {
87 pub fn connection_timeout(&self) -> Duration {
88 Duration::from_millis(self.db_connection_timeout_ms)
89 }
90
91 pub fn statement_timeout(&self) -> Option<Duration> {
92 self.db_statement_timeout_ms.map(Duration::from_millis)
93 }
94}
95
96impl Db {
97 pub async fn for_write(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
100 Self::new(database_url, config, false).await
101 }
102
103 pub async fn for_read(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
107 Self::new(database_url, config, true).await
108 }
109
110 async fn new(database_url: Url, db_args: DbArgs, read_only: bool) -> anyhow::Result<Self> {
111 Ok(Db {
112 pool: pool(database_url, db_args, read_only).await?,
113 pool_metrics: None,
114 })
115 }
116
117 pub fn register_metrics(
118 mut self,
119 prefix: Option<&str>,
120 registry: &Registry,
121 ) -> anyhow::Result<Self> {
122 let pool_metrics = PoolMetrics::new(prefix, registry)?;
123 self.pool_metrics = Some(pool_metrics);
124 Ok(self)
125 }
126
127 pub async fn connect(&self) -> anyhow::Result<Connection<'_>> {
130 if let Some(pool_metrics) = &self.pool_metrics {
131 let guard = CancelGuard::request(pool_metrics);
132 match self.pool.get().await {
133 Ok(c) => {
134 guard.acquired();
135 Ok(Connection(c))
136 }
137 Err(e) => {
138 guard.unacquired_error();
139 Err(e.into())
140 }
141 }
142 } else {
143 Ok(Connection(self.pool.get().await?))
144 }
145 }
146
147 pub fn state(&self) -> bb8::State {
149 self.pool.state()
150 }
151
152 pub fn pool_metrics(&self) -> Option<&PoolMetrics> {
153 self.pool_metrics.as_ref().map(|p| p.as_ref())
154 }
155
156 async fn clear_database(&self) -> anyhow::Result<()> {
157 info!("Clearing the database...");
158 let mut conn = self.connect().await?;
159 let drop_all_tables = "
160 DO $$ DECLARE
161 r RECORD;
162 BEGIN
163 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
164 LOOP
165 EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
166 END LOOP;
167 END $$;";
168 diesel::sql_query(drop_all_tables)
169 .execute(&mut conn)
170 .await?;
171 info!("Dropped all tables.");
172
173 let drop_all_procedures = "
174 DO $$ DECLARE
175 r RECORD;
176 BEGIN
177 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
178 FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
179 WHERE ns.nspname = 'public' AND prokind = 'p')
180 LOOP
181 EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
182 END LOOP;
183 END $$;";
184 diesel::sql_query(drop_all_procedures)
185 .execute(&mut conn)
186 .await?;
187 info!("Dropped all procedures.");
188
189 let drop_all_functions = "
190 DO $$ DECLARE
191 r RECORD;
192 BEGIN
193 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
194 FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
195 WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
196 LOOP
197 EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
198 END LOOP;
199 END $$;";
200 diesel::sql_query(drop_all_functions)
201 .execute(&mut conn)
202 .await?;
203 info!("Database cleared.");
204 Ok(())
205 }
206
207 pub async fn run_migrations(
210 &self,
211 migrations: Option<&'static EmbeddedMigrations>,
212 ) -> anyhow::Result<Vec<MigrationVersion<'static>>> {
213 use diesel_migrations::MigrationHarness;
214
215 let merged_migrations = merge_migrations(migrations);
216
217 info!("Running migrations ...");
218 let conn = self.pool.dedicated_connection().await?;
219 let mut wrapper: AsyncConnectionWrapper<AsyncPgConnectionWithId> =
220 AsyncConnectionWrapper::from(conn);
221
222 let finished_migrations = tokio::task::spawn_blocking(move || {
223 wrapper
224 .run_pending_migrations(merged_migrations)
225 .map(|versions| versions.iter().map(MigrationVersion::as_owned).collect())
226 })
227 .await?
228 .map_err(|e| anyhow!("Failed to run migrations: {:?}", e))?;
229
230 info!("Migrations complete.");
231 Ok(finished_migrations)
232 }
233}
234
235impl<'m> CancelGuard<'m> {
236 fn request(stats: &'m PoolMetrics) -> Self {
237 stats.requested.inc();
238 Self(Some(stats))
239 }
240
241 fn acquired(mut self) {
242 if let Some(m) = self.0.take() {
243 m.acquired.inc()
244 }
245 }
246
247 fn unacquired_error(mut self) {
248 if let Some(m) = self.0.take() {
249 m.unacquired_error.inc()
250 }
251 }
252}
253
254impl Default for DbArgs {
255 fn default() -> Self {
256 Self {
257 db_connection_pool_size: 100,
258 db_connection_timeout_ms: 60_000,
259 db_statement_timeout_ms: None,
260 tls_verify_cert: false,
261 tls_ca_cert_path: None,
262 }
263 }
264}
265
266impl<'m> Drop for CancelGuard<'m> {
267 fn drop(&mut self) {
268 if let Some(m) = self.0.take() {
269 m.unacquired_canceled.inc()
270 }
271 }
272}
273
274pub async fn reset_database(
276 database_url: Url,
277 db_args: DbArgs,
278 migrations: Option<&'static EmbeddedMigrations>,
279) -> anyhow::Result<()> {
280 let db = Db::for_write(database_url, db_args).await?;
281 db.clear_database().await?;
282 if let Some(migrations) = migrations {
283 db.run_migrations(Some(migrations)).await?;
284 }
285
286 Ok(())
287}
288
289impl<'a> Deref for Connection<'a> {
290 type Target = PooledConnection<'a, AsyncPgConnectionWithId>;
291
292 fn deref(&self) -> &Self::Target {
293 &self.0
294 }
295}
296
297impl DerefMut for Connection<'_> {
298 fn deref_mut(&mut self) -> &mut Self::Target {
299 &mut self.0
300 }
301}
302
303async fn pool(
304 database_url: Url,
305 args: DbArgs,
306 read_only: bool,
307) -> anyhow::Result<Pool<AsyncPgConnectionWithId>> {
308 let statement_timeout = args.statement_timeout();
309
310 let tls_config = build_tls_config(args.tls_verify_cert, args.tls_ca_cert_path.clone())?;
312
313 let mut config = ManagerConfig::default();
314
315 config.custom_setup = Box::new(move |url| {
316 let tls_config = tls_config.clone();
317
318 async move {
319 let mut conn = establish_tls_connection(url, tls_config).await?;
320
321 if let Some(timeout) = statement_timeout {
322 diesel::sql_query(format!("SET statement_timeout = {}", timeout.as_millis()))
323 .execute(&mut conn)
324 .await
325 .map_err(ConnectionError::CouldntSetupConfiguration)?;
326 }
327
328 if read_only {
329 diesel::sql_query("SET default_transaction_read_only = 'on'")
330 .execute(&mut conn)
331 .await
332 .map_err(ConnectionError::CouldntSetupConfiguration)?;
333 }
334
335 Ok(conn)
336 }
337 .boxed()
338 });
339
340 let manager = AsyncDieselConnectionManager::new_with_config(database_url.as_str(), config);
341
342 Ok(Pool::builder()
343 .max_size(args.db_connection_pool_size)
344 .connection_timeout(args.connection_timeout())
345 .build(manager)
346 .await?)
347}
348
349pub fn merge_migrations(
352 migrations: Option<&'static EmbeddedMigrations>,
353) -> impl MigrationSource<Pg> + Send + Sync + 'static {
354 struct Migrations(Option<&'static EmbeddedMigrations>);
355 impl MigrationSource<Pg> for Migrations {
356 fn migrations(&self) -> diesel::migration::Result<Vec<Box<dyn Migration<Pg>>>> {
357 let mut migrations = MIGRATIONS.migrations()?;
358 if let Some(more_migrations) = self.0 {
359 migrations.extend(more_migrations.migrations()?);
360 }
361 Ok(migrations)
362 }
363 }
364
365 Migrations(migrations)
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::temp::TempDb;
372 use anyhow::Error;
373 use diesel::prelude::QueryableByName;
374 use tokio::spawn;
375 use tokio::time::timeout;
376
377 struct MetricTest {
378 db: Db,
379 _temp_db: TempDb,
380 _registry: Registry,
381 }
382
383 impl MetricTest {
384 async fn new(db_connection_timeout: Duration) -> Arc<Self> {
385 let temp_db = TempDb::new().unwrap();
386 let url = temp_db.database().url();
387 let db_args = DbArgs {
388 db_connection_pool_size: 1,
389 db_connection_timeout_ms: db_connection_timeout.as_millis() as u64,
390 ..Default::default()
391 };
392 let registry = Registry::new();
393 let db = Db::for_read(url.clone(), db_args)
394 .await
395 .unwrap()
396 .register_metrics(None, ®istry)
397 .unwrap();
398 Arc::new(Self {
399 db,
400 _temp_db: temp_db,
401 _registry: registry,
402 })
403 }
404
405 fn pool_metrics(&self) -> &PoolMetrics {
406 self.db.pool_metrics().unwrap()
407 }
408
409 async fn select_sleep(&self, duration: Duration) -> Result<(), Error> {
410 let mut conn = self.db.connect().await?;
411 let duration_s = duration.as_secs_f64();
412 diesel::sql_query(format!("SELECT pg_sleep({duration_s});"))
413 .execute(&mut conn)
414 .await?;
415 Ok(())
416 }
417 }
418
419 #[tokio::test]
420 async fn temp_db_smoketest() {
421 telemetry_subscribers::init_for_testing();
422 let db = TempDb::new().unwrap();
423 let url = db.database().url();
424
425 info!(%url);
426 let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
427 let mut conn = db.connect().await.unwrap();
428
429 let resp = diesel::sql_query("SELECT datname FROM pg_database")
431 .execute(&mut conn)
432 .await
433 .unwrap();
434
435 info!(?resp);
436 }
437
438 #[derive(Debug, QueryableByName)]
439 struct CountResult {
440 #[diesel(sql_type = diesel::sql_types::BigInt)]
441 cnt: i64,
442 }
443
444 #[tokio::test]
445 async fn test_reset_database_skip_migrations() {
446 let temp_db = TempDb::new().unwrap();
447 let url = temp_db.database().url();
448
449 let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
450 let mut conn = db.connect().await.unwrap();
451 diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
452 .execute(&mut conn)
453 .await
454 .unwrap();
455 let cnt = diesel::sql_query(
456 "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
457 )
458 .get_result::<CountResult>(&mut conn)
459 .await
460 .unwrap();
461 assert_eq!(cnt.cnt, 1);
462
463 reset_database(url.clone(), DbArgs::default(), None)
464 .await
465 .unwrap();
466
467 let mut conn = db.connect().await.unwrap();
468 let cnt: CountResult = diesel::sql_query(
469 "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
470 )
471 .get_result(&mut conn)
472 .await
473 .unwrap();
474 assert_eq!(cnt.cnt, 0);
475 }
476
477 #[tokio::test]
478 async fn test_read_only() {
479 let temp_db = TempDb::new().unwrap();
480 let url = temp_db.database().url();
481
482 let writer = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
483 let reader = Db::for_read(url.clone(), DbArgs::default()).await.unwrap();
484
485 {
486 let mut conn = writer.connect().await.unwrap();
488 diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
489 .execute(&mut conn)
490 .await
491 .unwrap();
492 }
493
494 {
495 let mut conn = reader.connect().await.unwrap();
497 let result = diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
498 .execute(&mut conn)
499 .await;
500 assert!(result.is_err());
501 }
502
503 {
504 let mut conn = reader.connect().await.unwrap();
507 let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
508 .get_result(&mut conn)
509 .await
510 .unwrap();
511 assert_eq!(cnt.cnt, 0);
512 }
513
514 {
515 let mut conn = writer.connect().await.unwrap();
517 diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
518 .execute(&mut conn)
519 .await
520 .unwrap();
521 }
522
523 {
524 let mut conn = reader.connect().await.unwrap();
527 let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
528 .get_result(&mut conn)
529 .await
530 .unwrap();
531 assert_eq!(cnt.cnt, 1);
532 }
533 }
534
535 #[tokio::test]
536 async fn test_statement_timeout() {
537 let temp_db = TempDb::new().unwrap();
538 let url = temp_db.database().url();
539
540 let reader = Db::for_read(
541 url.clone(),
542 DbArgs {
543 db_statement_timeout_ms: Some(200),
544 ..DbArgs::default()
545 },
546 )
547 .await
548 .unwrap();
549
550 {
551 let mut conn = reader.connect().await.unwrap();
553 let cnt: CountResult = diesel::sql_query("SELECT 1::BIGINT AS cnt")
554 .get_result(&mut conn)
555 .await
556 .unwrap();
557
558 assert_eq!(cnt.cnt, 1);
559 }
560
561 {
562 let mut conn = reader.connect().await.unwrap();
564 diesel::sql_query("SELECT PG_SLEEP(2), 1::BIGINT AS cnt")
565 .get_result::<CountResult>(&mut conn)
566 .await
567 .expect_err("This request should fail because of a timeout");
568 }
569 }
570
571 #[tokio::test]
572 async fn test_unacquired_error() {
573 let db_connection_timeout = Duration::from_millis(500);
574 let metric_test = MetricTest::new(db_connection_timeout).await;
575
576 let metric_test_clone = metric_test.clone();
577 let task1 = spawn(async move {
578 metric_test_clone
579 .select_sleep(db_connection_timeout + Duration::from_millis(500))
580 .await
581 });
582 let metric_test_clone = metric_test.clone();
584 let task2 = spawn(async move {
585 metric_test_clone.select_sleep(Duration::ZERO).await
587 });
588 assert!(task1.await.unwrap().is_ok());
589 assert!(task2.await.unwrap().is_err());
590
591 let PoolMetrics {
592 requested,
593 acquired,
594 unacquired_error,
595 unacquired_canceled,
596 } = metric_test.pool_metrics();
597 assert_eq!(requested.get(), 2);
598 assert_eq!(acquired.get(), 1);
599 assert_eq!(unacquired_error.get(), 1);
600 assert_eq!(unacquired_canceled.get(), 0);
601 }
602
603 #[tokio::test]
604 async fn test_unacquired_canceled() {
605 let task_timeout = Duration::from_millis(500);
606 let sleep_timeout = task_timeout + Duration::from_millis(500);
607 let db_connection_timeout = sleep_timeout + Duration::from_millis(500);
608 let metric_test = MetricTest::new(db_connection_timeout).await;
609
610 let metric_test_clone = metric_test.clone();
611 let task1 = spawn(async move {
612 metric_test_clone
613 .select_sleep(db_connection_timeout + Duration::from_millis(500))
614 .await
615 });
616 let metric_test_clone = metric_test.clone();
617 let task2 = spawn(async move {
619 timeout(task_timeout, metric_test_clone.select_sleep(Duration::ZERO)).await
621 });
622 assert!(task1.await.unwrap().is_ok());
623 assert!(task2.await.unwrap().is_err());
624
625 let PoolMetrics {
626 requested,
627 acquired,
628 unacquired_error,
629 unacquired_canceled,
630 } = metric_test.pool_metrics();
631 assert_eq!(requested.get(), 2);
632 assert_eq!(acquired.get(), 1);
633 assert_eq!(unacquired_error.get(), 0);
634 assert_eq!(unacquired_canceled.get(), 1);
635 }
636}