sui_sql_macro/lib.rs
1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5
6use quote::quote;
7use syn::Error;
8use syn::Expr;
9use syn::LitStr;
10use syn::Result;
11use syn::Token;
12use syn::Type;
13use syn::parse::Parse;
14use syn::parse::ParseStream;
15use syn::parse_macro_input;
16use syn::punctuated::Punctuated;
17
18use crate::lexer::Lexer;
19use crate::parser::Format;
20use crate::parser::Parser;
21
22mod lexer;
23mod parser;
24
25/// Rust syntax for `sql!(as T, "format", binds,*)`
26struct SqlInput {
27 return_: Type,
28 format_: LitStr,
29 binds: Punctuated<Expr, Token![,]>,
30}
31
32/// Rust syntax for `query!("format", binds,*)`.
33struct QueryInput {
34 format_: LitStr,
35 binds: Punctuated<Expr, Token![,]>,
36}
37
38impl Parse for SqlInput {
39 fn parse(input: ParseStream) -> Result<Self> {
40 input.parse::<Token![as]>()?;
41 let return_ = input.parse()?;
42 input.parse::<Token![,]>()?;
43 let format_ = input.parse()?;
44
45 if input.is_empty() {
46 return Ok(Self {
47 return_,
48 format_,
49 binds: Punctuated::new(),
50 });
51 }
52
53 input.parse::<Token![,]>()?;
54 let binds = Punctuated::parse_terminated(input)?;
55
56 Ok(Self {
57 return_,
58 format_,
59 binds,
60 })
61 }
62}
63
64impl Parse for QueryInput {
65 fn parse(input: ParseStream) -> Result<Self> {
66 let format_ = input.parse()?;
67
68 if input.is_empty() {
69 return Ok(Self {
70 format_,
71 binds: Punctuated::new(),
72 });
73 }
74
75 input.parse::<Token![,]>()?;
76 let binds = Punctuated::parse_terminated(input)?;
77
78 Ok(Self { format_, binds })
79 }
80}
81
82/// The `sql!` macro is used to construct a `diesel::SqlLiteral<T>` using a format string to
83/// describe the SQL snippet with the following syntax:
84///
85/// ```rust,ignore
86/// sql!(as T, "format", binds,*)
87/// ```
88///
89/// `T` is the `SqlType` that the literal will be interpreted as, as a Rust expression. The format
90/// string introduces binders with curly braces, surrounding the `SqlType` of the bound value. This
91/// type is given as a string which must correspond to a type in the `diesel::sql_types` module.
92/// Bound values follow in the order matching their binders in the string:
93///
94/// ```rust,ignore
95/// sql!(as Bool, "{BigInt} <= foo AND foo < {BigInt}", 5, 10)
96/// ```
97///
98/// The above macro invocation will generate the following code:
99///
100/// ```rust,ignore
101/// sql::<Bool>("")
102/// .bind::<BigInt, _>(5)
103/// .sql(" <= foo AND foo < ")
104/// .bind::<BigInt, _>(10)
105/// .sql("")
106/// ```
107#[proc_macro]
108pub fn sql(input: TokenStream) -> TokenStream {
109 let SqlInput {
110 return_,
111 format_,
112 binds,
113 ..
114 } = parse_macro_input!(input as SqlInput);
115
116 let format_str = format_.value();
117 let lexemes: Vec<_> = Lexer::new(&format_str).collect();
118 let Format { head, tail } = match Parser::new(&lexemes).format() {
119 Ok(format) => format,
120 Err(err) => {
121 return Error::new(format_.span(), err).into_compile_error().into();
122 }
123 };
124
125 let mut tokens = quote! {
126 ::diesel::dsl::sql::<#return_>(#head)
127 };
128
129 for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
130 tokens.extend(if let Some(ty) = ty {
131 quote! {
132 .bind::<::diesel::sql_types::#ty, _>(#expr)
133 .sql(#suffix)
134 }
135 } else {
136 // No type was provided for the bind parameter, so we use `Untyped` which will report
137 // an error because it doesn't implement `SqlType`.
138 quote! {
139 .bind::<::diesel::sql_types::Untyped, _>(#expr)
140 .sql(#suffix)
141 }
142 });
143 }
144
145 tokens.into()
146}
147
148/// The `query!` macro constructs a value that implements `diesel::query_builder::Query` -- a full
149/// SQL query, defined by a format string and binds with the following syntax:
150///
151/// ```rust,ignore
152/// query!("format", binds,*)
153/// ```
154///
155/// The format string introduces binders with curly braces. An empty binder interpolates another
156/// query at that position, otherwise the binder is expected to contain a `SqlType` for a value
157/// that will be bound into the query, given a string which must correspond to a type in the
158/// `diesel::sql_types` module. Bound values or queries to interpolate follow in the order matching
159/// their binders in the string:
160///
161/// ```rust,ignore
162/// query!("SELECT * FROM foo WHERE {BigInt} <= cursor AND {}", 5, query!("cursor < {BigInt}", 10))
163/// ```
164///
165/// The above macro invocation will generate the following SQL query:
166///
167/// ```sql
168/// SELECT * FROM foo WHERE $1 <= cursor AND cursor < $2 -- binds [5, 10]
169/// ```
170#[proc_macro]
171pub fn query(input: TokenStream) -> TokenStream {
172 let QueryInput { format_, binds } = parse_macro_input!(input as QueryInput);
173
174 let format_str = format_.value();
175 let lexemes: Vec<_> = Lexer::new(&format_str).collect();
176 let Format { head, tail } = match Parser::new(&lexemes).format() {
177 Ok(format) => format,
178 Err(err) => {
179 return Error::new(format_.span(), err).into_compile_error().into();
180 }
181 };
182
183 let mut tokens = quote! {
184 ::sui_pg_db::query::Query::new(#head)
185 };
186
187 for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
188 tokens.extend(if let Some(ty) = ty {
189 // If there is a type, this interpolation is for a bind.
190 quote! {
191 .bind::<::diesel::sql_types::#ty, _>(#expr)
192 .sql(#suffix)
193 }
194 } else {
195 // Otherwise, we are interpolating another query.
196 quote! {
197 .query(#expr)
198 .sql(#suffix)
199 }
200 });
201 }
202
203 tokens.into()
204}