use std::{marker::PhantomData, ops};
use diesel::{
pg::Pg,
query_builder::{AstPass, QueryFragment, QueryId},
serialize::ToSql,
sql_types::{HasSqlType, Untyped},
QueryResult,
};
#[derive(Default)]
pub struct Query<'f> {
parts: Vec<Part<'f>>,
}
enum Part<'f> {
Sql(String),
Bind(Box<dyn QueryFragment<Pg> + Send + 'f>),
}
struct Bind<ST, U> {
value: U,
_data: PhantomData<ST>,
}
impl<'f> Query<'f> {
pub fn new(sql: impl AsRef<str>) -> Self {
Self {
parts: vec![Part::Sql(sql.as_ref().to_owned())],
}
}
pub fn query(mut self, query: Query<'f>) -> Self {
self.parts.extend(query.parts);
self
}
pub fn sql(mut self, sql: impl AsRef<str>) -> Self {
self.parts.push(Part::Sql(sql.as_ref().to_owned()));
self
}
pub fn bind<ST, V>(mut self, value: V) -> Self
where
Pg: HasSqlType<ST>,
V: ToSql<ST, Pg> + Send + 'f,
ST: Send + 'f,
{
self.parts.push(Part::Bind(Box::new(Bind {
value,
_data: PhantomData,
})));
self
}
}
impl QueryFragment<Pg> for Query<'_> {
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
for part in &self.parts {
match part {
Part::Sql(sql) => out.push_sql(sql),
Part::Bind(bind) => bind.walk_ast(out.reborrow())?,
}
}
Ok(())
}
}
impl<ST, U> QueryFragment<Pg> for Bind<ST, U>
where
Pg: HasSqlType<ST>,
U: ToSql<ST, Pg>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
out.push_bind_param(&self.value)
}
}
impl QueryId for Query<'_> {
type QueryId = ();
const HAS_STATIC_QUERY_ID: bool = false;
}
impl diesel::query_builder::Query for Query<'_> {
type SqlType = Untyped;
}
impl ops::Add for Query<'_> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
self.query(rhs)
}
}
impl ops::AddAssign for Query<'_> {
fn add_assign(&mut self, rhs: Self) {
self.parts.extend(rhs.parts);
}
}
#[cfg(test)]
mod tests {
use diesel::sql_types::{BigInt, Text};
use super::*;
#[test]
fn basic_query() {
let q = Query::new("SELECT 1");
assert_eq!(diesel::debug_query(&q).to_string(), "SELECT 1 -- binds: []");
}
#[test]
fn query_extended() {
let q = Query::new("SELECT 1").sql(" WHERE id = 1");
assert_eq!(
diesel::debug_query(&q).to_string(),
"SELECT 1 WHERE id = 1 -- binds: []"
);
}
#[test]
fn query_with_binds() {
let q = Query::new("SELECT 1 WHERE ")
.sql("id = ")
.bind::<BigInt, _>(42)
.sql(" AND name = ")
.bind::<Text, _>("foo");
assert_eq!(
diesel::debug_query(&q).to_string(),
"SELECT 1 WHERE id = $1 AND name = $2 -- binds: [42, \"foo\"]"
);
}
#[test]
fn query_embedded() {
let r = Query::new("cursor >= ").bind::<BigInt, _>(10);
let s = Query::new("cursor <= ").bind::<BigInt, _>(20);
let q = Query::new("SELECT 1 WHERE ")
.sql("id = ")
.bind::<BigInt, _>(42)
.sql(" AND ")
.query(r)
.sql(" AND name = ")
.bind::<Text, _>("foo")
.sql(" AND ")
.query(s);
assert_eq!(
diesel::debug_query(&q).to_string(),
"SELECT 1 WHERE id = $1 AND cursor >= $2 AND name = $3 AND cursor <= $4 \
-- binds: [42, 10, \"foo\", 20]"
);
}
}