1use std::{marker::PhantomData, ops};
5
6use diesel::{
7 QueryResult,
8 pg::Pg,
9 query_builder::{AstPass, QueryFragment, QueryId},
10 serialize::ToSql,
11 sql_types::{HasSqlType, Untyped},
12};
13
14#[derive(Default)]
26pub struct Query<'f> {
27 parts: Vec<Part<'f>>,
28}
29
30enum Part<'f> {
31 Sql(String),
32 Bind(Box<dyn QueryFragment<Pg> + Send + 'f>),
33}
34
35struct Bind<ST, U> {
36 value: U,
37 _data: PhantomData<ST>,
38}
39
40impl<'f> Query<'f> {
41 pub fn new(sql: impl AsRef<str>) -> Self {
43 Self {
44 parts: vec![Part::Sql(sql.as_ref().to_owned())],
45 }
46 }
47
48 pub fn query(mut self, query: Query<'f>) -> Self {
50 self.parts.extend(query.parts);
51 self
52 }
53
54 pub fn sql(mut self, sql: impl AsRef<str>) -> Self {
56 self.parts.push(Part::Sql(sql.as_ref().to_owned()));
57 self
58 }
59
60 pub fn bind<ST, V>(mut self, value: V) -> Self
62 where
63 Pg: HasSqlType<ST>,
64 V: ToSql<ST, Pg> + Send + 'f,
65 ST: Send + 'f,
66 {
67 self.parts.push(Part::Bind(Box::new(Bind {
68 value,
69 _data: PhantomData,
70 })));
71
72 self
73 }
74}
75
76impl QueryFragment<Pg> for Query<'_> {
77 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
78 for part in &self.parts {
79 match part {
80 Part::Sql(sql) => out.push_sql(sql),
81 Part::Bind(bind) => bind.walk_ast(out.reborrow())?,
82 }
83 }
84
85 Ok(())
86 }
87}
88
89impl<ST, U> QueryFragment<Pg> for Bind<ST, U>
90where
91 Pg: HasSqlType<ST>,
92 U: ToSql<ST, Pg>,
93{
94 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
95 out.push_bind_param(&self.value)
96 }
97}
98
99impl QueryId for Query<'_> {
100 type QueryId = ();
101 const HAS_STATIC_QUERY_ID: bool = false;
102}
103
104impl diesel::query_builder::Query for Query<'_> {
105 type SqlType = Untyped;
106}
107
108impl ops::Add for Query<'_> {
109 type Output = Self;
110
111 fn add(self, rhs: Self) -> Self::Output {
112 self.query(rhs)
113 }
114}
115
116impl ops::AddAssign for Query<'_> {
117 fn add_assign(&mut self, rhs: Self) {
118 self.parts.extend(rhs.parts);
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use diesel::sql_types::{BigInt, Text};
125
126 use super::*;
127
128 #[test]
130 fn basic_query() {
131 let q = Query::new("SELECT 1");
132 assert_eq!(diesel::debug_query(&q).to_string(), "SELECT 1 -- binds: []");
133 }
134
135 #[test]
137 fn query_extended() {
138 let q = Query::new("SELECT 1").sql(" WHERE id = 1");
139 assert_eq!(
140 diesel::debug_query(&q).to_string(),
141 "SELECT 1 WHERE id = 1 -- binds: []"
142 );
143 }
144
145 #[test]
147 fn query_with_binds() {
148 let q = Query::new("SELECT 1 WHERE ")
149 .sql("id = ")
150 .bind::<BigInt, _>(42)
151 .sql(" AND name = ")
152 .bind::<Text, _>("foo");
153
154 assert_eq!(
155 diesel::debug_query(&q).to_string(),
156 "SELECT 1 WHERE id = $1 AND name = $2 -- binds: [42, \"foo\"]"
157 );
158 }
159
160 #[test]
162 fn query_embedded() {
163 let r = Query::new("cursor >= ").bind::<BigInt, _>(10);
164 let s = Query::new("cursor <= ").bind::<BigInt, _>(20);
165 let q = Query::new("SELECT 1 WHERE ")
166 .sql("id = ")
167 .bind::<BigInt, _>(42)
168 .sql(" AND ")
169 .query(r)
170 .sql(" AND name = ")
171 .bind::<Text, _>("foo")
172 .sql(" AND ")
173 .query(s);
174
175 assert_eq!(
176 diesel::debug_query(&q).to_string(),
177 "SELECT 1 WHERE id = $1 AND cursor >= $2 AND name = $3 AND cursor <= $4 \
178 -- binds: [42, 10, \"foo\", 20]"
179 );
180 }
181}