sui_pg_db/
temp.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
// Copyright (c) Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use anyhow::anyhow;
use anyhow::Context;
use anyhow::Result;
use std::fs::OpenOptions;
use std::{
    path::{Path, PathBuf},
    process::{Child, Command},
    time::{Duration, Instant},
};
use tracing::{event_enabled, info, trace};
use url::Url;

/// A temporary, local postgres database
pub struct TempDb {
    database: LocalDatabase,

    // Directory used for the ephemeral database.
    //
    // On drop the directory will be cleaned an its contents deleted.
    //
    // NOTE: This needs to be the last entry in this struct so that the database is dropped before
    // and has a chance to gracefully shutdown before the directory is deleted.
    dir: tempfile::TempDir,
}

/// Local instance of a `postgres` server.
///
/// See <https://www.postgresql.org/docs/16/app-postgres.html> for more info.
pub struct LocalDatabase {
    dir: PathBuf,
    port: u16,
    url: Url,
    process: Option<PostgresProcess>,
}

#[derive(Debug)]
struct PostgresProcess {
    dir: PathBuf,
    inner: Child,
}

#[derive(Debug)]
enum HealthCheckError {
    NotRunning,
    NotReady,
    #[allow(unused)]
    Unknown(String),
}

impl TempDb {
    /// Create and start a new temporary postgres database.
    ///
    /// A fresh database will be initialized in a temporary directory that will be cleand up on drop.
    /// The running `postgres` service will be serving traffic on an available, os-assigned port.
    pub fn new() -> Result<Self> {
        let dir = tempfile::TempDir::new()?;
        let port = get_available_port();

        let database = LocalDatabase::new_initdb(dir.path().to_owned(), port)?;

        Ok(Self { dir, database })
    }

    pub fn database(&self) -> &LocalDatabase {
        &self.database
    }

    pub fn database_mut(&mut self) -> &mut LocalDatabase {
        &mut self.database
    }

    pub fn dir(&self) -> &Path {
        self.dir.path()
    }
}

impl LocalDatabase {
    /// Start a local `postgres` database service.
    ///
    /// `dir`: The location of the on-disk postgres database. The database must already exist at
    ///     the provided path. If you instead want to create a new database see `Self::new_initdb`.
    ///
    /// `port`: The port to listen for incoming connection on.
    pub fn new(dir: PathBuf, port: u16) -> Result<Self> {
        let url = format!(
            "postgres://postgres:postgrespw@localhost:{port}/{db_name}",
            db_name = "postgres"
        )
        .parse()
        .unwrap();
        let mut db = Self {
            dir,
            port,
            url,
            process: None,
        };
        db.start()?;
        Ok(db)
    }

    /// Initialize and start a local `postgres` database service.
    ///
    /// Unlike `Self::new`, this will initialize a clean database at the provided path.
    pub fn new_initdb(dir: PathBuf, port: u16) -> Result<Self> {
        initdb(&dir)?;
        Self::new(dir, port)
    }

    /// Return the url used to connect to the database
    pub fn url(&self) -> &Url {
        &self.url
    }

    fn start(&mut self) -> Result<()> {
        if self.process.is_none() {
            self.process = Some(PostgresProcess::start(self.dir.clone(), self.port)?);
            self.wait_till_ready()
                .map_err(|e| anyhow!("unable to start postgres: {e:?}"))?;
        }

        Ok(())
    }

    fn health_check(&mut self) -> Result<(), HealthCheckError> {
        if let Some(p) = &mut self.process {
            match p.inner.try_wait() {
                // This would mean the child process has crashed
                Ok(Some(_)) => Err(HealthCheckError::NotRunning),

                // This is the case where the process is still running
                Ok(None) => pg_isready(self.port),

                // Some other unknown error
                Err(e) => Err(HealthCheckError::Unknown(e.to_string())),
            }
        } else {
            Err(HealthCheckError::NotRunning)
        }
    }

    fn wait_till_ready(&mut self) -> Result<(), HealthCheckError> {
        let start = Instant::now();

        while start.elapsed() < Duration::from_secs(10) {
            match self.health_check() {
                Ok(()) => return Ok(()),
                Err(HealthCheckError::NotReady) => {}
                Err(HealthCheckError::NotRunning | HealthCheckError::Unknown(_)) => break,
            }

            std::thread::sleep(Duration::from_millis(50));
        }

        Err(HealthCheckError::Unknown(
            "timeout reached when waiting for service to be ready".to_owned(),
        ))
    }
}

impl PostgresProcess {
    fn start(dir: PathBuf, port: u16) -> Result<Self> {
        let child = Command::new("postgres")
            // Set the data directory to use
            .arg("-D")
            .arg(&dir)
            // Set the port to listen for incoming connections
            .args(["-p", &port.to_string()])
            // Disable creating and listening on a UDS
            .args(["-c", "unix_socket_directories="])
            // pipe stdout and stderr to files located in the data directory
            .stdout(
                OpenOptions::new()
                    .create(true)
                    .append(true)
                    .open(dir.join("stdout"))?,
            )
            .stderr(
                OpenOptions::new()
                    .create(true)
                    .append(true)
                    .open(dir.join("stderr"))?,
            )
            .spawn()
            .context("command not found: postgres")?;

        Ok(Self { dir, inner: child })
    }

    // https://www.postgresql.org/docs/16/app-pg-ctl.html
    fn pg_ctl_stop(&mut self) -> Result<()> {
        let output = Command::new("pg_ctl")
            .arg("stop")
            .arg("-D")
            .arg(&self.dir)
            .arg("-mfast")
            .output()
            .context("command not found: pg_ctl")?;

        if output.status.success() {
            Ok(())
        } else {
            Err(anyhow!("couldn't shut down postgres"))
        }
    }

    fn dump_stdout_stderr(&self) -> Result<(String, String)> {
        let stdout = std::fs::read_to_string(self.dir.join("stdout"))?;
        let stderr = std::fs::read_to_string(self.dir.join("stderr"))?;

        Ok((stdout, stderr))
    }
}

impl Drop for PostgresProcess {
    // When the Process struct goes out of scope we need to kill the child process
    fn drop(&mut self) {
        info!("dropping postgres");
        // check if the process has already been terminated
        match self.inner.try_wait() {
            // The child process has already terminated, perhaps due to a crash
            Ok(Some(_)) => {}

            // The process is still running so we need to attempt to kill it
            _ => {
                if self.pg_ctl_stop().is_err() {
                    // Couldn't gracefully stop server so we'll just kill it
                    self.inner.kill().expect("postgres couldn't be killed");
                }
                self.inner.wait().unwrap();
            }
        }

        // Dump the contents of stdout/stderr if TRACE is enabled
        if event_enabled!(tracing::Level::TRACE) {
            if let Ok((stdout, stderr)) = self.dump_stdout_stderr() {
                trace!("stdout: {stdout}");
                trace!("stderr: {stderr}");
            }
        }
    }
}

/// Run the postgres `pg_isready` command to get the status of database
///
/// See <https://www.postgresql.org/docs/16/app-pg-isready.html> for more info
fn pg_isready(port: u16) -> Result<(), HealthCheckError> {
    let output = Command::new("pg_isready")
        .arg("--host=localhost")
        .arg("-p")
        .arg(port.to_string())
        .arg("--username=postgres")
        .output()
        .map_err(|e| HealthCheckError::Unknown(format!("command not found: pg_ctl: {e}")))?;

    trace!("pg_isready code: {:?}", output.status.code());
    trace!("pg_isready output: {}", output.stderr.escape_ascii());
    trace!("pg_isready output: {}", output.stdout.escape_ascii());
    if output.status.success() {
        Ok(())
    } else {
        Err(HealthCheckError::NotReady)
    }
}

/// Run the postgres `initdb` command to initialize a database at the provided path
///
/// See <https://www.postgresql.org/docs/16/app-initdb.html> for more info
fn initdb(dir: &Path) -> Result<()> {
    let output = Command::new("initdb")
        .arg("-D")
        .arg(dir)
        .arg("--no-instructions")
        .arg("--username=postgres")
        .output()
        .context("command not found: initdb")?;

    if output.status.success() {
        Ok(())
    } else {
        Err(anyhow!(
            "unable to initialize database: {:?}",
            String::from_utf8(output.stderr)
        ))
    }
}

/// Return an ephemeral, available port. On unix systems, the port returned will be in the
/// TIME_WAIT state ensuring that the OS won't hand out this port for some grace period.
/// Callers should be able to bind to this port given they use SO_REUSEADDR.
pub fn get_available_port() -> u16 {
    const MAX_PORT_RETRIES: u32 = 1000;

    for _ in 0..MAX_PORT_RETRIES {
        if let Ok(port) = get_ephemeral_port() {
            return port;
        }
    }

    panic!("Error: could not find an available port");
}

fn get_ephemeral_port() -> std::io::Result<u16> {
    // Request a random available port from the OS
    let listener = std::net::TcpListener::bind(("127.0.0.1", 0))?;
    let addr = listener.local_addr()?;

    // Create and accept a connection (which we'll promptly drop) in order to force the port
    // into the TIME_WAIT state, ensuring that the port will be reserved from some limited
    // amount of time (roughly 60s on some Linux systems)
    let _sender = std::net::TcpStream::connect(addr)?;
    let _incoming = listener.accept()?;

    Ok(addr.port())
}