1use std::ops::{Deref, DerefMut};
5use std::path::PathBuf;
6use std::time::Duration;
7
8use anyhow::anyhow;
9use diesel::ConnectionError;
10use diesel::migration::{Migration, MigrationSource, MigrationVersion};
11use diesel::pg::Pg;
12use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
13use diesel_async::pooled_connection::ManagerConfig;
14use diesel_async::{
15 AsyncPgConnection, RunQueryDsl,
16 pooled_connection::{
17 AsyncDieselConnectionManager,
18 bb8::{Pool, PooledConnection},
19 },
20};
21use futures::FutureExt;
22use tracing::info;
23use url::Url;
24
25use tls::{build_tls_config, establish_tls_connection};
26
27mod model;
28mod tls;
29
30pub use sui_field_count::FieldCount;
31pub use sui_sql_macro::sql;
32
33pub mod query;
34pub mod schema;
35pub mod store;
36pub mod temp;
37
38use diesel_migrations::{EmbeddedMigrations, embed_migrations};
39
40pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
41
42#[derive(clap::Args, Debug, Clone)]
43pub struct DbArgs {
44 #[arg(long, default_value_t = Self::default().db_connection_pool_size)]
46 pub db_connection_pool_size: u32,
47
48 #[arg(long, default_value_t = Self::default().db_connection_timeout_ms)]
50 pub db_connection_timeout_ms: u64,
51
52 #[arg(long)]
53 pub db_statement_timeout_ms: Option<u64>,
55
56 #[arg(long)]
57 pub tls_verify_cert: bool,
60
61 #[arg(long)]
62 pub tls_ca_cert_path: Option<PathBuf>,
64}
65
66#[derive(Clone)]
67pub struct Db(Pool<AsyncPgConnection>);
68
69pub struct Connection<'a>(PooledConnection<'a, AsyncPgConnection>);
71
72impl DbArgs {
73 pub fn connection_timeout(&self) -> Duration {
74 Duration::from_millis(self.db_connection_timeout_ms)
75 }
76
77 pub fn statement_timeout(&self) -> Option<Duration> {
78 self.db_statement_timeout_ms.map(Duration::from_millis)
79 }
80}
81
82impl Db {
83 pub async fn for_write(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
86 Ok(Self(pool(database_url, config, false).await?))
87 }
88
89 pub async fn for_read(database_url: Url, config: DbArgs) -> anyhow::Result<Self> {
93 Ok(Self(pool(database_url, config, true).await?))
94 }
95
96 pub async fn connect(&self) -> anyhow::Result<Connection<'_>> {
99 Ok(Connection(self.0.get().await?))
100 }
101
102 pub fn state(&self) -> bb8::State {
104 self.0.state()
105 }
106
107 async fn clear_database(&self) -> anyhow::Result<()> {
108 info!("Clearing the database...");
109 let mut conn = self.connect().await?;
110 let drop_all_tables = "
111 DO $$ DECLARE
112 r RECORD;
113 BEGIN
114 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
115 LOOP
116 EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
117 END LOOP;
118 END $$;";
119 diesel::sql_query(drop_all_tables)
120 .execute(&mut conn)
121 .await?;
122 info!("Dropped all tables.");
123
124 let drop_all_procedures = "
125 DO $$ DECLARE
126 r RECORD;
127 BEGIN
128 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
129 FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
130 WHERE ns.nspname = 'public' AND prokind = 'p')
131 LOOP
132 EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
133 END LOOP;
134 END $$;";
135 diesel::sql_query(drop_all_procedures)
136 .execute(&mut conn)
137 .await?;
138 info!("Dropped all procedures.");
139
140 let drop_all_functions = "
141 DO $$ DECLARE
142 r RECORD;
143 BEGIN
144 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
145 FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
146 WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
147 LOOP
148 EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
149 END LOOP;
150 END $$;";
151 diesel::sql_query(drop_all_functions)
152 .execute(&mut conn)
153 .await?;
154 info!("Database cleared.");
155 Ok(())
156 }
157
158 pub async fn run_migrations(
161 &self,
162 migrations: Option<&'static EmbeddedMigrations>,
163 ) -> anyhow::Result<Vec<MigrationVersion<'static>>> {
164 use diesel_migrations::MigrationHarness;
165
166 let merged_migrations = merge_migrations(migrations);
167
168 info!("Running migrations ...");
169 let conn = self.0.dedicated_connection().await?;
170 let mut wrapper: AsyncConnectionWrapper<AsyncPgConnection> =
171 diesel_async::async_connection_wrapper::AsyncConnectionWrapper::from(conn);
172
173 let finished_migrations = tokio::task::spawn_blocking(move || {
174 wrapper
175 .run_pending_migrations(merged_migrations)
176 .map(|versions| versions.iter().map(MigrationVersion::as_owned).collect())
177 })
178 .await?
179 .map_err(|e| anyhow!("Failed to run migrations: {:?}", e))?;
180
181 info!("Migrations complete.");
182 Ok(finished_migrations)
183 }
184}
185
186impl Default for DbArgs {
187 fn default() -> Self {
188 Self {
189 db_connection_pool_size: 100,
190 db_connection_timeout_ms: 60_000,
191 db_statement_timeout_ms: None,
192 tls_verify_cert: false,
193 tls_ca_cert_path: None,
194 }
195 }
196}
197
198pub async fn reset_database(
200 database_url: Url,
201 db_config: DbArgs,
202 migrations: Option<&'static EmbeddedMigrations>,
203) -> anyhow::Result<()> {
204 let db = Db::for_write(database_url, db_config).await?;
205 db.clear_database().await?;
206 if let Some(migrations) = migrations {
207 db.run_migrations(Some(migrations)).await?;
208 }
209
210 Ok(())
211}
212
213impl<'a> Deref for Connection<'a> {
214 type Target = PooledConnection<'a, AsyncPgConnection>;
215
216 fn deref(&self) -> &Self::Target {
217 &self.0
218 }
219}
220
221impl DerefMut for Connection<'_> {
222 fn deref_mut(&mut self) -> &mut Self::Target {
223 &mut self.0
224 }
225}
226
227async fn pool(
228 database_url: Url,
229 args: DbArgs,
230 read_only: bool,
231) -> anyhow::Result<Pool<AsyncPgConnection>> {
232 let statement_timeout = args.statement_timeout();
233
234 let tls_config = build_tls_config(args.tls_verify_cert, args.tls_ca_cert_path.clone())?;
236
237 let mut config = ManagerConfig::default();
238
239 config.custom_setup = Box::new(move |url| {
240 let tls_config = tls_config.clone();
241
242 async move {
243 let mut conn = establish_tls_connection(url, tls_config).await?;
244
245 if let Some(timeout) = statement_timeout {
246 diesel::sql_query(format!("SET statement_timeout = {}", timeout.as_millis()))
247 .execute(&mut conn)
248 .await
249 .map_err(ConnectionError::CouldntSetupConfiguration)?;
250 }
251
252 if read_only {
253 diesel::sql_query("SET default_transaction_read_only = 'on'")
254 .execute(&mut conn)
255 .await
256 .map_err(ConnectionError::CouldntSetupConfiguration)?;
257 }
258
259 Ok(conn)
260 }
261 .boxed()
262 });
263
264 let manager = AsyncDieselConnectionManager::new_with_config(database_url.as_str(), config);
265
266 Ok(Pool::builder()
267 .max_size(args.db_connection_pool_size)
268 .connection_timeout(args.connection_timeout())
269 .build(manager)
270 .await?)
271}
272
273pub fn merge_migrations(
276 migrations: Option<&'static EmbeddedMigrations>,
277) -> impl MigrationSource<Pg> + Send + Sync + 'static {
278 struct Migrations(Option<&'static EmbeddedMigrations>);
279 impl MigrationSource<Pg> for Migrations {
280 fn migrations(&self) -> diesel::migration::Result<Vec<Box<dyn Migration<Pg>>>> {
281 let mut migrations = MIGRATIONS.migrations()?;
282 if let Some(more_migrations) = self.0 {
283 migrations.extend(more_migrations.migrations()?);
284 }
285 Ok(migrations)
286 }
287 }
288
289 Migrations(migrations)
290}
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use diesel::prelude::QueryableByName;
295 use diesel_async::RunQueryDsl;
296
297 #[tokio::test]
298 async fn temp_db_smoketest() {
299 telemetry_subscribers::init_for_testing();
300 let db = temp::TempDb::new().unwrap();
301 let url = db.database().url();
302
303 info!(%url);
304 let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
305 let mut conn = db.connect().await.unwrap();
306
307 let resp = diesel::sql_query("SELECT datname FROM pg_database")
309 .execute(&mut conn)
310 .await
311 .unwrap();
312
313 info!(?resp);
314 }
315
316 #[derive(Debug, QueryableByName)]
317 struct CountResult {
318 #[diesel(sql_type = diesel::sql_types::BigInt)]
319 cnt: i64,
320 }
321
322 #[tokio::test]
323 async fn test_reset_database_skip_migrations() {
324 let temp_db = temp::TempDb::new().unwrap();
325 let url = temp_db.database().url();
326
327 let db = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
328 let mut conn = db.connect().await.unwrap();
329 diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
330 .execute(&mut conn)
331 .await
332 .unwrap();
333 let cnt = diesel::sql_query(
334 "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
335 )
336 .get_result::<CountResult>(&mut conn)
337 .await
338 .unwrap();
339 assert_eq!(cnt.cnt, 1);
340
341 reset_database(url.clone(), DbArgs::default(), None)
342 .await
343 .unwrap();
344
345 let mut conn = db.connect().await.unwrap();
346 let cnt: CountResult = diesel::sql_query(
347 "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_name = 'test_table'",
348 )
349 .get_result(&mut conn)
350 .await
351 .unwrap();
352 assert_eq!(cnt.cnt, 0);
353 }
354
355 #[tokio::test]
356 async fn test_read_only() {
357 let temp_db = temp::TempDb::new().unwrap();
358 let url = temp_db.database().url();
359
360 let writer = Db::for_write(url.clone(), DbArgs::default()).await.unwrap();
361 let reader = Db::for_read(url.clone(), DbArgs::default()).await.unwrap();
362
363 {
364 let mut conn = writer.connect().await.unwrap();
366 diesel::sql_query("CREATE TABLE test_table (id INTEGER PRIMARY KEY)")
367 .execute(&mut conn)
368 .await
369 .unwrap();
370 }
371
372 {
373 let mut conn = reader.connect().await.unwrap();
375 let result = diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
376 .execute(&mut conn)
377 .await;
378 assert!(result.is_err());
379 }
380
381 {
382 let mut conn = reader.connect().await.unwrap();
385 let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
386 .get_result(&mut conn)
387 .await
388 .unwrap();
389 assert_eq!(cnt.cnt, 0);
390 }
391
392 {
393 let mut conn = writer.connect().await.unwrap();
395 diesel::sql_query("INSERT INTO test_table (id) VALUES (1)")
396 .execute(&mut conn)
397 .await
398 .unwrap();
399 }
400
401 {
402 let mut conn = reader.connect().await.unwrap();
405 let cnt: CountResult = diesel::sql_query("SELECT COUNT(*) as cnt FROM test_table")
406 .get_result(&mut conn)
407 .await
408 .unwrap();
409 assert_eq!(cnt.cnt, 1);
410 }
411 }
412
413 #[tokio::test]
414 async fn test_statement_timeout() {
415 let temp_db = temp::TempDb::new().unwrap();
416 let url = temp_db.database().url();
417
418 let reader = Db::for_read(
419 url.clone(),
420 DbArgs {
421 db_statement_timeout_ms: Some(200),
422 ..DbArgs::default()
423 },
424 )
425 .await
426 .unwrap();
427
428 {
429 let mut conn = reader.connect().await.unwrap();
431 let cnt: CountResult = diesel::sql_query("SELECT 1::BIGINT AS cnt")
432 .get_result(&mut conn)
433 .await
434 .unwrap();
435
436 assert_eq!(cnt.cnt, 1);
437 }
438
439 {
440 let mut conn = reader.connect().await.unwrap();
442 diesel::sql_query("SELECT PG_SLEEP(2), 1::BIGINT AS cnt")
443 .get_result::<CountResult>(&mut conn)
444 .await
445 .expect_err("This request should fail because of a timeout");
446 }
447 }
448}