sui_sql_macro/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5
6use quote::quote;
7use syn::Error;
8use syn::Expr;
9use syn::LitStr;
10use syn::Result;
11use syn::Token;
12use syn::Type;
13use syn::parse::Parse;
14use syn::parse::ParseStream;
15use syn::parse_macro_input;
16use syn::punctuated::Punctuated;
17
18use crate::lexer::Lexer;
19use crate::parser::Format;
20use crate::parser::Parser;
21
22mod lexer;
23mod parser;
24
25/// Rust syntax for `sql!(as T, "format", binds,*)`
26struct SqlInput {
27    return_: Type,
28    format_: LitStr,
29    binds: Punctuated<Expr, Token![,]>,
30}
31
32/// Rust syntax for `query!("format", binds,*)`.
33struct QueryInput {
34    format_: LitStr,
35    binds: Punctuated<Expr, Token![,]>,
36}
37
38impl Parse for SqlInput {
39    fn parse(input: ParseStream) -> Result<Self> {
40        input.parse::<Token![as]>()?;
41        let return_ = input.parse()?;
42        input.parse::<Token![,]>()?;
43        let format_ = input.parse()?;
44
45        if input.is_empty() {
46            return Ok(Self {
47                return_,
48                format_,
49                binds: Punctuated::new(),
50            });
51        }
52
53        input.parse::<Token![,]>()?;
54        let binds = Punctuated::parse_terminated(input)?;
55
56        Ok(Self {
57            return_,
58            format_,
59            binds,
60        })
61    }
62}
63
64impl Parse for QueryInput {
65    fn parse(input: ParseStream) -> Result<Self> {
66        let format_ = input.parse()?;
67
68        if input.is_empty() {
69            return Ok(Self {
70                format_,
71                binds: Punctuated::new(),
72            });
73        }
74
75        input.parse::<Token![,]>()?;
76        let binds = Punctuated::parse_terminated(input)?;
77
78        Ok(Self { format_, binds })
79    }
80}
81
82/// The `sql!` macro is used to construct a `diesel::SqlLiteral<T>` using a format string to
83/// describe the SQL snippet with the following syntax:
84///
85/// ```rust,ignore
86/// sql!(as T, "format", binds,*)
87/// ```
88///
89/// `T` is the `SqlType` that the literal will be interpreted as, as a Rust expression. The format
90/// string introduces binders with curly braces, surrounding the `SqlType` of the bound value. This
91/// type is given as a string which must correspond to a type in the `diesel::sql_types` module.
92/// Bound values follow in the order matching their binders in the string:
93///
94/// ```rust,ignore
95/// sql!(as Bool, "{BigInt} <= foo AND foo < {BigInt}", 5, 10)
96/// ```
97///
98/// The above macro invocation will generate the following code:
99///
100/// ```rust,ignore
101/// sql::<Bool>("")
102///    .bind::<BigInt, _>(5)
103///    .sql(" <= foo AND foo < ")
104///    .bind::<BigInt, _>(10)
105///    .sql("")
106/// ```
107#[proc_macro]
108pub fn sql(input: TokenStream) -> TokenStream {
109    let SqlInput {
110        return_,
111        format_,
112        binds,
113        ..
114    } = parse_macro_input!(input as SqlInput);
115
116    let format_str = format_.value();
117    let lexemes: Vec<_> = Lexer::new(&format_str).collect();
118    let Format { head, tail } = match Parser::new(&lexemes).format() {
119        Ok(format) => format,
120        Err(err) => {
121            return Error::new(format_.span(), err).into_compile_error().into();
122        }
123    };
124
125    let mut tokens = quote! {
126        ::diesel::dsl::sql::<#return_>(#head)
127    };
128
129    for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
130        tokens.extend(if let Some(ty) = ty {
131            quote! {
132                .bind::<::diesel::sql_types::#ty, _>(#expr)
133                .sql(#suffix)
134            }
135        } else {
136            // No type was provided for the bind parameter, so we use `Untyped` which will report
137            // an error because it doesn't implement `SqlType`.
138            quote! {
139                .bind::<::diesel::sql_types::Untyped, _>(#expr)
140                .sql(#suffix)
141            }
142        });
143    }
144
145    tokens.into()
146}
147
148/// The `query!` macro constructs a value that implements `diesel::query_builder::Query` -- a full
149/// SQL query, defined by a format string and binds with the following syntax:
150///
151/// ```rust,ignore
152/// query!("format", binds,*)
153/// ```
154///
155/// The format string introduces binders with curly braces. An empty binder interpolates another
156/// query at that position, otherwise the binder is expected to contain a `SqlType` for a value
157/// that will be bound into the query, given a string which must correspond to a type in the
158/// `diesel::sql_types` module. Bound values or queries to interpolate follow in the order matching
159/// their binders in the string:
160///
161/// ```rust,ignore
162/// query!("SELECT * FROM foo WHERE {BigInt} <= cursor AND {}", 5, query!("cursor < {BigInt}", 10))
163/// ```
164///
165/// The above macro invocation will generate the following SQL query:
166///
167/// ```sql
168/// SELECT * FROM foo WHERE $1 <= cursor AND cursor < $2 -- binds [5, 10]
169/// ```
170#[proc_macro]
171pub fn query(input: TokenStream) -> TokenStream {
172    let QueryInput { format_, binds } = parse_macro_input!(input as QueryInput);
173
174    let format_str = format_.value();
175    let lexemes: Vec<_> = Lexer::new(&format_str).collect();
176    let Format { head, tail } = match Parser::new(&lexemes).format() {
177        Ok(format) => format,
178        Err(err) => {
179            return Error::new(format_.span(), err).into_compile_error().into();
180        }
181    };
182
183    let mut tokens = quote! {
184        ::sui_pg_db::query::Query::new(#head)
185    };
186
187    for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
188        tokens.extend(if let Some(ty) = ty {
189            // If there is a type, this interpolation is for a bind.
190            quote! {
191                .bind::<::diesel::sql_types::#ty, _>(#expr)
192                .sql(#suffix)
193            }
194        } else {
195            // Otherwise, we are interpolating another query.
196            quote! {
197                .query(#expr)
198                .sql(#suffix)
199            }
200        });
201    }
202
203    tokens.into()
204}