sui_sql_macro/lib.rs
1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5
6use quote::quote;
7use syn::Error;
8use syn::Expr;
9use syn::LitStr;
10use syn::Result;
11use syn::Token;
12use syn::Type;
13use syn::parse::Parse;
14use syn::parse::ParseStream;
15use syn::parse_macro_input;
16use syn::punctuated::Punctuated;
17
18use crate::lexer::Lexer;
19use crate::parser::Format;
20use crate::parser::Parser;
21
22mod lexer;
23mod parser;
24
25/// Rust syntax for `sql!(as T, "format", binds,*)`
26struct SqlInput {
27 return_: Type,
28 format_: LitStr,
29 binds: Punctuated<Expr, Token![,]>,
30}
31
32/// Rust syntax for `query!("format", binds,*)`.
33struct QueryInput {
34 format_: LitStr,
35 binds: Punctuated<Expr, Token![,]>,
36}
37
38impl Parse for SqlInput {
39 fn parse(input: ParseStream) -> Result<Self> {
40 input.parse::<Token![as]>()?;
41 let return_ = input.parse()?;
42 input.parse::<Token![,]>()?;
43 let format_ = input.parse()?;
44
45 if input.is_empty() {
46 return Ok(Self {
47 return_,
48 format_,
49 binds: Punctuated::new(),
50 });
51 }
52
53 input.parse::<Token![,]>()?;
54 let binds = Punctuated::parse_terminated(input)?;
55
56 Ok(Self {
57 return_,
58 format_,
59 binds,
60 })
61 }
62}
63
64impl Parse for QueryInput {
65 fn parse(input: ParseStream) -> Result<Self> {
66 let format_ = input.parse()?;
67
68 if input.is_empty() {
69 return Ok(Self {
70 format_,
71 binds: Punctuated::new(),
72 });
73 }
74
75 input.parse::<Token![,]>()?;
76 let binds = Punctuated::parse_terminated(input)?;
77
78 Ok(Self { format_, binds })
79 }
80}
81
82/// The `sql!` macro is used to construct a `diesel::SqlLiteral<T>` using a format string to
83/// describe the SQL snippet with the following syntax:
84///
85/// ```rust,ignore
86/// sql!(as T, "format", binds,*)
87/// ```
88///
89/// `T` is the `SqlType` that the literal will be interpreted as, as a Rust expression. The format
90/// string introduces binders with curly braces, surrounding the `SqlType` of the bound value. This
91/// type is given as a string which must correspond to a type in the `diesel::sql_types` module.
92/// Bound values follow in the order matching their binders in the string:
93///
94/// ```rust,ignore
95/// sql!(as Bool, "{BigInt} <= foo AND foo < {BigInt}", 5, 10)
96/// ```
97///
98/// The above macro invocation will generate the following code:
99///
100/// ```rust,ignore
101/// sql::<Bool>("")
102/// .bind::<BigInt, _>(5)
103/// .sql(" <= foo AND foo < ")
104/// .bind::<BigInt, _>(10)
105/// .sql("")
106/// ```
107#[proc_macro]
108pub fn sql(input: TokenStream) -> TokenStream {
109 let SqlInput {
110 return_,
111 format_,
112 binds,
113 ..
114 } = parse_macro_input!(input as SqlInput);
115
116 let format_str = format_.value();
117 let lexemes: Vec<_> = Lexer::new(&format_str).collect();
118 let Format { head, tail } = match Parser::new(&lexemes).format() {
119 Ok(format) => format,
120 Err(err) => {
121 return Error::new(format_.span(), err).into_compile_error().into();
122 }
123 };
124
125 let mut tokens = quote! {
126 ::diesel::dsl::sql::<#return_>(#head)
127 };
128
129 // Intentional zip: proc-macro crate, zip_debug_eq not applicable at compile time
130 #[allow(clippy::disallowed_methods)]
131 for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
132 tokens.extend(if let Some(ty) = ty {
133 quote! {
134 .bind::<::diesel::sql_types::#ty, _>(#expr)
135 .sql(#suffix)
136 }
137 } else {
138 // No type was provided for the bind parameter, so we use `Untyped` which will report
139 // an error because it doesn't implement `SqlType`.
140 quote! {
141 .bind::<::diesel::sql_types::Untyped, _>(#expr)
142 .sql(#suffix)
143 }
144 });
145 }
146
147 tokens.into()
148}
149
150/// The `query!` macro constructs a value that implements `diesel::query_builder::Query` -- a full
151/// SQL query, defined by a format string and binds with the following syntax:
152///
153/// ```rust,ignore
154/// query!("format", binds,*)
155/// ```
156///
157/// The format string introduces binders with curly braces. An empty binder interpolates another
158/// query at that position, otherwise the binder is expected to contain a `SqlType` for a value
159/// that will be bound into the query, given a string which must correspond to a type in the
160/// `diesel::sql_types` module. Bound values or queries to interpolate follow in the order matching
161/// their binders in the string:
162///
163/// ```rust,ignore
164/// query!("SELECT * FROM foo WHERE {BigInt} <= cursor AND {}", 5, query!("cursor < {BigInt}", 10))
165/// ```
166///
167/// The above macro invocation will generate the following SQL query:
168///
169/// ```sql
170/// SELECT * FROM foo WHERE $1 <= cursor AND cursor < $2 -- binds [5, 10]
171/// ```
172#[proc_macro]
173pub fn query(input: TokenStream) -> TokenStream {
174 let QueryInput { format_, binds } = parse_macro_input!(input as QueryInput);
175
176 let format_str = format_.value();
177 let lexemes: Vec<_> = Lexer::new(&format_str).collect();
178 let Format { head, tail } = match Parser::new(&lexemes).format() {
179 Ok(format) => format,
180 Err(err) => {
181 return Error::new(format_.span(), err).into_compile_error().into();
182 }
183 };
184
185 let mut tokens = quote! {
186 ::sui_pg_db::query::Query::new(#head)
187 };
188
189 // Intentional zip: proc-macro crate, zip_debug_eq not applicable at compile time
190 #[allow(clippy::disallowed_methods)]
191 for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
192 tokens.extend(if let Some(ty) = ty {
193 // If there is a type, this interpolation is for a bind.
194 quote! {
195 .bind::<::diesel::sql_types::#ty, _>(#expr)
196 .sql(#suffix)
197 }
198 } else {
199 // Otherwise, we are interpolating another query.
200 quote! {
201 .query(#expr)
202 .sql(#suffix)
203 }
204 });
205 }
206
207 tokens.into()
208}