1use std::fmt::Debug;
5use std::fs::OpenOptions;
6use std::path::Path;
7use std::path::PathBuf;
8use std::process::Child;
9use std::process::Command;
10use std::process::ExitStatus;
11use std::time::Duration;
12use std::time::Instant;
13
14use anyhow::Context;
15use anyhow::Result;
16use anyhow::anyhow;
17use tracing::event_enabled;
18use tracing::info;
19use tracing::trace;
20use url::Url;
21
22pub struct TempDb {
24 database: LocalDatabase,
25
26 dir: tempfile::TempDir,
33}
34
35pub struct LocalDatabase {
39 dir: PathBuf,
40 port: u16,
41 url: Url,
42 process: Option<PostgresProcess>,
43}
44
45#[derive(Debug)]
46struct PostgresProcess {
47 dir: PathBuf,
48 inner: Child,
49}
50
51#[derive(Debug)]
52enum HealthCheckError {
53 NotRunning(Option<ExitStatus>),
54 NotReady,
55 #[allow(unused)]
56 Unknown(String),
57}
58
59impl std::fmt::Display for HealthCheckError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 HealthCheckError::NotRunning(Some(status)) => {
63 write!(f, "Not running - exit status: {}", status)
64 }
65 HealthCheckError::NotRunning(None) => write!(f, "Not running - no exit status"),
66 HealthCheckError::NotReady => write!(f, "Not ready"),
67 HealthCheckError::Unknown(msg) => write!(f, "Unknown error: {}", msg),
68 }
69 }
70}
71
72impl std::error::Error for HealthCheckError {}
73
74impl TempDb {
75 pub fn new() -> Result<Self> {
80 let dir = tempfile::TempDir::new()?;
81 let port = get_available_port();
82
83 let database = LocalDatabase::new_initdb(dir.path().to_owned(), port)?;
84
85 Ok(Self { dir, database })
86 }
87
88 pub fn database(&self) -> &LocalDatabase {
89 &self.database
90 }
91
92 pub fn database_mut(&mut self) -> &mut LocalDatabase {
93 &mut self.database
94 }
95
96 pub fn dir(&self) -> &Path {
97 self.dir.path()
98 }
99}
100
101impl LocalDatabase {
102 pub fn new(dir: PathBuf, port: u16) -> Result<Self> {
109 let url = format!(
110 "postgres://postgres:postgrespw@localhost:{port}/{db_name}",
111 db_name = "postgres"
112 )
113 .parse()
114 .unwrap();
115 let mut db = Self {
116 dir,
117 port,
118 url,
119 process: None,
120 };
121 db.start()?;
122 Ok(db)
123 }
124
125 pub fn new_initdb(dir: PathBuf, port: u16) -> Result<Self> {
129 initdb(&dir)?;
130 Self::new(dir, port)
131 }
132
133 pub fn url(&self) -> &Url {
135 &self.url
136 }
137
138 fn start(&mut self) -> Result<()> {
139 if self.process.is_none() {
140 self.process = Some(PostgresProcess::start(self.dir.clone(), self.port)?);
141 self.wait_till_ready().with_context(|| {
142 format!(
143 "Unable to start postgres, check dir {} for logs",
145 self.dir.display(),
146 )
147 })?;
148 }
149
150 Ok(())
151 }
152
153 fn health_check(&mut self) -> Result<(), HealthCheckError> {
154 if let Some(p) = &mut self.process {
155 match p.inner.try_wait() {
156 Ok(Some(status)) => Err(HealthCheckError::NotRunning(Some(status))),
158
159 Ok(None) => pg_isready(self.port),
161
162 Err(e) => Err(HealthCheckError::Unknown(e.to_string())),
164 }
165 } else {
166 Err(HealthCheckError::NotRunning(None))
167 }
168 }
169
170 fn wait_till_ready(&mut self) -> Result<(), HealthCheckError> {
171 let start = Instant::now();
172
173 while start.elapsed() < Duration::from_secs(10) {
174 match self.health_check() {
175 Err(HealthCheckError::NotReady) => {}
176 result => return result,
177 }
178
179 std::thread::sleep(Duration::from_millis(50));
180 }
181
182 Err(HealthCheckError::Unknown(
183 "timeout reached when waiting for service to be ready".to_owned(),
184 ))
185 }
186}
187
188impl PostgresProcess {
189 fn start(dir: PathBuf, port: u16) -> Result<Self> {
190 let child = Command::new("postgres")
191 .arg("-D")
193 .arg(&dir)
194 .args(["-p", &port.to_string()])
196 .args(["-c", "unix_socket_directories="])
198 .stdout(
200 OpenOptions::new()
201 .create(true)
202 .append(true)
203 .open(dir.join("stdout"))?,
204 )
205 .stderr(
206 OpenOptions::new()
207 .create(true)
208 .append(true)
209 .open(dir.join("stderr"))?,
210 )
211 .spawn()
212 .context("command not found: postgres")?;
213
214 Ok(Self { dir, inner: child })
215 }
216
217 fn pg_ctl_stop(&mut self) -> Result<()> {
219 let output = Command::new("pg_ctl")
220 .arg("stop")
221 .arg("-D")
222 .arg(&self.dir)
223 .arg("-mfast")
224 .output()
225 .context("command not found: pg_ctl")?;
226
227 if output.status.success() {
228 Ok(())
229 } else {
230 Err(anyhow!("couldn't shut down postgres"))
231 }
232 }
233
234 fn dump_stdout_stderr(&self) -> Result<(String, String)> {
235 let stdout = std::fs::read_to_string(self.dir.join("stdout"))?;
236 let stderr = std::fs::read_to_string(self.dir.join("stderr"))?;
237
238 Ok((stdout, stderr))
239 }
240}
241
242impl Drop for PostgresProcess {
243 fn drop(&mut self) {
245 info!("dropping postgres");
246 match self.inner.try_wait() {
248 Ok(Some(_)) => {}
250
251 _ => {
253 if self.pg_ctl_stop().is_err() {
254 self.inner.kill().expect("postgres couldn't be killed");
256 }
257 self.inner.wait().unwrap();
258 }
259 }
260
261 if event_enabled!(tracing::Level::TRACE)
263 && let Ok((stdout, stderr)) = self.dump_stdout_stderr()
264 {
265 trace!("stdout: {stdout}");
266 trace!("stderr: {stderr}");
267 }
268 }
269}
270
271fn pg_isready(port: u16) -> Result<(), HealthCheckError> {
275 let output = Command::new("pg_isready")
276 .arg("--host=localhost")
277 .arg("-p")
278 .arg(port.to_string())
279 .arg("--username=postgres")
280 .output()
281 .map_err(|e| HealthCheckError::Unknown(format!("command not found: pg_ctl: {e}")))?;
282
283 trace!("pg_isready code: {:?}", output.status.code());
284 trace!("pg_isready output: {}", output.stderr.escape_ascii());
285 trace!("pg_isready output: {}", output.stdout.escape_ascii());
286 if output.status.success() {
287 Ok(())
288 } else {
289 Err(HealthCheckError::NotReady)
290 }
291}
292
293fn initdb(dir: &Path) -> Result<()> {
297 let output = Command::new("initdb")
298 .arg("-D")
299 .arg(dir)
300 .arg("--no-instructions")
301 .arg("--username=postgres")
302 .output()
303 .context("command not found: initdb")?;
304
305 if output.status.success() {
306 Ok(())
307 } else {
308 Err(anyhow!(
309 "unable to initialize database: {:?}",
310 String::from_utf8(output.stderr)
311 ))
312 }
313}
314
315pub fn get_available_port() -> u16 {
319 const MAX_PORT_RETRIES: u32 = 1000;
320
321 for _ in 0..MAX_PORT_RETRIES {
322 if let Ok(port) = get_ephemeral_port() {
323 return port;
324 }
325 }
326
327 panic!("Error: could not find an available port");
328}
329
330fn get_ephemeral_port() -> std::io::Result<u16> {
331 let listener = std::net::TcpListener::bind(("127.0.0.1", 0))?;
333 let addr = listener.local_addr()?;
334
335 let _sender = std::net::TcpStream::connect(addr)?;
339 let _incoming = listener.accept()?;
340
341 Ok(addr.port())
342}