sui_sql_macro/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5
6use quote::quote;
7use syn::Error;
8use syn::Expr;
9use syn::LitStr;
10use syn::Result;
11use syn::Token;
12use syn::Type;
13use syn::parse::Parse;
14use syn::parse::ParseStream;
15use syn::parse_macro_input;
16use syn::punctuated::Punctuated;
17
18use crate::lexer::Lexer;
19use crate::parser::Format;
20use crate::parser::Parser;
21
22mod lexer;
23mod parser;
24
25/// Rust syntax for `sql!(as T, "format", binds,*)`
26struct SqlInput {
27    return_: Type,
28    format_: LitStr,
29    binds: Punctuated<Expr, Token![,]>,
30}
31
32/// Rust syntax for `query!("format", binds,*)`.
33struct QueryInput {
34    format_: LitStr,
35    binds: Punctuated<Expr, Token![,]>,
36}
37
38impl Parse for SqlInput {
39    fn parse(input: ParseStream) -> Result<Self> {
40        input.parse::<Token![as]>()?;
41        let return_ = input.parse()?;
42        input.parse::<Token![,]>()?;
43        let format_ = input.parse()?;
44
45        if input.is_empty() {
46            return Ok(Self {
47                return_,
48                format_,
49                binds: Punctuated::new(),
50            });
51        }
52
53        input.parse::<Token![,]>()?;
54        let binds = Punctuated::parse_terminated(input)?;
55
56        Ok(Self {
57            return_,
58            format_,
59            binds,
60        })
61    }
62}
63
64impl Parse for QueryInput {
65    fn parse(input: ParseStream) -> Result<Self> {
66        let format_ = input.parse()?;
67
68        if input.is_empty() {
69            return Ok(Self {
70                format_,
71                binds: Punctuated::new(),
72            });
73        }
74
75        input.parse::<Token![,]>()?;
76        let binds = Punctuated::parse_terminated(input)?;
77
78        Ok(Self { format_, binds })
79    }
80}
81
82/// The `sql!` macro is used to construct a `diesel::SqlLiteral<T>` using a format string to
83/// describe the SQL snippet with the following syntax:
84///
85/// ```rust,ignore
86/// sql!(as T, "format", binds,*)
87/// ```
88///
89/// `T` is the `SqlType` that the literal will be interpreted as, as a Rust expression. The format
90/// string introduces binders with curly braces, surrounding the `SqlType` of the bound value. This
91/// type is given as a string which must correspond to a type in the `diesel::sql_types` module.
92/// Bound values follow in the order matching their binders in the string:
93///
94/// ```rust,ignore
95/// sql!(as Bool, "{BigInt} <= foo AND foo < {BigInt}", 5, 10)
96/// ```
97///
98/// The above macro invocation will generate the following code:
99///
100/// ```rust,ignore
101/// sql::<Bool>("")
102///    .bind::<BigInt, _>(5)
103///    .sql(" <= foo AND foo < ")
104///    .bind::<BigInt, _>(10)
105///    .sql("")
106/// ```
107#[proc_macro]
108pub fn sql(input: TokenStream) -> TokenStream {
109    let SqlInput {
110        return_,
111        format_,
112        binds,
113        ..
114    } = parse_macro_input!(input as SqlInput);
115
116    let format_str = format_.value();
117    let lexemes: Vec<_> = Lexer::new(&format_str).collect();
118    let Format { head, tail } = match Parser::new(&lexemes).format() {
119        Ok(format) => format,
120        Err(err) => {
121            return Error::new(format_.span(), err).into_compile_error().into();
122        }
123    };
124
125    let mut tokens = quote! {
126        ::diesel::dsl::sql::<#return_>(#head)
127    };
128
129    // Intentional zip: proc-macro crate, zip_debug_eq not applicable at compile time
130    #[allow(clippy::disallowed_methods)]
131    for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
132        tokens.extend(if let Some(ty) = ty {
133            quote! {
134                .bind::<::diesel::sql_types::#ty, _>(#expr)
135                .sql(#suffix)
136            }
137        } else {
138            // No type was provided for the bind parameter, so we use `Untyped` which will report
139            // an error because it doesn't implement `SqlType`.
140            quote! {
141                .bind::<::diesel::sql_types::Untyped, _>(#expr)
142                .sql(#suffix)
143            }
144        });
145    }
146
147    tokens.into()
148}
149
150/// The `query!` macro constructs a value that implements `diesel::query_builder::Query` -- a full
151/// SQL query, defined by a format string and binds with the following syntax:
152///
153/// ```rust,ignore
154/// query!("format", binds,*)
155/// ```
156///
157/// The format string introduces binders with curly braces. An empty binder interpolates another
158/// query at that position, otherwise the binder is expected to contain a `SqlType` for a value
159/// that will be bound into the query, given a string which must correspond to a type in the
160/// `diesel::sql_types` module. Bound values or queries to interpolate follow in the order matching
161/// their binders in the string:
162///
163/// ```rust,ignore
164/// query!("SELECT * FROM foo WHERE {BigInt} <= cursor AND {}", 5, query!("cursor < {BigInt}", 10))
165/// ```
166///
167/// The above macro invocation will generate the following SQL query:
168///
169/// ```sql
170/// SELECT * FROM foo WHERE $1 <= cursor AND cursor < $2 -- binds [5, 10]
171/// ```
172#[proc_macro]
173pub fn query(input: TokenStream) -> TokenStream {
174    let QueryInput { format_, binds } = parse_macro_input!(input as QueryInput);
175
176    let format_str = format_.value();
177    let lexemes: Vec<_> = Lexer::new(&format_str).collect();
178    let Format { head, tail } = match Parser::new(&lexemes).format() {
179        Ok(format) => format,
180        Err(err) => {
181            return Error::new(format_.span(), err).into_compile_error().into();
182        }
183    };
184
185    let mut tokens = quote! {
186        ::sui_pg_db::query::Query::new(#head)
187    };
188
189    // Intentional zip: proc-macro crate, zip_debug_eq not applicable at compile time
190    #[allow(clippy::disallowed_methods)]
191    for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
192        tokens.extend(if let Some(ty) = ty {
193            // If there is a type, this interpolation is for a bind.
194            quote! {
195                .bind::<::diesel::sql_types::#ty, _>(#expr)
196                .sql(#suffix)
197            }
198        } else {
199            // Otherwise, we are interpolating another query.
200            quote! {
201                .query(#expr)
202                .sql(#suffix)
203            }
204        });
205    }
206
207    tokens.into()
208}