sui_pg_db/
query.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::marker::PhantomData;
5use std::ops;
6
7use diesel::QueryResult;
8use diesel::pg::Pg;
9use diesel::query_builder::AstPass;
10use diesel::query_builder::QueryFragment;
11use diesel::query_builder::QueryId;
12use diesel::serialize::ToSql;
13use diesel::sql_types::HasSqlType;
14use diesel::sql_types::Untyped;
15
16/// A full SQL query constructed from snippets of raw SQL and bindings.
17///
18/// This abstraction is similar to [`diesel::query_builder::BoxedSqlQuery`], with the following
19/// differences:
20///
21/// - Binds are specified inline, and the abstraction deals with inserting a bind parameter into
22///   the SQL output (similar to how [`diesel::dsl::sql`] works).
23///
24/// - It is possible to embed one query into another.
25///
26/// - Queries can be built using the [`sui_sql_macro::query!`] macro, using format strings.
27#[derive(Default)]
28pub struct Query<'f> {
29    parts: Vec<Part<'f>>,
30}
31
32enum Part<'f> {
33    Sql(String),
34    Bind(Box<dyn QueryFragment<Pg> + Send + 'f>),
35}
36
37struct Bind<ST, U> {
38    value: U,
39    _data: PhantomData<ST>,
40}
41
42impl<'f> Query<'f> {
43    /// Construct a new query starting with the `sql` snippet.
44    pub fn new(sql: impl AsRef<str>) -> Self {
45        Self {
46            parts: vec![Part::Sql(sql.as_ref().to_owned())],
47        }
48    }
49
50    /// Append `query` at the end of `self`.
51    pub fn query(mut self, query: Query<'f>) -> Self {
52        self.parts.extend(query.parts);
53        self
54    }
55
56    /// Add a raw `sql` snippet to the end of the query.
57    pub fn sql(mut self, sql: impl AsRef<str>) -> Self {
58        self.parts.push(Part::Sql(sql.as_ref().to_owned()));
59        self
60    }
61
62    /// Embed `value` into the query as a bind parameter, at the end of the query.
63    pub fn bind<ST, V>(mut self, value: V) -> Self
64    where
65        Pg: HasSqlType<ST>,
66        V: ToSql<ST, Pg> + Send + 'f,
67        ST: Send + 'f,
68    {
69        self.parts.push(Part::Bind(Box::new(Bind {
70            value,
71            _data: PhantomData,
72        })));
73
74        self
75    }
76}
77
78impl QueryFragment<Pg> for Query<'_> {
79    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
80        for part in &self.parts {
81            match part {
82                Part::Sql(sql) => out.push_sql(sql),
83                Part::Bind(bind) => bind.walk_ast(out.reborrow())?,
84            }
85        }
86
87        Ok(())
88    }
89}
90
91impl<ST, U> QueryFragment<Pg> for Bind<ST, U>
92where
93    Pg: HasSqlType<ST>,
94    U: ToSql<ST, Pg>,
95{
96    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
97        out.push_bind_param(&self.value)
98    }
99}
100
101impl QueryId for Query<'_> {
102    type QueryId = ();
103    const HAS_STATIC_QUERY_ID: bool = false;
104}
105
106impl diesel::query_builder::Query for Query<'_> {
107    type SqlType = Untyped;
108}
109
110impl ops::Add for Query<'_> {
111    type Output = Self;
112
113    fn add(self, rhs: Self) -> Self::Output {
114        self.query(rhs)
115    }
116}
117
118impl ops::AddAssign for Query<'_> {
119    fn add_assign(&mut self, rhs: Self) {
120        self.parts.extend(rhs.parts);
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use diesel::sql_types::BigInt;
127    use diesel::sql_types::Text;
128
129    use super::*;
130
131    /// A query without any binds.
132    #[test]
133    fn basic_query() {
134        let q = Query::new("SELECT 1");
135        assert_eq!(diesel::debug_query(&q).to_string(), "SELECT 1 -- binds: []");
136    }
137
138    /// A query that has been extended with more SQL.
139    #[test]
140    fn query_extended() {
141        let q = Query::new("SELECT 1").sql(" WHERE id = 1");
142        assert_eq!(
143            diesel::debug_query(&q).to_string(),
144            "SELECT 1 WHERE id = 1 -- binds: []"
145        );
146    }
147
148    /// A query that has some binds.
149    #[test]
150    fn query_with_binds() {
151        let q = Query::new("SELECT 1 WHERE ")
152            .sql("id = ")
153            .bind::<BigInt, _>(42)
154            .sql(" AND name = ")
155            .bind::<Text, _>("foo");
156
157        assert_eq!(
158            diesel::debug_query(&q).to_string(),
159            "SELECT 1 WHERE id = $1 AND name = $2 -- binds: [42, \"foo\"]"
160        );
161    }
162
163    /// A query that has other queries embedded into it.
164    #[test]
165    fn query_embedded() {
166        let r = Query::new("cursor >= ").bind::<BigInt, _>(10);
167        let s = Query::new("cursor <= ").bind::<BigInt, _>(20);
168        let q = Query::new("SELECT 1 WHERE ")
169            .sql("id = ")
170            .bind::<BigInt, _>(42)
171            .sql(" AND ")
172            .query(r)
173            .sql(" AND name = ")
174            .bind::<Text, _>("foo")
175            .sql(" AND ")
176            .query(s);
177
178        assert_eq!(
179            diesel::debug_query(&q).to_string(),
180            "SELECT 1 WHERE id = $1 AND cursor >= $2 AND name = $3 AND cursor <= $4 \
181            -- binds: [42, 10, \"foo\", 20]"
182        );
183    }
184}