1use anyhow::anyhow;
5use anyhow::Context;
6use anyhow::Result;
7use std::fmt::Debug;
8use std::fs::OpenOptions;
9use std::process::ExitStatus;
10use std::{
11 path::{Path, PathBuf},
12 process::{Child, Command},
13 time::{Duration, Instant},
14};
15use tracing::{event_enabled, info, trace};
16use url::Url;
17
18pub struct TempDb {
20 database: LocalDatabase,
21
22 dir: tempfile::TempDir,
29}
30
31pub struct LocalDatabase {
35 dir: PathBuf,
36 port: u16,
37 url: Url,
38 process: Option<PostgresProcess>,
39}
40
41#[derive(Debug)]
42struct PostgresProcess {
43 dir: PathBuf,
44 inner: Child,
45}
46
47#[derive(Debug)]
48enum HealthCheckError {
49 NotRunning(Option<ExitStatus>),
50 NotReady,
51 #[allow(unused)]
52 Unknown(String),
53}
54
55impl std::fmt::Display for HealthCheckError {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 HealthCheckError::NotRunning(Some(status)) => {
59 write!(f, "Not running - exit status: {}", status)
60 }
61 HealthCheckError::NotRunning(None) => write!(f, "Not running - no exit status"),
62 HealthCheckError::NotReady => write!(f, "Not ready"),
63 HealthCheckError::Unknown(msg) => write!(f, "Unknown error: {}", msg),
64 }
65 }
66}
67
68impl std::error::Error for HealthCheckError {}
69
70impl TempDb {
71 pub fn new() -> Result<Self> {
76 let dir = tempfile::TempDir::new()?;
77 let port = get_available_port();
78
79 let database = LocalDatabase::new_initdb(dir.path().to_owned(), port)?;
80
81 Ok(Self { dir, database })
82 }
83
84 pub fn database(&self) -> &LocalDatabase {
85 &self.database
86 }
87
88 pub fn database_mut(&mut self) -> &mut LocalDatabase {
89 &mut self.database
90 }
91
92 pub fn dir(&self) -> &Path {
93 self.dir.path()
94 }
95}
96
97impl LocalDatabase {
98 pub fn new(dir: PathBuf, port: u16) -> Result<Self> {
105 let url = format!(
106 "postgres://postgres:postgrespw@localhost:{port}/{db_name}",
107 db_name = "postgres"
108 )
109 .parse()
110 .unwrap();
111 let mut db = Self {
112 dir,
113 port,
114 url,
115 process: None,
116 };
117 db.start()?;
118 Ok(db)
119 }
120
121 pub fn new_initdb(dir: PathBuf, port: u16) -> Result<Self> {
125 initdb(&dir)?;
126 Self::new(dir, port)
127 }
128
129 pub fn url(&self) -> &Url {
131 &self.url
132 }
133
134 fn start(&mut self) -> Result<()> {
135 if self.process.is_none() {
136 self.process = Some(PostgresProcess::start(self.dir.clone(), self.port)?);
137 self.wait_till_ready().with_context(|| {
138 format!(
139 "Unable to start postgres, check dir {} for logs",
141 self.dir.display(),
142 )
143 })?;
144 }
145
146 Ok(())
147 }
148
149 fn health_check(&mut self) -> Result<(), HealthCheckError> {
150 if let Some(p) = &mut self.process {
151 match p.inner.try_wait() {
152 Ok(Some(status)) => Err(HealthCheckError::NotRunning(Some(status))),
154
155 Ok(None) => pg_isready(self.port),
157
158 Err(e) => Err(HealthCheckError::Unknown(e.to_string())),
160 }
161 } else {
162 Err(HealthCheckError::NotRunning(None))
163 }
164 }
165
166 fn wait_till_ready(&mut self) -> Result<(), HealthCheckError> {
167 let start = Instant::now();
168
169 while start.elapsed() < Duration::from_secs(10) {
170 match self.health_check() {
171 Err(HealthCheckError::NotReady) => {}
172 result => return result,
173 }
174
175 std::thread::sleep(Duration::from_millis(50));
176 }
177
178 Err(HealthCheckError::Unknown(
179 "timeout reached when waiting for service to be ready".to_owned(),
180 ))
181 }
182}
183
184impl PostgresProcess {
185 fn start(dir: PathBuf, port: u16) -> Result<Self> {
186 let child = Command::new("postgres")
187 .arg("-D")
189 .arg(&dir)
190 .args(["-p", &port.to_string()])
192 .args(["-c", "unix_socket_directories="])
194 .stdout(
196 OpenOptions::new()
197 .create(true)
198 .append(true)
199 .open(dir.join("stdout"))?,
200 )
201 .stderr(
202 OpenOptions::new()
203 .create(true)
204 .append(true)
205 .open(dir.join("stderr"))?,
206 )
207 .spawn()
208 .context("command not found: postgres")?;
209
210 Ok(Self { dir, inner: child })
211 }
212
213 fn pg_ctl_stop(&mut self) -> Result<()> {
215 let output = Command::new("pg_ctl")
216 .arg("stop")
217 .arg("-D")
218 .arg(&self.dir)
219 .arg("-mfast")
220 .output()
221 .context("command not found: pg_ctl")?;
222
223 if output.status.success() {
224 Ok(())
225 } else {
226 Err(anyhow!("couldn't shut down postgres"))
227 }
228 }
229
230 fn dump_stdout_stderr(&self) -> Result<(String, String)> {
231 let stdout = std::fs::read_to_string(self.dir.join("stdout"))?;
232 let stderr = std::fs::read_to_string(self.dir.join("stderr"))?;
233
234 Ok((stdout, stderr))
235 }
236}
237
238impl Drop for PostgresProcess {
239 fn drop(&mut self) {
241 info!("dropping postgres");
242 match self.inner.try_wait() {
244 Ok(Some(_)) => {}
246
247 _ => {
249 if self.pg_ctl_stop().is_err() {
250 self.inner.kill().expect("postgres couldn't be killed");
252 }
253 self.inner.wait().unwrap();
254 }
255 }
256
257 if event_enabled!(tracing::Level::TRACE) {
259 if let Ok((stdout, stderr)) = self.dump_stdout_stderr() {
260 trace!("stdout: {stdout}");
261 trace!("stderr: {stderr}");
262 }
263 }
264 }
265}
266
267fn pg_isready(port: u16) -> Result<(), HealthCheckError> {
271 let output = Command::new("pg_isready")
272 .arg("--host=localhost")
273 .arg("-p")
274 .arg(port.to_string())
275 .arg("--username=postgres")
276 .output()
277 .map_err(|e| HealthCheckError::Unknown(format!("command not found: pg_ctl: {e}")))?;
278
279 trace!("pg_isready code: {:?}", output.status.code());
280 trace!("pg_isready output: {}", output.stderr.escape_ascii());
281 trace!("pg_isready output: {}", output.stdout.escape_ascii());
282 if output.status.success() {
283 Ok(())
284 } else {
285 Err(HealthCheckError::NotReady)
286 }
287}
288
289fn initdb(dir: &Path) -> Result<()> {
293 let output = Command::new("initdb")
294 .arg("-D")
295 .arg(dir)
296 .arg("--no-instructions")
297 .arg("--username=postgres")
298 .output()
299 .context("command not found: initdb")?;
300
301 if output.status.success() {
302 Ok(())
303 } else {
304 Err(anyhow!(
305 "unable to initialize database: {:?}",
306 String::from_utf8(output.stderr)
307 ))
308 }
309}
310
311pub fn get_available_port() -> u16 {
315 const MAX_PORT_RETRIES: u32 = 1000;
316
317 for _ in 0..MAX_PORT_RETRIES {
318 if let Ok(port) = get_ephemeral_port() {
319 return port;
320 }
321 }
322
323 panic!("Error: could not find an available port");
324}
325
326fn get_ephemeral_port() -> std::io::Result<u16> {
327 let listener = std::net::TcpListener::bind(("127.0.0.1", 0))?;
329 let addr = listener.local_addr()?;
330
331 let _sender = std::net::TcpStream::connect(addr)?;
335 let _incoming = listener.accept()?;
336
337 Ok(addr.port())
338}