sui_pg_db/
query.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{marker::PhantomData, ops};
5
6use diesel::{
7    QueryResult,
8    pg::Pg,
9    query_builder::{AstPass, QueryFragment, QueryId},
10    serialize::ToSql,
11    sql_types::{HasSqlType, Untyped},
12};
13
14/// A full SQL query constructed from snippets of raw SQL and bindings.
15///
16/// This abstraction is similar to [`diesel::query_builder::BoxedSqlQuery`], with the following
17/// differences:
18///
19/// - Binds are specified inline, and the abstraction deals with inserting a bind parameter into
20///   the SQL output (similar to how [`diesel::dsl::sql`] works).
21///
22/// - It is possible to embed one query into another.
23///
24/// - Queries can be built using the [`sui_sql_macro::query!`] macro, using format strings.
25#[derive(Default)]
26pub struct Query<'f> {
27    parts: Vec<Part<'f>>,
28}
29
30enum Part<'f> {
31    Sql(String),
32    Bind(Box<dyn QueryFragment<Pg> + Send + 'f>),
33}
34
35struct Bind<ST, U> {
36    value: U,
37    _data: PhantomData<ST>,
38}
39
40impl<'f> Query<'f> {
41    /// Construct a new query starting with the `sql` snippet.
42    pub fn new(sql: impl AsRef<str>) -> Self {
43        Self {
44            parts: vec![Part::Sql(sql.as_ref().to_owned())],
45        }
46    }
47
48    /// Append `query` at the end of `self`.
49    pub fn query(mut self, query: Query<'f>) -> Self {
50        self.parts.extend(query.parts);
51        self
52    }
53
54    /// Add a raw `sql` snippet to the end of the query.
55    pub fn sql(mut self, sql: impl AsRef<str>) -> Self {
56        self.parts.push(Part::Sql(sql.as_ref().to_owned()));
57        self
58    }
59
60    /// Embed `value` into the query as a bind parameter, at the end of the query.
61    pub fn bind<ST, V>(mut self, value: V) -> Self
62    where
63        Pg: HasSqlType<ST>,
64        V: ToSql<ST, Pg> + Send + 'f,
65        ST: Send + 'f,
66    {
67        self.parts.push(Part::Bind(Box::new(Bind {
68            value,
69            _data: PhantomData,
70        })));
71
72        self
73    }
74}
75
76impl QueryFragment<Pg> for Query<'_> {
77    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
78        for part in &self.parts {
79            match part {
80                Part::Sql(sql) => out.push_sql(sql),
81                Part::Bind(bind) => bind.walk_ast(out.reborrow())?,
82            }
83        }
84
85        Ok(())
86    }
87}
88
89impl<ST, U> QueryFragment<Pg> for Bind<ST, U>
90where
91    Pg: HasSqlType<ST>,
92    U: ToSql<ST, Pg>,
93{
94    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
95        out.push_bind_param(&self.value)
96    }
97}
98
99impl QueryId for Query<'_> {
100    type QueryId = ();
101    const HAS_STATIC_QUERY_ID: bool = false;
102}
103
104impl diesel::query_builder::Query for Query<'_> {
105    type SqlType = Untyped;
106}
107
108impl ops::Add for Query<'_> {
109    type Output = Self;
110
111    fn add(self, rhs: Self) -> Self::Output {
112        self.query(rhs)
113    }
114}
115
116impl ops::AddAssign for Query<'_> {
117    fn add_assign(&mut self, rhs: Self) {
118        self.parts.extend(rhs.parts);
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use diesel::sql_types::{BigInt, Text};
125
126    use super::*;
127
128    /// A query without any binds.
129    #[test]
130    fn basic_query() {
131        let q = Query::new("SELECT 1");
132        assert_eq!(diesel::debug_query(&q).to_string(), "SELECT 1 -- binds: []");
133    }
134
135    /// A query that has been extended with more SQL.
136    #[test]
137    fn query_extended() {
138        let q = Query::new("SELECT 1").sql(" WHERE id = 1");
139        assert_eq!(
140            diesel::debug_query(&q).to_string(),
141            "SELECT 1 WHERE id = 1 -- binds: []"
142        );
143    }
144
145    /// A query that has some binds.
146    #[test]
147    fn query_with_binds() {
148        let q = Query::new("SELECT 1 WHERE ")
149            .sql("id = ")
150            .bind::<BigInt, _>(42)
151            .sql(" AND name = ")
152            .bind::<Text, _>("foo");
153
154        assert_eq!(
155            diesel::debug_query(&q).to_string(),
156            "SELECT 1 WHERE id = $1 AND name = $2 -- binds: [42, \"foo\"]"
157        );
158    }
159
160    /// A query that has other queries embedded into it.
161    #[test]
162    fn query_embedded() {
163        let r = Query::new("cursor >= ").bind::<BigInt, _>(10);
164        let s = Query::new("cursor <= ").bind::<BigInt, _>(20);
165        let q = Query::new("SELECT 1 WHERE ")
166            .sql("id = ")
167            .bind::<BigInt, _>(42)
168            .sql(" AND ")
169            .query(r)
170            .sql(" AND name = ")
171            .bind::<Text, _>("foo")
172            .sql(" AND ")
173            .query(s);
174
175        assert_eq!(
176            diesel::debug_query(&q).to_string(),
177            "SELECT 1 WHERE id = $1 AND cursor >= $2 AND name = $3 AND cursor <= $4 \
178            -- binds: [42, 10, \"foo\", 20]"
179        );
180    }
181}