1use super::Node;
5use anyhow::Result;
6use futures::future::try_join_all;
7use rand::rngs::OsRng;
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::num::NonZeroUsize;
11use std::time::Duration;
12use std::{
13 ops,
14 path::{Path, PathBuf},
15};
16use sui_types::traffic_control::{PolicyConfig, RemoteFirewallConfig};
17
18#[cfg(msim)]
19use sui_config::node::ExecutionTimeObserverConfig;
20use sui_config::node::{AuthorityOverloadConfig, DBCheckpointConfig, RunWithRange};
21use sui_config::{ExecutionCacheConfig, NodeConfig};
22use sui_macros::nondeterministic;
23use sui_node::SuiNodeHandle;
24use sui_protocol_config::{Chain, ProtocolVersion};
25use sui_swarm_config::genesis_config::{AccountConfig, GenesisConfig, ValidatorGenesisConfig};
26use sui_swarm_config::network_config::NetworkConfig;
27use sui_swarm_config::network_config_builder::{
28 CommitteeConfig, ConfigBuilder, FundsWithdrawSchedulerTypeConfig,
29 GlobalStateHashV2EnabledConfig, ProtocolVersionsConfig, SupportedProtocolVersionsCallback,
30 ValidatorObserverConfigCallback,
31};
32use sui_swarm_config::node_config_builder::FullnodeConfigBuilder;
33use sui_types::base_types::AuthorityName;
34use sui_types::object::Object;
35use sui_types::supported_protocol_versions::SupportedProtocolVersions;
36use tempfile::TempDir;
37use tracing::info;
38
39pub struct SwarmBuilder<R = OsRng> {
40 rng: R,
41 dir: Option<PathBuf>,
43 committee: CommitteeConfig,
44 genesis_config: Option<GenesisConfig>,
45 network_config: Option<NetworkConfig>,
46 chain_override: Option<Chain>,
47 additional_objects: Vec<Object>,
48 fullnode_count: usize,
49 fullnode_rpc_port: Option<u16>,
50 fullnode_rpc_addr: Option<SocketAddr>,
51 fullnode_rpc_config: Option<sui_config::RpcConfig>,
52 supported_protocol_versions_config: ProtocolVersionsConfig,
53 fullnode_supported_protocol_versions_config: Option<ProtocolVersionsConfig>,
55 db_checkpoint_config: DBCheckpointConfig,
56 jwk_fetch_interval: Option<Duration>,
57 num_unpruned_validators: Option<usize>,
58 authority_overload_config: Option<AuthorityOverloadConfig>,
59 execution_cache_config: Option<ExecutionCacheConfig>,
60 data_ingestion_dir: Option<PathBuf>,
61 fullnode_run_with_range: Option<RunWithRange>,
62 fullnode_policy_config: Option<PolicyConfig>,
63 fullnode_fw_config: Option<RemoteFirewallConfig>,
64 global_state_hash_v2_enabled_config: GlobalStateHashV2EnabledConfig,
65 funds_withdraw_scheduler_type_config: Option<FundsWithdrawSchedulerTypeConfig>,
66 disable_fullnode_pruning: bool,
67 state_sync_config: Option<sui_config::p2p::StateSyncConfig>,
68 #[cfg(msim)]
69 execution_time_observer_config: Option<ExecutionTimeObserverConfig>,
70 validator_observer_config: Option<ValidatorObserverConfigCallback>,
71}
72
73impl SwarmBuilder {
74 #[allow(clippy::new_without_default)]
75 pub fn new() -> Self {
76 Self {
77 rng: OsRng,
78 dir: None,
79 committee: CommitteeConfig::Size(NonZeroUsize::new(1).unwrap()),
80 genesis_config: None,
81 network_config: None,
82 chain_override: None,
83 additional_objects: vec![],
84 fullnode_count: 0,
85 fullnode_rpc_port: None,
86 fullnode_rpc_addr: None,
87 fullnode_rpc_config: None,
88 supported_protocol_versions_config: ProtocolVersionsConfig::Default,
89 fullnode_supported_protocol_versions_config: None,
90 db_checkpoint_config: DBCheckpointConfig::default(),
91 jwk_fetch_interval: None,
92 num_unpruned_validators: None,
93 authority_overload_config: None,
94 execution_cache_config: None,
95 data_ingestion_dir: None,
96 fullnode_run_with_range: None,
97 fullnode_policy_config: None,
98 fullnode_fw_config: None,
99 global_state_hash_v2_enabled_config: GlobalStateHashV2EnabledConfig::Global(true),
100 funds_withdraw_scheduler_type_config: None,
101 disable_fullnode_pruning: false,
102 state_sync_config: None,
103 #[cfg(msim)]
104 execution_time_observer_config: None,
105 validator_observer_config: None,
106 }
107 }
108}
109
110impl<R> SwarmBuilder<R> {
111 pub fn rng<N: rand::RngCore + rand::CryptoRng>(self, rng: N) -> SwarmBuilder<N> {
112 SwarmBuilder {
113 rng,
114 dir: self.dir,
115 committee: self.committee,
116 genesis_config: self.genesis_config,
117 network_config: self.network_config,
118 chain_override: self.chain_override,
119 additional_objects: self.additional_objects,
120 fullnode_count: self.fullnode_count,
121 fullnode_rpc_port: self.fullnode_rpc_port,
122 fullnode_rpc_addr: self.fullnode_rpc_addr,
123 fullnode_rpc_config: self.fullnode_rpc_config.clone(),
124 supported_protocol_versions_config: self.supported_protocol_versions_config,
125 fullnode_supported_protocol_versions_config: self
126 .fullnode_supported_protocol_versions_config,
127 db_checkpoint_config: self.db_checkpoint_config,
128 jwk_fetch_interval: self.jwk_fetch_interval,
129 num_unpruned_validators: self.num_unpruned_validators,
130 authority_overload_config: self.authority_overload_config,
131 execution_cache_config: self.execution_cache_config,
132 data_ingestion_dir: self.data_ingestion_dir,
133 fullnode_run_with_range: self.fullnode_run_with_range,
134 fullnode_policy_config: self.fullnode_policy_config,
135 fullnode_fw_config: self.fullnode_fw_config,
136 global_state_hash_v2_enabled_config: self.global_state_hash_v2_enabled_config,
137 funds_withdraw_scheduler_type_config: self.funds_withdraw_scheduler_type_config,
138 disable_fullnode_pruning: self.disable_fullnode_pruning,
139 state_sync_config: self.state_sync_config,
140 #[cfg(msim)]
141 execution_time_observer_config: self.execution_time_observer_config,
142 validator_observer_config: self.validator_observer_config,
143 }
144 }
145
146 pub fn dir<P: Into<PathBuf>>(mut self, dir: P) -> Self {
152 self.dir = Some(dir.into());
153 self
154 }
155
156 pub fn committee_size(mut self, committee_size: NonZeroUsize) -> Self {
160 self.committee = CommitteeConfig::Size(committee_size);
161 self
162 }
163
164 pub fn with_validators(mut self, validators: Vec<ValidatorGenesisConfig>) -> Self {
165 self.committee = CommitteeConfig::Validators(validators);
166 self
167 }
168
169 pub fn with_genesis_config(mut self, genesis_config: GenesisConfig) -> Self {
170 assert!(self.network_config.is_none() && self.genesis_config.is_none());
171 self.genesis_config = Some(genesis_config);
172 self
173 }
174
175 pub fn with_chain_override(mut self, chain: Chain) -> Self {
176 assert!(self.chain_override.is_none());
177 self.chain_override = Some(chain);
178 self
179 }
180
181 pub fn with_num_unpruned_validators(mut self, n: usize) -> Self {
182 assert!(self.network_config.is_none());
183 self.num_unpruned_validators = Some(n);
184 self
185 }
186
187 pub fn with_jwk_fetch_interval(mut self, i: Duration) -> Self {
188 self.jwk_fetch_interval = Some(i);
189 self
190 }
191
192 pub fn with_network_config(mut self, network_config: NetworkConfig) -> Self {
193 assert!(self.network_config.is_none() && self.genesis_config.is_none());
194 self.network_config = Some(network_config);
195 self
196 }
197
198 pub fn with_accounts(mut self, accounts: Vec<AccountConfig>) -> Self {
199 self.get_or_init_genesis_config().accounts = accounts;
200 self
201 }
202
203 pub fn with_objects<I: IntoIterator<Item = Object>>(mut self, objects: I) -> Self {
204 self.additional_objects.extend(objects);
205 self
206 }
207
208 pub fn with_fullnode_count(mut self, fullnode_count: usize) -> Self {
209 self.fullnode_count = fullnode_count;
210 self
211 }
212
213 pub fn with_fullnode_rpc_port(mut self, fullnode_rpc_port: u16) -> Self {
214 assert!(self.fullnode_rpc_addr.is_none());
215 self.fullnode_rpc_port = Some(fullnode_rpc_port);
216 self
217 }
218
219 pub fn with_fullnode_rpc_addr(mut self, fullnode_rpc_addr: SocketAddr) -> Self {
220 assert!(self.fullnode_rpc_port.is_none());
221 self.fullnode_rpc_addr = Some(fullnode_rpc_addr);
222 self
223 }
224
225 pub fn with_fullnode_rpc_config(mut self, fullnode_rpc_config: sui_config::RpcConfig) -> Self {
226 self.fullnode_rpc_config = Some(fullnode_rpc_config);
227 self
228 }
229
230 pub fn with_epoch_duration_ms(mut self, epoch_duration_ms: u64) -> Self {
231 assert!(
232 epoch_duration_ms >= 10000,
233 "Epoch duration must be at least 10s (10000ms) to avoid flaky tests. Got {epoch_duration_ms}ms."
234 );
235 self.get_or_init_genesis_config()
236 .parameters
237 .epoch_duration_ms = epoch_duration_ms;
238 self
239 }
240
241 pub fn with_protocol_version(mut self, v: ProtocolVersion) -> Self {
242 self.get_or_init_genesis_config()
243 .parameters
244 .protocol_version = v;
245 self
246 }
247
248 pub fn with_supported_protocol_versions(mut self, c: SupportedProtocolVersions) -> Self {
249 self.supported_protocol_versions_config = ProtocolVersionsConfig::Global(c);
250 self
251 }
252
253 pub fn with_supported_protocol_version_callback(
254 mut self,
255 func: SupportedProtocolVersionsCallback,
256 ) -> Self {
257 self.supported_protocol_versions_config = ProtocolVersionsConfig::PerValidator(func);
258 self
259 }
260
261 pub fn with_supported_protocol_versions_config(mut self, c: ProtocolVersionsConfig) -> Self {
262 self.supported_protocol_versions_config = c;
263 self
264 }
265
266 pub fn with_global_state_hash_v2_enabled_config(
267 mut self,
268 c: GlobalStateHashV2EnabledConfig,
269 ) -> Self {
270 self.global_state_hash_v2_enabled_config = c;
271 self
272 }
273
274 pub fn with_funds_withdraw_scheduler_type_config(
275 mut self,
276 c: FundsWithdrawSchedulerTypeConfig,
277 ) -> Self {
278 self.funds_withdraw_scheduler_type_config = Some(c);
279 self
280 }
281
282 #[cfg(msim)]
283 pub fn with_execution_time_observer_config(mut self, c: ExecutionTimeObserverConfig) -> Self {
284 self.execution_time_observer_config = Some(c);
285 self
286 }
287
288 pub fn with_validator_observer_config(mut self, c: ValidatorObserverConfigCallback) -> Self {
289 self.validator_observer_config = Some(c);
290 self
291 }
292
293 pub fn with_fullnode_supported_protocol_versions_config(
294 mut self,
295 c: ProtocolVersionsConfig,
296 ) -> Self {
297 self.fullnode_supported_protocol_versions_config = Some(c);
298 self
299 }
300
301 pub fn with_db_checkpoint_config(mut self, db_checkpoint_config: DBCheckpointConfig) -> Self {
302 self.db_checkpoint_config = db_checkpoint_config;
303 self
304 }
305
306 pub fn with_authority_overload_config(
307 mut self,
308 authority_overload_config: AuthorityOverloadConfig,
309 ) -> Self {
310 assert!(self.network_config.is_none());
311 self.authority_overload_config = Some(authority_overload_config);
312 self
313 }
314
315 pub fn with_execution_cache_config(
316 mut self,
317 execution_cache_config: ExecutionCacheConfig,
318 ) -> Self {
319 self.execution_cache_config = Some(execution_cache_config);
320 self
321 }
322
323 pub fn with_data_ingestion_dir(mut self, path: PathBuf) -> Self {
324 self.data_ingestion_dir = Some(path);
325 self
326 }
327
328 pub fn with_state_sync_config(mut self, config: sui_config::p2p::StateSyncConfig) -> Self {
329 self.state_sync_config = Some(config);
330 self
331 }
332
333 pub fn with_fullnode_run_with_range(mut self, run_with_range: Option<RunWithRange>) -> Self {
334 if let Some(run_with_range) = run_with_range {
335 self.fullnode_run_with_range = Some(run_with_range);
336 }
337 self
338 }
339
340 pub fn with_fullnode_policy_config(mut self, config: Option<PolicyConfig>) -> Self {
341 self.fullnode_policy_config = config;
342 self
343 }
344
345 pub fn with_fullnode_fw_config(mut self, config: Option<RemoteFirewallConfig>) -> Self {
346 self.fullnode_fw_config = config;
347 self
348 }
349
350 fn get_or_init_genesis_config(&mut self) -> &mut GenesisConfig {
351 if self.genesis_config.is_none() {
352 assert!(self.network_config.is_none());
353 self.genesis_config = Some(GenesisConfig::for_local_testing());
354 }
355 self.genesis_config.as_mut().unwrap()
356 }
357
358 pub fn with_disable_fullnode_pruning(mut self) -> Self {
359 self.disable_fullnode_pruning = true;
360 self
361 }
362}
363
364impl<R: rand::RngCore + rand::CryptoRng> SwarmBuilder<R> {
365 pub fn build(self) -> Swarm {
367 let dir = if let Some(dir) = self.dir {
368 SwarmDirectory::Persistent(dir)
369 } else {
370 SwarmDirectory::new_temporary()
371 };
372
373 let ingest_data = self.data_ingestion_dir.clone();
374
375 let network_config = self.network_config.unwrap_or_else(|| {
376 let mut config_builder = ConfigBuilder::new(dir.as_ref());
377
378 if let Some(genesis_config) = self.genesis_config {
379 config_builder = config_builder.with_genesis_config(genesis_config);
380 }
381
382 if let Some(chain_override) = self.chain_override {
383 config_builder = config_builder.with_chain_override(chain_override);
384 }
385
386 if let Some(num_unpruned_validators) = self.num_unpruned_validators {
387 config_builder =
388 config_builder.with_num_unpruned_validators(num_unpruned_validators);
389 }
390
391 if let Some(jwk_fetch_interval) = self.jwk_fetch_interval {
392 config_builder = config_builder.with_jwk_fetch_interval(jwk_fetch_interval);
393 }
394
395 if let Some(authority_overload_config) = self.authority_overload_config {
396 config_builder =
397 config_builder.with_authority_overload_config(authority_overload_config);
398 }
399
400 if let Some(execution_cache_config) = self.execution_cache_config {
401 config_builder = config_builder.with_execution_cache_config(execution_cache_config);
402 }
403
404 if let Some(path) = self.data_ingestion_dir {
405 config_builder = config_builder.with_data_ingestion_dir(path);
406 }
407
408 #[allow(unused_mut)]
409 let mut final_builder = config_builder
410 .committee(self.committee)
411 .rng(self.rng)
412 .with_objects(self.additional_objects)
413 .with_supported_protocol_versions_config(
414 self.supported_protocol_versions_config.clone(),
415 )
416 .with_global_state_hash_v2_enabled_config(
417 self.global_state_hash_v2_enabled_config.clone(),
418 );
419
420 if let Some(funds_withdraw_scheduler_type_config) =
421 self.funds_withdraw_scheduler_type_config.clone()
422 {
423 final_builder = final_builder.with_funds_withdraw_scheduler_type_config(
424 funds_withdraw_scheduler_type_config,
425 );
426 }
427
428 if let Some(state_sync_config) = self.state_sync_config.clone() {
429 final_builder = final_builder.with_state_sync_config(state_sync_config);
430 }
431
432 #[cfg(msim)]
433 if let Some(execution_time_observer_config) = self.execution_time_observer_config {
434 final_builder = final_builder
435 .with_execution_time_observer_config(execution_time_observer_config);
436 }
437
438 if let Some(validator_observer_config) = self.validator_observer_config {
439 final_builder =
440 final_builder.with_validator_observer_config(validator_observer_config);
441 }
442
443 final_builder.build()
444 });
445
446 let mut nodes: HashMap<_, _> = network_config
447 .validator_configs()
448 .iter()
449 .map(|config| {
450 info!(
451 "SwarmBuilder configuring validator with name {}",
452 config.protocol_public_key()
453 );
454 (config.protocol_public_key(), Node::new(config.to_owned()))
455 })
456 .collect();
457
458 let mut fullnode_config_builder = FullnodeConfigBuilder::new()
459 .with_config_directory(dir.as_ref().into())
460 .with_db_checkpoint_config(self.db_checkpoint_config.clone())
461 .with_run_with_range(self.fullnode_run_with_range)
462 .with_policy_config(self.fullnode_policy_config)
463 .with_data_ingestion_dir(ingest_data)
464 .with_fw_config(self.fullnode_fw_config)
465 .with_disable_pruning(self.disable_fullnode_pruning);
466
467 if let Some(state_sync_config) = self.state_sync_config.clone() {
468 fullnode_config_builder =
469 fullnode_config_builder.with_state_sync_config(state_sync_config);
470 }
471
472 if let Some(chain) = self.chain_override {
473 fullnode_config_builder = fullnode_config_builder.with_chain_override(chain);
474 }
475
476 if let Some(spvc) = &self.fullnode_supported_protocol_versions_config {
477 let supported_versions = match spvc {
478 ProtocolVersionsConfig::Default => SupportedProtocolVersions::SYSTEM_DEFAULT,
479 ProtocolVersionsConfig::Global(v) => *v,
480 ProtocolVersionsConfig::PerValidator(func) => func(0, None),
481 };
482 fullnode_config_builder =
483 fullnode_config_builder.with_supported_protocol_versions(supported_versions);
484 }
485
486 if self.fullnode_count > 0 {
487 (0..self.fullnode_count).for_each(|idx| {
488 let mut builder = fullnode_config_builder.clone();
489 if idx == 0 {
490 if let Some(rpc_addr) = self.fullnode_rpc_addr {
493 builder = builder.with_rpc_addr(rpc_addr);
494 }
495 if let Some(rpc_port) = self.fullnode_rpc_port {
496 builder = builder.with_rpc_port(rpc_port);
497 }
498 if let Some(rpc_config) = &self.fullnode_rpc_config {
499 builder = builder.with_rpc_config(rpc_config.clone());
500 }
501 }
502 let config = builder.build(&mut OsRng, &network_config);
503 info!(
504 "SwarmBuilder configuring full node with name {}",
505 config.protocol_public_key()
506 );
507 nodes.insert(config.protocol_public_key(), Node::new(config));
508 });
509 }
510 Swarm {
511 dir,
512 network_config,
513 nodes,
514 fullnode_config_builder,
515 }
516 }
517}
518
519#[derive(Debug)]
521pub struct Swarm {
522 dir: SwarmDirectory,
523 network_config: NetworkConfig,
524 nodes: HashMap<AuthorityName, Node>,
525 fullnode_config_builder: FullnodeConfigBuilder,
527}
528
529impl Drop for Swarm {
530 fn drop(&mut self) {
531 self.nodes_iter_mut().for_each(|node| node.stop());
532 }
533}
534
535impl Swarm {
536 fn nodes_iter_mut(&mut self) -> impl Iterator<Item = &mut Node> {
537 self.nodes.values_mut()
538 }
539
540 pub fn builder() -> SwarmBuilder {
542 SwarmBuilder::new()
543 }
544
545 pub async fn launch(&mut self) -> Result<()> {
547 try_join_all(self.nodes_iter_mut().map(|node| node.start())).await?;
548 tracing::info!("Successfully launched Swarm");
549 Ok(())
550 }
551
552 pub fn dir(&self) -> &Path {
554 self.dir.as_ref()
555 }
556
557 pub fn config(&self) -> &NetworkConfig {
559 &self.network_config
560 }
561
562 pub fn config_mut(&mut self) -> &mut NetworkConfig {
565 &mut self.network_config
566 }
567
568 pub fn all_nodes(&self) -> impl Iterator<Item = &Node> {
569 self.nodes.values()
570 }
571
572 pub fn node(&self, name: &AuthorityName) -> Option<&Node> {
573 self.nodes.get(name)
574 }
575
576 pub fn node_mut(&mut self, name: &AuthorityName) -> Option<&mut Node> {
577 self.nodes.get_mut(name)
578 }
579
580 pub fn validator_nodes(&self) -> impl Iterator<Item = &Node> {
584 self.nodes
585 .values()
586 .filter(|node| node.config().consensus_config.is_some())
587 }
588
589 pub fn validator_node_handles(&self) -> Vec<SuiNodeHandle> {
590 self.validator_nodes()
591 .map(|node| node.get_node_handle().unwrap())
592 .collect()
593 }
594
595 pub fn active_validators(&self) -> impl Iterator<Item = &Node> {
597 self.validator_nodes().filter(|node| {
598 node.get_node_handle().is_some_and(|handle| {
599 let state = handle.state();
600 state.is_validator(&state.epoch_store_for_testing())
601 })
602 })
603 }
604
605 pub fn fullnodes(&self) -> impl Iterator<Item = &Node> {
607 self.nodes
608 .values()
609 .filter(|node| node.config().consensus_config.is_none())
610 }
611
612 pub async fn spawn_new_node(&mut self, config: NodeConfig) -> SuiNodeHandle {
613 let name = config.protocol_public_key();
614 let node = Node::new(config);
615 node.start().await.unwrap();
616 let handle = node.get_node_handle().unwrap();
617 self.nodes.insert(name, node);
618 handle
619 }
620
621 pub fn get_fullnode_config_builder(&self) -> FullnodeConfigBuilder {
622 self.fullnode_config_builder.clone()
623 }
624}
625
626#[derive(Debug)]
627enum SwarmDirectory {
628 Persistent(PathBuf),
629 Temporary(TempDir),
630}
631
632impl SwarmDirectory {
633 fn new_temporary() -> Self {
634 SwarmDirectory::Temporary(nondeterministic!(TempDir::new().unwrap()))
635 }
636}
637
638impl ops::Deref for SwarmDirectory {
639 type Target = Path;
640
641 fn deref(&self) -> &Self::Target {
642 match self {
643 SwarmDirectory::Persistent(dir) => dir.deref(),
644 SwarmDirectory::Temporary(dir) => dir.path(),
645 }
646 }
647}
648
649impl AsRef<Path> for SwarmDirectory {
650 fn as_ref(&self) -> &Path {
651 match self {
652 SwarmDirectory::Persistent(dir) => dir.as_ref(),
653 SwarmDirectory::Temporary(dir) => dir.as_ref(),
654 }
655 }
656}
657
658#[cfg(test)]
659mod test {
660 use super::Swarm;
661 use std::num::NonZeroUsize;
662
663 #[tokio::test]
664 async fn launch() {
665 telemetry_subscribers::init_for_testing();
666 let mut swarm = Swarm::builder()
667 .committee_size(NonZeroUsize::new(4).unwrap())
668 .with_fullnode_count(1)
669 .build();
670
671 swarm.launch().await.unwrap();
672
673 for validator in swarm.validator_nodes() {
674 validator.health_check(true).await.unwrap();
675 }
676
677 for fullnode in swarm.fullnodes() {
678 fullnode.health_check(false).await.unwrap();
679 }
680
681 println!("hello");
682 }
683}