sui_pg_db/
temp.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::anyhow;
5use anyhow::Context;
6use anyhow::Result;
7use std::fmt::Debug;
8use std::fs::OpenOptions;
9use std::process::ExitStatus;
10use std::{
11    path::{Path, PathBuf},
12    process::{Child, Command},
13    time::{Duration, Instant},
14};
15use tracing::{event_enabled, info, trace};
16use url::Url;
17
18/// A temporary, local postgres database
19pub struct TempDb {
20    database: LocalDatabase,
21
22    // Directory used for the ephemeral database.
23    //
24    // On drop the directory will be cleaned an its contents deleted.
25    //
26    // NOTE: This needs to be the last entry in this struct so that the database is dropped before
27    // and has a chance to gracefully shutdown before the directory is deleted.
28    dir: tempfile::TempDir,
29}
30
31/// Local instance of a `postgres` server.
32///
33/// See <https://www.postgresql.org/docs/16/app-postgres.html> for more info.
34pub struct LocalDatabase {
35    dir: PathBuf,
36    port: u16,
37    url: Url,
38    process: Option<PostgresProcess>,
39}
40
41#[derive(Debug)]
42struct PostgresProcess {
43    dir: PathBuf,
44    inner: Child,
45}
46
47#[derive(Debug)]
48enum HealthCheckError {
49    NotRunning(Option<ExitStatus>),
50    NotReady,
51    #[allow(unused)]
52    Unknown(String),
53}
54
55impl std::fmt::Display for HealthCheckError {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            HealthCheckError::NotRunning(Some(status)) => {
59                write!(f, "Not running - exit status: {}", status)
60            }
61            HealthCheckError::NotRunning(None) => write!(f, "Not running - no exit status"),
62            HealthCheckError::NotReady => write!(f, "Not ready"),
63            HealthCheckError::Unknown(msg) => write!(f, "Unknown error: {}", msg),
64        }
65    }
66}
67
68impl std::error::Error for HealthCheckError {}
69
70impl TempDb {
71    /// Create and start a new temporary postgres database.
72    ///
73    /// A fresh database will be initialized in a temporary directory that will be cleand up on drop.
74    /// The running `postgres` service will be serving traffic on an available, os-assigned port.
75    pub fn new() -> Result<Self> {
76        let dir = tempfile::TempDir::new()?;
77        let port = get_available_port();
78
79        let database = LocalDatabase::new_initdb(dir.path().to_owned(), port)?;
80
81        Ok(Self { dir, database })
82    }
83
84    pub fn database(&self) -> &LocalDatabase {
85        &self.database
86    }
87
88    pub fn database_mut(&mut self) -> &mut LocalDatabase {
89        &mut self.database
90    }
91
92    pub fn dir(&self) -> &Path {
93        self.dir.path()
94    }
95}
96
97impl LocalDatabase {
98    /// Start a local `postgres` database service.
99    ///
100    /// `dir`: The location of the on-disk postgres database. The database must already exist at
101    ///     the provided path. If you instead want to create a new database see `Self::new_initdb`.
102    ///
103    /// `port`: The port to listen for incoming connection on.
104    pub fn new(dir: PathBuf, port: u16) -> Result<Self> {
105        let url = format!(
106            "postgres://postgres:postgrespw@localhost:{port}/{db_name}",
107            db_name = "postgres"
108        )
109        .parse()
110        .unwrap();
111        let mut db = Self {
112            dir,
113            port,
114            url,
115            process: None,
116        };
117        db.start()?;
118        Ok(db)
119    }
120
121    /// Initialize and start a local `postgres` database service.
122    ///
123    /// Unlike `Self::new`, this will initialize a clean database at the provided path.
124    pub fn new_initdb(dir: PathBuf, port: u16) -> Result<Self> {
125        initdb(&dir)?;
126        Self::new(dir, port)
127    }
128
129    /// Return the url used to connect to the database
130    pub fn url(&self) -> &Url {
131        &self.url
132    }
133
134    fn start(&mut self) -> Result<()> {
135        if self.process.is_none() {
136            self.process = Some(PostgresProcess::start(self.dir.clone(), self.port)?);
137            self.wait_till_ready().with_context(|| {
138                format!(
139                    // may need to add breakpoint/sleep to prevent temp dir from being deleted
140                    "Unable to start postgres, check dir {} for logs",
141                    self.dir.display(),
142                )
143            })?;
144        }
145
146        Ok(())
147    }
148
149    fn health_check(&mut self) -> Result<(), HealthCheckError> {
150        if let Some(p) = &mut self.process {
151            match p.inner.try_wait() {
152                // This would mean the child process has crashed
153                Ok(Some(status)) => Err(HealthCheckError::NotRunning(Some(status))),
154
155                // This is the case where the process is still running
156                Ok(None) => pg_isready(self.port),
157
158                // Some other unknown error
159                Err(e) => Err(HealthCheckError::Unknown(e.to_string())),
160            }
161        } else {
162            Err(HealthCheckError::NotRunning(None))
163        }
164    }
165
166    fn wait_till_ready(&mut self) -> Result<(), HealthCheckError> {
167        let start = Instant::now();
168
169        while start.elapsed() < Duration::from_secs(10) {
170            match self.health_check() {
171                Err(HealthCheckError::NotReady) => {}
172                result => return result,
173            }
174
175            std::thread::sleep(Duration::from_millis(50));
176        }
177
178        Err(HealthCheckError::Unknown(
179            "timeout reached when waiting for service to be ready".to_owned(),
180        ))
181    }
182}
183
184impl PostgresProcess {
185    fn start(dir: PathBuf, port: u16) -> Result<Self> {
186        let child = Command::new("postgres")
187            // Set the data directory to use
188            .arg("-D")
189            .arg(&dir)
190            // Set the port to listen for incoming connections
191            .args(["-p", &port.to_string()])
192            // Disable creating and listening on a UDS
193            .args(["-c", "unix_socket_directories="])
194            // pipe stdout and stderr to files located in the data directory
195            .stdout(
196                OpenOptions::new()
197                    .create(true)
198                    .append(true)
199                    .open(dir.join("stdout"))?,
200            )
201            .stderr(
202                OpenOptions::new()
203                    .create(true)
204                    .append(true)
205                    .open(dir.join("stderr"))?,
206            )
207            .spawn()
208            .context("command not found: postgres")?;
209
210        Ok(Self { dir, inner: child })
211    }
212
213    // https://www.postgresql.org/docs/16/app-pg-ctl.html
214    fn pg_ctl_stop(&mut self) -> Result<()> {
215        let output = Command::new("pg_ctl")
216            .arg("stop")
217            .arg("-D")
218            .arg(&self.dir)
219            .arg("-mfast")
220            .output()
221            .context("command not found: pg_ctl")?;
222
223        if output.status.success() {
224            Ok(())
225        } else {
226            Err(anyhow!("couldn't shut down postgres"))
227        }
228    }
229
230    fn dump_stdout_stderr(&self) -> Result<(String, String)> {
231        let stdout = std::fs::read_to_string(self.dir.join("stdout"))?;
232        let stderr = std::fs::read_to_string(self.dir.join("stderr"))?;
233
234        Ok((stdout, stderr))
235    }
236}
237
238impl Drop for PostgresProcess {
239    // When the Process struct goes out of scope we need to kill the child process
240    fn drop(&mut self) {
241        info!("dropping postgres");
242        // check if the process has already been terminated
243        match self.inner.try_wait() {
244            // The child process has already terminated, perhaps due to a crash
245            Ok(Some(_)) => {}
246
247            // The process is still running so we need to attempt to kill it
248            _ => {
249                if self.pg_ctl_stop().is_err() {
250                    // Couldn't gracefully stop server so we'll just kill it
251                    self.inner.kill().expect("postgres couldn't be killed");
252                }
253                self.inner.wait().unwrap();
254            }
255        }
256
257        // Dump the contents of stdout/stderr if TRACE is enabled
258        if event_enabled!(tracing::Level::TRACE) {
259            if let Ok((stdout, stderr)) = self.dump_stdout_stderr() {
260                trace!("stdout: {stdout}");
261                trace!("stderr: {stderr}");
262            }
263        }
264    }
265}
266
267/// Run the postgres `pg_isready` command to get the status of database
268///
269/// See <https://www.postgresql.org/docs/16/app-pg-isready.html> for more info
270fn pg_isready(port: u16) -> Result<(), HealthCheckError> {
271    let output = Command::new("pg_isready")
272        .arg("--host=localhost")
273        .arg("-p")
274        .arg(port.to_string())
275        .arg("--username=postgres")
276        .output()
277        .map_err(|e| HealthCheckError::Unknown(format!("command not found: pg_ctl: {e}")))?;
278
279    trace!("pg_isready code: {:?}", output.status.code());
280    trace!("pg_isready output: {}", output.stderr.escape_ascii());
281    trace!("pg_isready output: {}", output.stdout.escape_ascii());
282    if output.status.success() {
283        Ok(())
284    } else {
285        Err(HealthCheckError::NotReady)
286    }
287}
288
289/// Run the postgres `initdb` command to initialize a database at the provided path
290///
291/// See <https://www.postgresql.org/docs/16/app-initdb.html> for more info
292fn initdb(dir: &Path) -> Result<()> {
293    let output = Command::new("initdb")
294        .arg("-D")
295        .arg(dir)
296        .arg("--no-instructions")
297        .arg("--username=postgres")
298        .output()
299        .context("command not found: initdb")?;
300
301    if output.status.success() {
302        Ok(())
303    } else {
304        Err(anyhow!(
305            "unable to initialize database: {:?}",
306            String::from_utf8(output.stderr)
307        ))
308    }
309}
310
311/// Return an ephemeral, available port. On unix systems, the port returned will be in the
312/// TIME_WAIT state ensuring that the OS won't hand out this port for some grace period.
313/// Callers should be able to bind to this port given they use SO_REUSEADDR.
314pub fn get_available_port() -> u16 {
315    const MAX_PORT_RETRIES: u32 = 1000;
316
317    for _ in 0..MAX_PORT_RETRIES {
318        if let Ok(port) = get_ephemeral_port() {
319            return port;
320        }
321    }
322
323    panic!("Error: could not find an available port");
324}
325
326fn get_ephemeral_port() -> std::io::Result<u16> {
327    // Request a random available port from the OS
328    let listener = std::net::TcpListener::bind(("127.0.0.1", 0))?;
329    let addr = listener.local_addr()?;
330
331    // Create and accept a connection (which we'll promptly drop) in order to force the port
332    // into the TIME_WAIT state, ensuring that the port will be reserved from some limited
333    // amount of time (roughly 60s on some Linux systems)
334    let _sender = std::net::TcpStream::connect(addr)?;
335    let _incoming = listener.accept()?;
336
337    Ok(addr.port())
338}