sui_sql_macro/lib.rs
1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{
7 Error, Expr, LitStr, Result, Token, Type,
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11};
12
13use crate::{
14 lexer::Lexer,
15 parser::{Format, Parser},
16};
17
18mod lexer;
19mod parser;
20
21/// Rust syntax for `sql!(as T, "format", binds,*)`
22struct SqlInput {
23 return_: Type,
24 format_: LitStr,
25 binds: Punctuated<Expr, Token![,]>,
26}
27
28/// Rust syntax for `query!("format", binds,*)`.
29struct QueryInput {
30 format_: LitStr,
31 binds: Punctuated<Expr, Token![,]>,
32}
33
34impl Parse for SqlInput {
35 fn parse(input: ParseStream) -> Result<Self> {
36 input.parse::<Token![as]>()?;
37 let return_ = input.parse()?;
38 input.parse::<Token![,]>()?;
39 let format_ = input.parse()?;
40
41 if input.is_empty() {
42 return Ok(Self {
43 return_,
44 format_,
45 binds: Punctuated::new(),
46 });
47 }
48
49 input.parse::<Token![,]>()?;
50 let binds = Punctuated::parse_terminated(input)?;
51
52 Ok(Self {
53 return_,
54 format_,
55 binds,
56 })
57 }
58}
59
60impl Parse for QueryInput {
61 fn parse(input: ParseStream) -> Result<Self> {
62 let format_ = input.parse()?;
63
64 if input.is_empty() {
65 return Ok(Self {
66 format_,
67 binds: Punctuated::new(),
68 });
69 }
70
71 input.parse::<Token![,]>()?;
72 let binds = Punctuated::parse_terminated(input)?;
73
74 Ok(Self { format_, binds })
75 }
76}
77
78/// The `sql!` macro is used to construct a `diesel::SqlLiteral<T>` using a format string to
79/// describe the SQL snippet with the following syntax:
80///
81/// ```rust,ignore
82/// sql!(as T, "format", binds,*)
83/// ```
84///
85/// `T` is the `SqlType` that the literal will be interpreted as, as a Rust expression. The format
86/// string introduces binders with curly braces, surrounding the `SqlType` of the bound value. This
87/// type is given as a string which must correspond to a type in the `diesel::sql_types` module.
88/// Bound values follow in the order matching their binders in the string:
89///
90/// ```rust,ignore
91/// sql!(as Bool, "{BigInt} <= foo AND foo < {BigInt}", 5, 10)
92/// ```
93///
94/// The above macro invocation will generate the following code:
95///
96/// ```rust,ignore
97/// sql::<Bool>("")
98/// .bind::<BigInt, _>(5)
99/// .sql(" <= foo AND foo < ")
100/// .bind::<BigInt, _>(10)
101/// .sql("")
102/// ```
103#[proc_macro]
104pub fn sql(input: TokenStream) -> TokenStream {
105 let SqlInput {
106 return_,
107 format_,
108 binds,
109 ..
110 } = parse_macro_input!(input as SqlInput);
111
112 let format_str = format_.value();
113 let lexemes: Vec<_> = Lexer::new(&format_str).collect();
114 let Format { head, tail } = match Parser::new(&lexemes).format() {
115 Ok(format) => format,
116 Err(err) => {
117 return Error::new(format_.span(), err).into_compile_error().into();
118 }
119 };
120
121 let mut tokens = quote! {
122 ::diesel::dsl::sql::<#return_>(#head)
123 };
124
125 for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
126 tokens.extend(if let Some(ty) = ty {
127 quote! {
128 .bind::<::diesel::sql_types::#ty, _>(#expr)
129 .sql(#suffix)
130 }
131 } else {
132 // No type was provided for the bind parameter, so we use `Untyped` which will report
133 // an error because it doesn't implement `SqlType`.
134 quote! {
135 .bind::<::diesel::sql_types::Untyped, _>(#expr)
136 .sql(#suffix)
137 }
138 });
139 }
140
141 tokens.into()
142}
143
144/// The `query!` macro constructs a value that implements `diesel::query_builder::Query` -- a full
145/// SQL query, defined by a format string and binds with the following syntax:
146///
147/// ```rust,ignore
148/// query!("format", binds,*)
149/// ```
150///
151/// The format string introduces binders with curly braces. An empty binder interpolates another
152/// query at that position, otherwise the binder is expected to contain a `SqlType` for a value
153/// that will be bound into the query, given a string which must correspond to a type in the
154/// `diesel::sql_types` module. Bound values or queries to interpolate follow in the order matching
155/// their binders in the string:
156///
157/// ```rust,ignore
158/// query!("SELECT * FROM foo WHERE {BigInt} <= cursor AND {}", 5, query!("cursor < {BigInt}", 10))
159/// ```
160///
161/// The above macro invocation will generate the following SQL query:
162///
163/// ```sql
164/// SELECT * FROM foo WHERE $1 <= cursor AND cursor < $2 -- binds [5, 10]
165/// ```
166#[proc_macro]
167pub fn query(input: TokenStream) -> TokenStream {
168 let QueryInput { format_, binds } = parse_macro_input!(input as QueryInput);
169
170 let format_str = format_.value();
171 let lexemes: Vec<_> = Lexer::new(&format_str).collect();
172 let Format { head, tail } = match Parser::new(&lexemes).format() {
173 Ok(format) => format,
174 Err(err) => {
175 return Error::new(format_.span(), err).into_compile_error().into();
176 }
177 };
178
179 let mut tokens = quote! {
180 ::sui_pg_db::query::Query::new(#head)
181 };
182
183 for (expr, (ty, suffix)) in binds.iter().zip(tail.into_iter()) {
184 tokens.extend(if let Some(ty) = ty {
185 // If there is a type, this interpolation is for a bind.
186 quote! {
187 .bind::<::diesel::sql_types::#ty, _>(#expr)
188 .sql(#suffix)
189 }
190 } else {
191 // Otherwise, we are interpolating another query.
192 quote! {
193 .query(#expr)
194 .sql(#suffix)
195 }
196 });
197 }
198
199 tokens.into()
200}