1use proc_macro::TokenStream;
5use quote::{ToTokens, quote, quote_spanned};
6use syn::{
7 Attribute, BinOp, Data, DataEnum, DeriveInput, Expr, ExprBinary, ExprMacro, Item, ItemMacro,
8 Stmt, StmtMacro, Token, UnOp,
9 fold::{Fold, fold_expr, fold_item_macro, fold_stmt},
10 parse::Parser,
11 parse_macro_input, parse2,
12 punctuated::Punctuated,
13 spanned::Spanned,
14};
15
16#[proc_macro_attribute]
17pub fn init_static_initializers(_args: TokenStream, item: TokenStream) -> TokenStream {
18 let mut input = parse_macro_input!(item as syn::ItemFn);
19
20 let body = &input.block;
21 input.block = syn::parse2(quote! {
22 {
23 std::thread::spawn(|| {
39 use sui_protocol_config::ProtocolConfig;
40 ::sui_simulator::telemetry_subscribers::init_for_testing();
41 ::sui_simulator::sui_types::execution_params::get_denied_certificates_for_sim_test();
42 ::sui_simulator::sui_framework::BuiltInFramework::all_package_ids();
43 ::sui_simulator::sui_types::gas::SuiGasStatus::new_unmetered();
44
45 let mut cache = ::sui_simulator::lru::LruCache::new(1.try_into().unwrap());
49 cache.put(1, 1);
50
51 use std::sync::Arc;
52
53 use ::sui_simulator::anemo_tower::callback::CallbackLayer;
54 use ::sui_simulator::anemo_tower::trace::DefaultMakeSpan;
55 use ::sui_simulator::anemo_tower::trace::DefaultOnFailure;
56 use ::sui_simulator::anemo_tower::trace::TraceLayer;
57 use ::sui_simulator::fastcrypto::traits::KeyPair;
58 use ::sui_simulator::mysten_network::metrics::MetricsMakeCallbackHandler;
59 use ::sui_simulator::mysten_network::metrics::NetworkMetrics;
60 use ::sui_simulator::rand_crate::rngs::{StdRng, OsRng};
61 use ::sui_simulator::rand::SeedableRng;
62 use ::sui_simulator::tower::ServiceBuilder;
63
64 let rt = ::sui_simulator::runtime::Runtime::new();
67 rt.block_on(async move {
68 use ::sui_simulator::anemo::{Network, Request};
69
70 {
71 use std::path::PathBuf;
74 use sui_simulator::sui_move_build::BuildConfig;
75 use sui_simulator::tempfile::TempDir;
76
77 let mut path = PathBuf::from(env!("SIMTEST_STATIC_INIT_MOVE"));
78 let mut build_config = BuildConfig::new_for_testing();
79 build_config.config.install_dir = Some(TempDir::new().unwrap().keep());
80 let _all_module_bytes = build_config
81 .build_async(&path)
82 .await
83 .unwrap()
84 .get_package_bytes(false);
85 }
86
87 let make_network = |port: u16| {
88 let registry = prometheus::Registry::new();
89 let inbound_network_metrics =
90 NetworkMetrics::new("sui", "inbound", ®istry);
91 let outbound_network_metrics =
92 NetworkMetrics::new("sui", "outbound", ®istry);
93
94 let service = ServiceBuilder::new()
95 .layer(
96 TraceLayer::new_for_server_errors()
97 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
98 .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
99 )
100 .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
101 Arc::new(inbound_network_metrics),
102 usize::MAX,
103 )))
104 .service(::sui_simulator::anemo::Router::new());
105
106 let outbound_layer = ServiceBuilder::new()
107 .layer(
108 TraceLayer::new_for_client_and_server_errors()
109 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
110 .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
111 )
112 .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
113 Arc::new(outbound_network_metrics),
114 usize::MAX,
115 )))
116 .into_inner();
117
118
119 Network::bind(format!("127.0.0.1:{}", port))
120 .server_name("static-init-network")
121 .private_key(
122 ::sui_simulator::fastcrypto::ed25519::Ed25519KeyPair::generate(&mut StdRng::from_rng(OsRng).unwrap())
123 .private()
124 .0
125 .to_bytes(),
126 )
127 .start(service)
128 .unwrap()
129 };
130 let n1 = make_network(80);
131 let n2 = make_network(81);
132
133 let _peer = n1.connect(n2.local_addr()).await.unwrap();
134 });
135 }).join().unwrap();
136
137 #body
138 }
139 })
140 .expect("Parsing failure");
141
142 let result = quote! {
143 #input
144 };
145
146 result.into()
147}
148
149#[proc_macro_attribute]
154pub fn sui_test(args: TokenStream, item: TokenStream) -> TokenStream {
155 let input = parse_macro_input!(item as syn::ItemFn);
156 let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
157 let args = arg_parser.parse(args).unwrap().into_iter();
158
159 let header = if cfg!(msim) {
160 quote! {
161 #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args)* )]
162 }
163 } else {
164 quote! {
165 #[::tokio::test(#(#args)*)]
166 }
167 };
168
169 let result = quote! {
170 #header
171 #[::sui_macros::init_static_initializers]
172 #input
173 };
174
175 result.into()
176}
177
178#[proc_macro_attribute]
185pub fn sim_test(args: TokenStream, item: TokenStream) -> TokenStream {
186 let input = parse_macro_input!(item as syn::ItemFn);
187 let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
188 let args = arg_parser.parse(args).unwrap().into_iter();
189
190 let ignore = input
191 .attrs
192 .iter()
193 .find(|attr| attr.path().is_ident("ignore"))
194 .map_or(quote! {}, |_| quote! { #[ignore] });
195
196 let result = if cfg!(msim) {
197 let sig = &input.sig;
198 let return_type = &sig.output;
199 let body = &input.block;
200 quote! {
201 #[::sui_simulator::sim_test(crate = "sui_simulator", #(#args),*)]
202 #[::sui_macros::init_static_initializers]
203 #ignore
204 #sig {
205 async fn body_fn() #return_type { #body }
206
207 let timeout_secs: u64 = std::env::var("SUI_SIM_TEST_TIMEOUT_SECS")
208 .ok()
209 .and_then(|s| s.parse().ok())
210 .unwrap_or(1000);
211 let timeout_duration = tokio::time::Duration::from_secs(timeout_secs);
212
213 let ret = tokio::time::timeout(timeout_duration, body_fn())
214 .await
215 .expect("sim_test timed out");
216
217 ::sui_simulator::task::shutdown_all_nodes();
218
219 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
222
223 assert_eq!(
224 sui_simulator::NodeLeakDetector::get_current_node_count(),
225 0,
226 "SuiNode leak detected"
227 );
228
229 ret
230 }
231 }
232 } else {
233 let fn_name = &input.sig.ident;
234 let sig = &input.sig;
235 let body = &input.block;
236 quote! {
237 #[allow(clippy::needless_return)]
238 #[tokio::test]
239 #ignore
240 #sig {
241 if std::env::var("SUI_SKIP_SIMTESTS").is_ok() {
242 println!("not running test {} in `cargo test`: SUI_SKIP_SIMTESTS is set", stringify!(#fn_name));
243
244 struct Ret;
245
246 impl From<Ret> for () {
247 fn from(_ret: Ret) -> Self {
248 }
249 }
250
251 impl<E> From<Ret> for Result<(), E> {
252 fn from(_ret: Ret) -> Self {
253 Ok(())
254 }
255 }
256
257 return Ret.into();
258 }
259
260 #body
261 }
262 }
263 };
264
265 result.into()
266}
267
268#[proc_macro]
269pub fn checked_arithmetic(input: TokenStream) -> TokenStream {
270 let input_file = CheckArithmetic.fold_file(parse_macro_input!(input));
271
272 let output_items = input_file.items;
273
274 let output = quote! {
275 #(#output_items)*
276 };
277
278 TokenStream::from(output)
279}
280
281#[proc_macro_attribute]
282pub fn with_checked_arithmetic(_attr: TokenStream, item: TokenStream) -> TokenStream {
283 let input_item = parse_macro_input!(item as Item);
284 match input_item {
285 Item::Fn(input_fn) => {
286 let transformed_fn = CheckArithmetic.fold_item_fn(input_fn);
287 TokenStream::from(quote! { #transformed_fn })
288 }
289 Item::Impl(input_impl) => {
290 let transformed_impl = CheckArithmetic.fold_item_impl(input_impl);
291 TokenStream::from(quote! { #transformed_impl })
292 }
293 item => {
294 let transformed_impl = CheckArithmetic.fold_item(item);
295 TokenStream::from(quote! { #transformed_impl })
296 }
297 }
298}
299
300struct CheckArithmetic;
301
302impl CheckArithmetic {
303 fn maybe_skip_macro(&self, attrs: &mut Vec<Attribute>) -> bool {
304 if let Some(idx) = attrs
305 .iter()
306 .position(|attr| attr.path().is_ident("skip_checked_arithmetic"))
307 {
308 attrs.remove(idx);
311 true
312 } else {
313 false
314 }
315 }
316
317 fn process_macro_contents(
318 &mut self,
319 tokens: proc_macro2::TokenStream,
320 ) -> syn::Result<proc_macro2::TokenStream> {
321 let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
323 let Ok(exprs) = parser.parse(tokens.clone().into()) else {
324 return Err(syn::Error::new_spanned(
325 tokens,
326 "could not process macro contents - use #[skip_checked_arithmetic] to skip this macro",
327 ));
328 };
329
330 let folded_exprs = exprs
332 .into_iter()
333 .map(|expr| self.fold_expr(expr))
334 .collect::<Vec<_>>();
335
336 let mut folded_tokens = proc_macro2::TokenStream::new();
338 for (i, folded_expr) in folded_exprs.into_iter().enumerate() {
339 if i > 0 {
340 folded_tokens.extend(std::iter::once::<proc_macro2::TokenTree>(
341 proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone).into(),
342 ));
343 }
344 folded_expr.to_tokens(&mut folded_tokens);
345 }
346
347 Ok(folded_tokens)
348 }
349}
350
351impl Fold for CheckArithmetic {
352 fn fold_stmt(&mut self, stmt: Stmt) -> Stmt {
353 let stmt = fold_stmt(self, stmt);
354 if let Stmt::Macro(stmt_macro) = stmt {
355 let StmtMacro {
356 mut attrs,
357 mut mac,
358 semi_token,
359 } = stmt_macro;
360
361 if self.maybe_skip_macro(&mut attrs) {
362 Stmt::Macro(StmtMacro {
363 attrs,
364 mac,
365 semi_token,
366 })
367 } else {
368 match self.process_macro_contents(mac.tokens.clone()) {
369 Ok(folded_tokens) => {
370 mac.tokens = folded_tokens;
371 Stmt::Macro(StmtMacro {
372 attrs,
373 mac,
374 semi_token,
375 })
376 }
377 Err(error) => parse2(error.to_compile_error()).unwrap(),
378 }
379 }
380 } else {
381 stmt
382 }
383 }
384
385 fn fold_item_macro(&mut self, mut item_macro: ItemMacro) -> ItemMacro {
386 if !self.maybe_skip_macro(&mut item_macro.attrs) {
387 let err = syn::Error::new_spanned(
388 item_macro.to_token_stream(),
389 "cannot process macros - use #[skip_checked_arithmetic] to skip \
390 processing this macro",
391 );
392
393 return parse2(err.to_compile_error()).unwrap();
394 }
395 fold_item_macro(self, item_macro)
396 }
397
398 fn fold_expr(&mut self, expr: Expr) -> Expr {
399 let span = expr.span();
400 let expr = fold_expr(self, expr);
401 let expr = match expr {
402 Expr::Macro(expr_macro) => {
403 let ExprMacro { mut attrs, mut mac } = expr_macro;
404
405 if self.maybe_skip_macro(&mut attrs) {
406 return Expr::Macro(ExprMacro { attrs, mac });
407 } else {
408 match self.process_macro_contents(mac.tokens.clone()) {
409 Ok(folded_tokens) => {
410 mac.tokens = folded_tokens;
411 let expr_macro = Expr::Macro(ExprMacro { attrs, mac });
412 quote!(#expr_macro)
413 }
414 Err(error) => {
415 return Expr::Verbatim(error.to_compile_error());
416 }
417 }
418 }
419 }
420
421 Expr::Binary(expr_binary) => {
422 let ExprBinary {
423 attrs,
424 mut left,
425 op,
426 mut right,
427 } = expr_binary;
428
429 fn remove_parens(expr: &mut Expr) {
430 if let Expr::Paren(paren) = expr {
431 assert!(paren.attrs.is_empty(), "TODO: attrs on parenthesized");
433 *expr = *paren.expr.clone();
434 }
435 }
436
437 macro_rules! wrap_op {
438 ($left: expr, $right: expr, $method: ident, $span: expr) => {{
439 remove_parens(&mut $left);
442 remove_parens(&mut $right);
443
444 quote_spanned!($span => {
445 let (left, right) = (#left, #right);
448 left.$method(right)
449 .unwrap_or_else(||
450 panic!(
451 "Overflow or underflow in {} {} + {}",
452 stringify!($method),
453 left,
454 right,
455 )
456 )
457 })
458 }};
459 }
460
461 macro_rules! wrap_op_assign {
462 ($left: expr, $right: expr, $method: ident, $span: expr) => {{
463 remove_parens(&mut $left);
466 remove_parens(&mut $right);
467
468 quote_spanned!($span => {
469 let (left, right) = (&mut #left, #right);
472 *left = (*left).$method(right)
473 .unwrap_or_else(||
474 panic!(
475 "Overflow or underflow in {} {} + {}",
476 stringify!($method),
477 *left,
478 right
479 )
480 )
481 })
482 }};
483 }
484
485 match op {
486 BinOp::Add(_) => {
487 wrap_op!(left, right, checked_add, span)
488 }
489 BinOp::Sub(_) => {
490 wrap_op!(left, right, checked_sub, span)
491 }
492 BinOp::Mul(_) => {
493 wrap_op!(left, right, checked_mul, span)
494 }
495 BinOp::Div(_) => {
496 wrap_op!(left, right, checked_div, span)
497 }
498 BinOp::Rem(_) => {
499 wrap_op!(left, right, checked_rem, span)
500 }
501 BinOp::AddAssign(_) => {
502 wrap_op_assign!(left, right, checked_add, span)
503 }
504 BinOp::SubAssign(_) => {
505 wrap_op_assign!(left, right, checked_sub, span)
506 }
507 BinOp::MulAssign(_) => {
508 wrap_op_assign!(left, right, checked_mul, span)
509 }
510 BinOp::DivAssign(_) => {
511 wrap_op_assign!(left, right, checked_div, span)
512 }
513 BinOp::RemAssign(_) => {
514 wrap_op_assign!(left, right, checked_rem, span)
515 }
516 _ => {
517 let expr_binary = ExprBinary {
518 attrs,
519 left,
520 op,
521 right,
522 };
523 quote_spanned!(span => #expr_binary)
524 }
525 }
526 }
527 Expr::Unary(expr_unary) => {
528 let op = &expr_unary.op;
529 let operand = &expr_unary.expr;
530 match op {
531 UnOp::Neg(_) => {
532 quote_spanned!(span => #operand.checked_neg().expect("Overflow or underflow in negation"))
533 }
534 _ => quote_spanned!(span => #expr_unary),
535 }
536 }
537 _ => quote_spanned!(span => #expr),
538 };
539
540 parse2(expr).unwrap()
541 }
542}
543
544#[proc_macro_derive(EnumVariantOrder)]
561pub fn enum_variant_order_derive(input: TokenStream) -> TokenStream {
562 let ast = parse_macro_input!(input as DeriveInput);
563 let name = &ast.ident;
564
565 if let Data::Enum(DataEnum { variants, .. }) = ast.data {
566 let variant_entries = variants
567 .iter()
568 .enumerate()
569 .map(|(index, variant)| {
570 let variant_name = variant.ident.to_string();
571 quote! {
572 map.insert( #index as u64, (#variant_name).to_string());
573 }
574 })
575 .collect::<Vec<_>>();
576
577 let deriv = quote! {
578 impl sui_enum_compat_util::EnumOrderMap for #name {
579 fn order_to_variant_map() -> std::collections::BTreeMap<u64, String > {
580 let mut map = std::collections::BTreeMap::new();
581 #(#variant_entries)*
582 map
583 }
584 }
585 };
586
587 deriv.into()
588 } else {
589 panic!("EnumVariantOrder can only be used with enums.");
590 }
591}