sui_sql_macro/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{
7    Error, Expr, LitStr, Result, Token, Type,
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11};
12
13use crate::{
14    lexer::Lexer,
15    parser::{Format, Parser},
16};
17
18mod lexer;
19mod parser;
20
21/// Rust syntax for `sql!(as T, "format", binds,*)`
22struct SqlInput {
23    return_: Type,
24    format_: LitStr,
25    binds: Punctuated<Expr, Token![,]>,
26}
27
28/// Rust syntax for `query!("format", binds,*)`.
29struct QueryInput {
30    format_: LitStr,
31    binds: Punctuated<Expr, Token![,]>,
32}
33
34impl Parse for SqlInput {
35    fn parse(input: ParseStream) -> Result<Self> {
36        input.parse::<Token![as]>()?;
37        let return_ = input.parse()?;
38        input.parse::<Token![,]>()?;
39        let format_ = input.parse()?;
40
41        if input.is_empty() {
42            return Ok(Self {
43                return_,
44                format_,
45                binds: Punctuated::new(),
46            });
47        }
48
49        input.parse::<Token![,]>()?;
50        let binds = Punctuated::parse_terminated(input)?;
51
52        Ok(Self {
53            return_,
54            format_,
55            binds,
56        })
57    }
58}
59
60impl Parse for QueryInput {
61    fn parse(input: ParseStream) -> Result<Self> {
62        let format_ = input.parse()?;
63
64        if input.is_empty() {
65            return Ok(Self {
66                format_,
67                binds: Punctuated::new(),
68            });
69        }
70
71        input.parse::<Token![,]>()?;
72        let binds = Punctuated::parse_terminated(input)?;
73
74        Ok(Self { format_, binds })
75    }
76}
77
78/// The `sql!` macro is used to construct a `diesel::SqlLiteral<T>` using a format string to
79/// describe the SQL snippet with the following syntax:
80///
81/// ```rust,ignore
82/// sql!(as T, "format", binds,*)
83/// ```
84///
85/// `T` is the `SqlType` that the literal will be interpreted as, as a Rust expression. The format
86/// string introduces binders with curly braces, surrounding the `SqlType` of the bound value. This
87/// type is given as a string which must correspond to a type in the `diesel::sql_types` module.
88/// Bound values follow in the order matching their binders in the string:
89///
90/// ```rust,ignore
91/// sql!(as Bool, "{BigInt} <= foo AND foo < {BigInt}", 5, 10)
92/// ```
93///
94/// The above macro invocation will generate the following code:
95///
96/// ```rust,ignore
97/// sql::<Bool>("")
98///    .bind::<BigInt, _>(5)
99///    .sql(" <= foo AND foo < ")
100///    .bind::<BigInt, _>(10)
101///    .sql("")
102/// ```
103#[proc_macro]
104pub fn sql(input: TokenStream) -> TokenStream {
105    let SqlInput {
106        return_,
107        format_,
108        binds,
109        ..
110    } = parse_macro_input!(input as SqlInput);
111
112    let format_str = format_.value();
113    let lexemes: Vec<_> = Lexer::new(&format_str).collect();
114    let Format { head, tail } = match Parser::new(&lexemes).format() {
115        Ok(format) => format,
116        Err(err) => {
117            return Error::new(format_.span(), err).into_compile_error().into();
118        }
119    };
120
121    let mut tokens = quote! {
122        ::diesel::dsl::sql::<#return_>(#head)
123    };
124
125    for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
126        tokens.extend(if let Some(ty) = ty {
127            quote! {
128                .bind::<::diesel::sql_types::#ty, _>(#expr)
129                .sql(#suffix)
130            }
131        } else {
132            // No type was provided for the bind parameter, so we use `Untyped` which will report
133            // an error because it doesn't implement `SqlType`.
134            quote! {
135                .bind::<::diesel::sql_types::Untyped, _>(#expr)
136                .sql(#suffix)
137            }
138        });
139    }
140
141    tokens.into()
142}
143
144/// The `query!` macro constructs a value that implements `diesel::query_builder::Query` -- a full
145/// SQL query, defined by a format string and binds with the following syntax:
146///
147/// ```rust,ignore
148/// query!("format", binds,*)
149/// ```
150///
151/// The format string introduces binders with curly braces. An empty binder interpolates another
152/// query at that position, otherwise the binder is expected to contain a `SqlType` for a value
153/// that will be bound into the query, given a string which must correspond to a type in the
154/// `diesel::sql_types` module. Bound values or queries to interpolate follow in the order matching
155/// their binders in the string:
156///
157/// ```rust,ignore
158/// query!("SELECT * FROM foo WHERE {BigInt} <= cursor AND {}", 5, query!("cursor < {BigInt}", 10))
159/// ```
160///
161/// The above macro invocation will generate the following SQL query:
162///
163/// ```sql
164/// SELECT * FROM foo WHERE $1 <= cursor AND cursor < $2 -- binds [5, 10]
165/// ```
166#[proc_macro]
167pub fn query(input: TokenStream) -> TokenStream {
168    let QueryInput { format_, binds } = parse_macro_input!(input as QueryInput);
169
170    let format_str = format_.value();
171    let lexemes: Vec<_> = Lexer::new(&format_str).collect();
172    let Format { head, tail } = match Parser::new(&lexemes).format() {
173        Ok(format) => format,
174        Err(err) => {
175            return Error::new(format_.span(), err).into_compile_error().into();
176        }
177    };
178
179    let mut tokens = quote! {
180        ::sui_pg_db::query::Query::new(#head)
181    };
182
183    for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
184        tokens.extend(if let Some(ty) = ty {
185            // If there is a type, this interpolation is for a bind.
186            quote! {
187                .bind::<::diesel::sql_types::#ty, _>(#expr)
188                .sql(#suffix)
189            }
190        } else {
191            // Otherwise, we are interpolating another query.
192            quote! {
193                .query(#expr)
194                .sql(#suffix)
195            }
196        });
197    }
198
199    tokens.into()
200}