1use proc_macro::TokenStream;
5use quote::{ToTokens, quote, quote_spanned};
6use syn::{
7 Attribute, BinOp, Data, DataEnum, DeriveInput, Expr, ExprBinary, ExprMacro, Item, ItemMacro,
8 Stmt, StmtMacro, Token, UnOp,
9 fold::{Fold, fold_expr, fold_item_macro, fold_stmt},
10 parse::Parser,
11 parse_macro_input, parse2,
12 punctuated::Punctuated,
13 spanned::Spanned,
14};
15
16#[proc_macro_attribute]
17pub fn init_static_initializers(_args: TokenStream, item: TokenStream) -> TokenStream {
18 let mut input = parse_macro_input!(item as syn::ItemFn);
19
20 let body = &input.block;
21 input.block = syn::parse2(quote! {
22 {
23 std::thread::spawn(|| {
39 use sui_protocol_config::ProtocolConfig;
40 ::sui_simulator::telemetry_subscribers::init_for_testing();
41 ::sui_simulator::sui_types::execution_params::get_denied_certificates_for_sim_test();
42 ::sui_simulator::sui_framework::BuiltInFramework::all_package_ids();
43 ::sui_simulator::sui_types::gas::SuiGasStatus::new_unmetered();
44
45 let mut cache = ::sui_simulator::lru::LruCache::new(1.try_into().unwrap());
49 cache.put(1, 1);
50
51 use std::sync::Arc;
52
53 use ::sui_simulator::anemo_tower::callback::CallbackLayer;
54 use ::sui_simulator::anemo_tower::trace::DefaultMakeSpan;
55 use ::sui_simulator::anemo_tower::trace::DefaultOnFailure;
56 use ::sui_simulator::anemo_tower::trace::TraceLayer;
57 use ::sui_simulator::fastcrypto::traits::KeyPair;
58 use ::sui_simulator::mysten_network::metrics::MetricsMakeCallbackHandler;
59 use ::sui_simulator::mysten_network::metrics::NetworkMetrics;
60 use ::sui_simulator::rand_crate::rngs::{StdRng, OsRng};
61 use ::sui_simulator::rand::SeedableRng;
62 use ::sui_simulator::tower::ServiceBuilder;
63
64 let rt = ::sui_simulator::runtime::Runtime::new();
67 rt.block_on(async move {
68 use ::sui_simulator::anemo::{Network, Request};
69
70 {
71 use std::path::PathBuf;
74 use sui_simulator::sui_move_build::BuildConfig;
75 use sui_simulator::tempfile::TempDir;
76
77 let mut path = PathBuf::from(env!("SIMTEST_STATIC_INIT_MOVE"));
78 let mut build_config = BuildConfig::new_for_testing();
79 build_config.config.install_dir = Some(TempDir::new().unwrap().keep());
80 let _all_module_bytes = build_config
81 .build_async(&path)
82 .await
83 .unwrap()
84 .get_package_bytes(false);
85 }
86
87 let make_network = |port: u16| {
88 let registry = prometheus::Registry::new();
89 let inbound_network_metrics =
90 NetworkMetrics::new("sui", "inbound", ®istry);
91 let outbound_network_metrics =
92 NetworkMetrics::new("sui", "outbound", ®istry);
93
94 let service = ServiceBuilder::new()
95 .layer(
96 TraceLayer::new_for_server_errors()
97 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
98 .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
99 )
100 .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
101 Arc::new(inbound_network_metrics),
102 usize::MAX,
103 )))
104 .service(::sui_simulator::anemo::Router::new());
105
106 let outbound_layer = ServiceBuilder::new()
107 .layer(
108 TraceLayer::new_for_client_and_server_errors()
109 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
110 .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
111 )
112 .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
113 Arc::new(outbound_network_metrics),
114 usize::MAX,
115 )))
116 .into_inner();
117
118
119 Network::bind(format!("127.0.0.1:{}", port))
120 .server_name("static-init-network")
121 .private_key(
122 ::sui_simulator::fastcrypto::ed25519::Ed25519KeyPair::generate(&mut StdRng::from_rng(OsRng).unwrap())
123 .private()
124 .0
125 .to_bytes(),
126 )
127 .start(service)
128 .unwrap()
129 };
130 let n1 = make_network(80);
131 let n2 = make_network(81);
132
133 let _peer = n1.connect(n2.local_addr()).await.unwrap();
134 });
135 }).join().unwrap();
136
137 #body
138 }
139 })
140 .expect("Parsing failure");
141
142 let result = quote! {
143 #input
144 };
145
146 result.into()
147}
148
149#[proc_macro_attribute]
154pub fn sui_test(args: TokenStream, item: TokenStream) -> TokenStream {
155 let input = parse_macro_input!(item as syn::ItemFn);
156 let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
157 let args = arg_parser.parse(args).unwrap().into_iter();
158
159 let header = if cfg!(msim) {
160 quote! {
161 #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args)* )]
162 }
163 } else {
164 quote! {
165 #[::tokio::test(#(#args)*)]
166 }
167 };
168
169 let result = quote! {
170 #header
171 #[::sui_macros::init_static_initializers]
172 #input
173 };
174
175 result.into()
176}
177
178#[proc_macro_attribute]
185pub fn sim_test(args: TokenStream, item: TokenStream) -> TokenStream {
186 let input = parse_macro_input!(item as syn::ItemFn);
187 let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
188 let args = arg_parser.parse(args).unwrap().into_iter();
189
190 let ignore = input
191 .attrs
192 .iter()
193 .find(|attr| attr.path().is_ident("ignore"))
194 .map_or(quote! {}, |_| quote! { #[ignore] });
195
196 let result = if cfg!(msim) {
197 let sig = &input.sig;
198 let return_type = &sig.output;
199 let body = &input.block;
200 quote! {
201 #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args),*)]
202 #[::sui_macros::init_static_initializers]
203 #ignore
204 #sig {
205 async fn body_fn() #return_type { #body }
206
207 let ret = body_fn().await;
208
209 ::sui_simulator::task::shutdown_all_nodes();
210
211 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
214
215 assert_eq!(
216 sui_simulator::NodeLeakDetector::get_current_node_count(),
217 0,
218 "SuiNode leak detected"
219 );
220
221 ret
222 }
223 }
224 } else {
225 let fn_name = &input.sig.ident;
226 let sig = &input.sig;
227 let body = &input.block;
228 quote! {
229 #[allow(clippy::needless_return)]
230 #[tokio::test]
231 #ignore
232 #sig {
233 if std::env::var("SUI_SKIP_SIMTESTS").is_ok() {
234 println!("not running test {} in `cargo test`: SUI_SKIP_SIMTESTS is set", stringify!(#fn_name));
235
236 struct Ret;
237
238 impl From<Ret> for () {
239 fn from(_ret: Ret) -> Self {
240 }
241 }
242
243 impl<E> From<Ret> for Result<(), E> {
244 fn from(_ret: Ret) -> Self {
245 Ok(())
246 }
247 }
248
249 return Ret.into();
250 }
251
252 #body
253 }
254 }
255 };
256
257 result.into()
258}
259
260#[proc_macro]
261pub fn checked_arithmetic(input: TokenStream) -> TokenStream {
262 let input_file = CheckArithmetic.fold_file(parse_macro_input!(input));
263
264 let output_items = input_file.items;
265
266 let output = quote! {
267 #(#output_items)*
268 };
269
270 TokenStream::from(output)
271}
272
273#[proc_macro_attribute]
274pub fn with_checked_arithmetic(_attr: TokenStream, item: TokenStream) -> TokenStream {
275 let input_item = parse_macro_input!(item as Item);
276 match input_item {
277 Item::Fn(input_fn) => {
278 let transformed_fn = CheckArithmetic.fold_item_fn(input_fn);
279 TokenStream::from(quote! { #transformed_fn })
280 }
281 Item::Impl(input_impl) => {
282 let transformed_impl = CheckArithmetic.fold_item_impl(input_impl);
283 TokenStream::from(quote! { #transformed_impl })
284 }
285 item => {
286 let transformed_impl = CheckArithmetic.fold_item(item);
287 TokenStream::from(quote! { #transformed_impl })
288 }
289 }
290}
291
292struct CheckArithmetic;
293
294impl CheckArithmetic {
295 fn maybe_skip_macro(&self, attrs: &mut Vec<Attribute>) -> bool {
296 if let Some(idx) = attrs
297 .iter()
298 .position(|attr| attr.path().is_ident("skip_checked_arithmetic"))
299 {
300 attrs.remove(idx);
303 true
304 } else {
305 false
306 }
307 }
308
309 fn process_macro_contents(
310 &mut self,
311 tokens: proc_macro2::TokenStream,
312 ) -> syn::Result<proc_macro2::TokenStream> {
313 let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
315 let Ok(exprs) = parser.parse(tokens.clone().into()) else {
316 return Err(syn::Error::new_spanned(
317 tokens,
318 "could not process macro contents - use #[skip_checked_arithmetic] to skip this macro",
319 ));
320 };
321
322 let folded_exprs = exprs
324 .into_iter()
325 .map(|expr| self.fold_expr(expr))
326 .collect::<Vec<_>>();
327
328 let mut folded_tokens = proc_macro2::TokenStream::new();
330 for (i, folded_expr) in folded_exprs.into_iter().enumerate() {
331 if i > 0 {
332 folded_tokens.extend(std::iter::once::<proc_macro2::TokenTree>(
333 proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone).into(),
334 ));
335 }
336 folded_expr.to_tokens(&mut folded_tokens);
337 }
338
339 Ok(folded_tokens)
340 }
341}
342
343impl Fold for CheckArithmetic {
344 fn fold_stmt(&mut self, stmt: Stmt) -> Stmt {
345 let stmt = fold_stmt(self, stmt);
346 if let Stmt::Macro(stmt_macro) = stmt {
347 let StmtMacro {
348 mut attrs,
349 mut mac,
350 semi_token,
351 } = stmt_macro;
352
353 if self.maybe_skip_macro(&mut attrs) {
354 Stmt::Macro(StmtMacro {
355 attrs,
356 mac,
357 semi_token,
358 })
359 } else {
360 match self.process_macro_contents(mac.tokens.clone()) {
361 Ok(folded_tokens) => {
362 mac.tokens = folded_tokens;
363 Stmt::Macro(StmtMacro {
364 attrs,
365 mac,
366 semi_token,
367 })
368 }
369 Err(error) => parse2(error.to_compile_error()).unwrap(),
370 }
371 }
372 } else {
373 stmt
374 }
375 }
376
377 fn fold_item_macro(&mut self, mut item_macro: ItemMacro) -> ItemMacro {
378 if !self.maybe_skip_macro(&mut item_macro.attrs) {
379 let err = syn::Error::new_spanned(
380 item_macro.to_token_stream(),
381 "cannot process macros - use #[skip_checked_arithmetic] to skip \
382 processing this macro",
383 );
384
385 return parse2(err.to_compile_error()).unwrap();
386 }
387 fold_item_macro(self, item_macro)
388 }
389
390 fn fold_expr(&mut self, expr: Expr) -> Expr {
391 let span = expr.span();
392 let expr = fold_expr(self, expr);
393 let expr = match expr {
394 Expr::Macro(expr_macro) => {
395 let ExprMacro { mut attrs, mut mac } = expr_macro;
396
397 if self.maybe_skip_macro(&mut attrs) {
398 return Expr::Macro(ExprMacro { attrs, mac });
399 } else {
400 match self.process_macro_contents(mac.tokens.clone()) {
401 Ok(folded_tokens) => {
402 mac.tokens = folded_tokens;
403 let expr_macro = Expr::Macro(ExprMacro { attrs, mac });
404 quote!(#expr_macro)
405 }
406 Err(error) => {
407 return Expr::Verbatim(error.to_compile_error());
408 }
409 }
410 }
411 }
412
413 Expr::Binary(expr_binary) => {
414 let ExprBinary {
415 attrs,
416 mut left,
417 op,
418 mut right,
419 } = expr_binary;
420
421 fn remove_parens(expr: &mut Expr) {
422 if let Expr::Paren(paren) = expr {
423 assert!(paren.attrs.is_empty(), "TODO: attrs on parenthesized");
425 *expr = *paren.expr.clone();
426 }
427 }
428
429 macro_rules! wrap_op {
430 ($left: expr, $right: expr, $method: ident, $span: expr) => {{
431 remove_parens(&mut $left);
434 remove_parens(&mut $right);
435
436 quote_spanned!($span => {
437 let (left, right) = (#left, #right);
440 left.$method(right)
441 .unwrap_or_else(||
442 panic!(
443 "Overflow or underflow in {} {} + {}",
444 stringify!($method),
445 left,
446 right,
447 )
448 )
449 })
450 }};
451 }
452
453 macro_rules! wrap_op_assign {
454 ($left: expr, $right: expr, $method: ident, $span: expr) => {{
455 remove_parens(&mut $left);
458 remove_parens(&mut $right);
459
460 quote_spanned!($span => {
461 let (left, right) = (&mut #left, #right);
464 *left = (*left).$method(right)
465 .unwrap_or_else(||
466 panic!(
467 "Overflow or underflow in {} {} + {}",
468 stringify!($method),
469 *left,
470 right
471 )
472 )
473 })
474 }};
475 }
476
477 match op {
478 BinOp::Add(_) => {
479 wrap_op!(left, right, checked_add, span)
480 }
481 BinOp::Sub(_) => {
482 wrap_op!(left, right, checked_sub, span)
483 }
484 BinOp::Mul(_) => {
485 wrap_op!(left, right, checked_mul, span)
486 }
487 BinOp::Div(_) => {
488 wrap_op!(left, right, checked_div, span)
489 }
490 BinOp::Rem(_) => {
491 wrap_op!(left, right, checked_rem, span)
492 }
493 BinOp::AddAssign(_) => {
494 wrap_op_assign!(left, right, checked_add, span)
495 }
496 BinOp::SubAssign(_) => {
497 wrap_op_assign!(left, right, checked_sub, span)
498 }
499 BinOp::MulAssign(_) => {
500 wrap_op_assign!(left, right, checked_mul, span)
501 }
502 BinOp::DivAssign(_) => {
503 wrap_op_assign!(left, right, checked_div, span)
504 }
505 BinOp::RemAssign(_) => {
506 wrap_op_assign!(left, right, checked_rem, span)
507 }
508 _ => {
509 let expr_binary = ExprBinary {
510 attrs,
511 left,
512 op,
513 right,
514 };
515 quote_spanned!(span => #expr_binary)
516 }
517 }
518 }
519 Expr::Unary(expr_unary) => {
520 let op = &expr_unary.op;
521 let operand = &expr_unary.expr;
522 match op {
523 UnOp::Neg(_) => {
524 quote_spanned!(span => #operand.checked_neg().expect("Overflow or underflow in negation"))
525 }
526 _ => quote_spanned!(span => #expr_unary),
527 }
528 }
529 _ => quote_spanned!(span => #expr),
530 };
531
532 parse2(expr).unwrap()
533 }
534}
535
536#[proc_macro_derive(EnumVariantOrder)]
553pub fn enum_variant_order_derive(input: TokenStream) -> TokenStream {
554 let ast = parse_macro_input!(input as DeriveInput);
555 let name = &ast.ident;
556
557 if let Data::Enum(DataEnum { variants, .. }) = ast.data {
558 let variant_entries = variants
559 .iter()
560 .enumerate()
561 .map(|(index, variant)| {
562 let variant_name = variant.ident.to_string();
563 quote! {
564 map.insert( #index as u64, (#variant_name).to_string());
565 }
566 })
567 .collect::<Vec<_>>();
568
569 let deriv = quote! {
570 impl sui_enum_compat_util::EnumOrderMap for #name {
571 fn order_to_variant_map() -> std::collections::BTreeMap<u64, String > {
572 let mut map = std::collections::BTreeMap::new();
573 #(#variant_entries)*
574 map
575 }
576 }
577 };
578
579 deriv.into()
580 } else {
581 panic!("EnumVariantOrder can only be used with enums.");
582 }
583}