sui_rpc_loadgen/payload/
rpc_command_processor.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::{Result, anyhow};
5use async_trait::async_trait;
6use dashmap::{DashMap, DashSet};
7use futures::future::join_all;
8use serde::Serialize;
9use serde::de::DeserializeOwned;
10use shared_crypto::intent::{Intent, IntentMessage};
11use std::fmt;
12use std::fs::{self, File};
13use std::path::PathBuf;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use sui_json_rpc_types::{
17    SuiExecutionStatus, SuiObjectDataOptions, SuiTransactionBlockDataAPI,
18    SuiTransactionBlockEffectsAPI, SuiTransactionBlockResponse, SuiTransactionBlockResponseOptions,
19};
20use sui_types::digests::TransactionDigest;
21use tokio::sync::RwLock;
22use tokio::time::sleep;
23use tracing::{debug, info};
24
25use crate::load_test::LoadTestConfig;
26use sui_sdk::{SuiClient, SuiClientBuilder};
27use sui_types::base_types::{ObjectID, ObjectRef, SuiAddress};
28use sui_types::crypto::{AccountKeyPair, EncodeDecodeBase64, Signature, SuiKeyPair, get_key_pair};
29use sui_types::quorum_driver_types::ExecuteTransactionRequestType;
30use sui_types::transaction::{Transaction, TransactionData};
31
32use crate::payload::checkpoint_utils::get_latest_checkpoint_stats;
33use crate::payload::validation::chunk_entities;
34use crate::payload::{
35    Command, CommandData, DryRun, GetAllBalances, GetCheckpoints, GetObject, MultiGetObjects,
36    Payload, ProcessPayload, Processor, QueryTransactionBlocks, SignerInfo,
37};
38
39use super::MultiGetTransactionBlocks;
40
41pub(crate) const DEFAULT_GAS_BUDGET: u64 = 500_000_000;
42pub(crate) const DEFAULT_LARGE_GAS_BUDGET: u64 = 50_000_000_000;
43pub(crate) const MAX_NUM_NEW_OBJECTS_IN_SINGLE_TRANSACTION: usize = 120;
44
45#[derive(Clone)]
46pub struct RpcCommandProcessor {
47    clients: Arc<RwLock<Vec<SuiClient>>>,
48    // for equivocation prevention in `WaitForEffectsCert` mode
49    object_ref_cache: Arc<DashMap<ObjectID, ObjectRef>>,
50    transaction_digests: Arc<DashSet<TransactionDigest>>,
51    addresses: Arc<DashSet<SuiAddress>>,
52    data_dir: String,
53}
54
55impl RpcCommandProcessor {
56    pub async fn new(urls: &[String], data_dir: String) -> Self {
57        let clients = join_all(urls.iter().map(|url| async {
58            SuiClientBuilder::default()
59                .max_concurrent_requests(usize::MAX)
60                .request_timeout(Duration::from_secs(60))
61                .build(url.clone())
62                .await
63                .unwrap()
64        }))
65        .await;
66
67        Self {
68            clients: Arc::new(RwLock::new(clients)),
69            object_ref_cache: Arc::new(DashMap::new()),
70            transaction_digests: Arc::new(DashSet::new()),
71            addresses: Arc::new(DashSet::new()),
72            data_dir,
73        }
74    }
75
76    async fn process_command_data(
77        &self,
78        command: &CommandData,
79        signer_info: &Option<SignerInfo>,
80    ) -> Result<()> {
81        match command {
82            CommandData::DryRun(v) => self.process(v, signer_info).await,
83            CommandData::GetCheckpoints(v) => self.process(v, signer_info).await,
84            CommandData::PaySui(v) => self.process(v, signer_info).await,
85            CommandData::QueryTransactionBlocks(v) => self.process(v, signer_info).await,
86            CommandData::MultiGetTransactionBlocks(v) => self.process(v, signer_info).await,
87            CommandData::MultiGetObjects(v) => self.process(v, signer_info).await,
88            CommandData::GetObject(v) => self.process(v, signer_info).await,
89            CommandData::GetAllBalances(v) => self.process(v, signer_info).await,
90            CommandData::GetReferenceGasPrice(v) => self.process(v, signer_info).await,
91        }
92    }
93
94    pub(crate) async fn get_clients(&self) -> Result<Vec<SuiClient>> {
95        let read = self.clients.read().await;
96        Ok(read.clone())
97    }
98
99    /// sign_and_execute transaction and update `object_ref_cache`
100    pub(crate) async fn sign_and_execute(
101        &self,
102        client: &SuiClient,
103        keypair: &SuiKeyPair,
104        txn_data: TransactionData,
105        request_type: ExecuteTransactionRequestType,
106    ) -> SuiTransactionBlockResponse {
107        let resp = sign_and_execute(client, keypair, txn_data, request_type).await;
108        let effects = resp.effects.as_ref().unwrap();
109        let object_ref_cache = self.object_ref_cache.clone();
110        // NOTE: for now we don't need to care about deleted objects
111        for (owned_object_ref, _) in effects.all_changed_objects() {
112            let id = owned_object_ref.object_id();
113            let current = object_ref_cache.get_mut(&id);
114            match current {
115                Some(mut c) => {
116                    if c.1 < owned_object_ref.version() {
117                        *c = owned_object_ref.reference.to_object_ref();
118                    }
119                }
120                None => {
121                    object_ref_cache.insert(id, owned_object_ref.reference.to_object_ref());
122                }
123            };
124        }
125        resp
126    }
127
128    /// get the latest object ref from local cache, and if not exist, fetch from fullnode
129    pub(crate) async fn get_object_ref(
130        &self,
131        client: &SuiClient,
132        object_id: &ObjectID,
133    ) -> ObjectRef {
134        let object_ref_cache = self.object_ref_cache.clone();
135        let current = object_ref_cache.get_mut(object_id);
136        match current {
137            Some(c) => *c,
138            None => {
139                let resp = client
140                    .read_api()
141                    .get_object_with_options(*object_id, SuiObjectDataOptions::new())
142                    .await
143                    .unwrap_or_else(|_| panic!("Unable to fetch object reference {object_id}"));
144                let object_ref = resp.object_ref_if_exists().unwrap_or_else(|| {
145                    panic!("Unable to extract object reference {object_id} from response {resp:?}")
146                });
147                object_ref_cache.insert(*object_id, object_ref);
148                object_ref
149            }
150        }
151    }
152
153    pub(crate) fn add_transaction_digests(&self, digests: Vec<TransactionDigest>) {
154        // extend method requires mutable access to the underlying DashSet, which is not allowed by the Arc
155        for digest in digests {
156            self.transaction_digests.insert(digest);
157        }
158    }
159
160    pub(crate) fn add_addresses_from_response(&self, responses: &[SuiTransactionBlockResponse]) {
161        for response in responses {
162            let transaction = &response.transaction;
163            if let Some(transaction) = transaction {
164                let data = &transaction.data;
165                self.addresses.insert(*data.sender());
166            }
167        }
168    }
169
170    pub(crate) fn add_object_ids_from_response(&self, responses: &[SuiTransactionBlockResponse]) {
171        for response in responses {
172            let effects = &response.effects;
173            if let Some(effects) = effects {
174                let all_changed_objects = effects.all_changed_objects();
175                for (object_ref, _) in all_changed_objects {
176                    self.object_ref_cache
177                        .insert(object_ref.object_id(), object_ref.reference.to_object_ref());
178                }
179            }
180        }
181    }
182
183    pub(crate) fn dump_cache_to_file(&self) {
184        // TODO: be more granular
185        let digests: Vec<TransactionDigest> = self.transaction_digests.iter().map(|x| *x).collect();
186        if !digests.is_empty() {
187            debug!("dumping transaction digests to file {:?}", digests.len());
188            write_data_to_file(
189                &digests,
190                &format!("{}/{}", &self.data_dir, CacheType::TransactionDigest),
191            )
192            .unwrap();
193        }
194
195        let addresses: Vec<SuiAddress> = self.addresses.iter().map(|x| *x).collect();
196        if !addresses.is_empty() {
197            debug!("dumping addresses to file {:?}", addresses.len());
198            write_data_to_file(
199                &addresses,
200                &format!("{}/{}", &self.data_dir, CacheType::SuiAddress),
201            )
202            .unwrap();
203        }
204
205        let mut object_ids: Vec<ObjectID> = Vec::new();
206        let cloned_object_cache = self.object_ref_cache.clone();
207
208        for item in cloned_object_cache.iter() {
209            let object_id = item.key();
210            object_ids.push(*object_id);
211        }
212
213        if !object_ids.is_empty() {
214            debug!("dumping object_ids to file {:?}", object_ids.len());
215            write_data_to_file(
216                &object_ids,
217                &format!("{}/{}", &self.data_dir, CacheType::ObjectID),
218            )
219            .unwrap();
220        }
221    }
222}
223
224#[async_trait]
225impl Processor for RpcCommandProcessor {
226    async fn apply(&self, payload: &Payload) -> Result<()> {
227        let commands = &payload.commands;
228        for command in commands.iter() {
229            let repeat_interval = command.repeat_interval;
230            let repeat_n_times = command.repeat_n_times;
231            for i in 0..=repeat_n_times {
232                let start_time = Instant::now();
233
234                self.process_command_data(&command.data, &payload.signer_info)
235                    .await?;
236
237                let elapsed_time = start_time.elapsed();
238                if elapsed_time < repeat_interval {
239                    let sleep_duration = repeat_interval - elapsed_time;
240                    sleep(sleep_duration).await;
241                }
242                let clients = self.get_clients().await?;
243                let checkpoint_stats = get_latest_checkpoint_stats(&clients, None).await;
244                info!(
245                    "Repeat {i}: Checkpoint stats {checkpoint_stats}, elapse {:.4} since last repeat",
246                    elapsed_time.as_secs_f64()
247                );
248            }
249        }
250        Ok(())
251    }
252
253    async fn prepare(&self, config: &LoadTestConfig) -> Result<Vec<Payload>> {
254        let clients = self.get_clients().await?;
255        let Command {
256            repeat_n_times,
257            repeat_interval,
258            ..
259        } = &config.command;
260        let command_payloads = match &config.command.data {
261            CommandData::GetCheckpoints(data) => {
262                if !config.divide_tasks {
263                    vec![config.command.clone(); config.num_threads]
264                } else {
265                    divide_checkpoint_tasks(&clients, data, config.num_threads).await
266                }
267            }
268            CommandData::QueryTransactionBlocks(data) => {
269                if !config.divide_tasks {
270                    vec![config.command.clone(); config.num_threads]
271                } else {
272                    divide_query_transaction_blocks_tasks(data, config.num_threads).await
273                }
274            }
275            CommandData::MultiGetTransactionBlocks(data) => {
276                if !config.divide_tasks {
277                    vec![config.command.clone(); config.num_threads]
278                } else {
279                    divide_multi_get_transaction_blocks_tasks(data, config.num_threads).await
280                }
281            }
282            CommandData::GetAllBalances(data) => {
283                if !config.divide_tasks {
284                    vec![config.command.clone(); config.num_threads]
285                } else {
286                    divide_get_all_balances_tasks(data, config.num_threads).await
287                }
288            }
289            CommandData::MultiGetObjects(data) => {
290                if !config.divide_tasks {
291                    vec![config.command.clone(); config.num_threads]
292                } else {
293                    divide_multi_get_objects_tasks(data, config.num_threads).await
294                }
295            }
296            CommandData::GetObject(data) => {
297                if !config.divide_tasks {
298                    vec![config.command.clone(); config.num_threads]
299                } else {
300                    divide_get_object_tasks(data, config.num_threads).await
301                }
302            }
303            _ => vec![config.command.clone(); config.num_threads],
304        };
305
306        let command_payloads = command_payloads.into_iter().map(|command| {
307            command
308                .with_repeat_interval(*repeat_interval)
309                .with_repeat_n_times(*repeat_n_times)
310        });
311
312        let coins_and_keys = if config.signer_info.is_some() {
313            Some(
314                prepare_new_signer_and_coins(
315                    clients.first().unwrap(),
316                    config.signer_info.as_ref().unwrap(),
317                    config.num_threads * config.num_chunks_per_thread,
318                    config.max_repeat as u64 + 1,
319                )
320                .await,
321            )
322        } else {
323            None
324        };
325
326        let num_chunks = config.num_chunks_per_thread;
327        Ok(command_payloads
328            .into_iter()
329            .enumerate()
330            .map(|(i, command)| Payload {
331                commands: vec![command], // note commands is also a vector
332                signer_info: coins_and_keys
333                    .as_ref()
334                    .map(|(coins, encoded_keypair)| SignerInfo {
335                        encoded_keypair: encoded_keypair.clone(),
336                        gas_payment: Some(coins[num_chunks * i..(i + 1) * num_chunks].to_vec()),
337                        gas_budget: None,
338                    }),
339            })
340            .collect())
341    }
342
343    fn dump_cache_to_file(&self, config: &LoadTestConfig) {
344        if let CommandData::GetCheckpoints(data) = &config.command.data
345            && data.record
346        {
347            self.dump_cache_to_file();
348        }
349    }
350}
351
352#[async_trait]
353impl<'a> ProcessPayload<'a, &'a DryRun> for RpcCommandProcessor {
354    async fn process(&'a self, _op: &'a DryRun, _signer_info: &Option<SignerInfo>) -> Result<()> {
355        debug!("DryRun");
356        Ok(())
357    }
358}
359
360fn write_data_to_file<T: Serialize>(data: &T, file_path: &str) -> Result<(), anyhow::Error> {
361    let mut path_buf = PathBuf::from(&file_path);
362    path_buf.pop();
363    fs::create_dir_all(&path_buf).map_err(|e| anyhow!("Error creating directory: {}", e))?;
364
365    let file_name = format!("{}.json", file_path);
366    let file = File::create(file_name).map_err(|e| anyhow!("Error creating file: {}", e))?;
367    serde_json::to_writer(file, data).map_err(|e| anyhow!("Error writing to file: {}", e))?;
368
369    Ok(())
370}
371
372pub enum CacheType {
373    SuiAddress,
374    TransactionDigest,
375    ObjectID,
376}
377
378impl fmt::Display for CacheType {
379    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380        match self {
381            CacheType::SuiAddress => write!(f, "SuiAddress"),
382            CacheType::TransactionDigest => write!(f, "TransactionDigest"),
383            CacheType::ObjectID => write!(f, "ObjectID"),
384        }
385    }
386}
387
388// TODO(Will): Consider using enums for input and output? Would mean we need to do checks any time we use generic load_cache_from_file
389pub fn load_addresses_from_file(filepath: String) -> Vec<SuiAddress> {
390    let path = format!("{}/{}", filepath, CacheType::SuiAddress);
391    let addresses: Vec<SuiAddress> = read_data_from_file(&path).expect("Failed to read addresses");
392    addresses
393}
394
395pub fn load_objects_from_file(filepath: String) -> Vec<ObjectID> {
396    let path = format!("{}/{}", filepath, CacheType::ObjectID);
397    let objects: Vec<ObjectID> = read_data_from_file(&path).expect("Failed to read objects");
398    objects
399}
400
401pub fn load_digests_from_file(filepath: String) -> Vec<TransactionDigest> {
402    let path = format!("{}/{}", filepath, CacheType::TransactionDigest);
403    let digests: Vec<TransactionDigest> =
404        read_data_from_file(&path).expect("Failed to read transaction digests");
405    digests
406}
407
408fn read_data_from_file<T: DeserializeOwned>(file_path: &str) -> Result<T, anyhow::Error> {
409    let mut path_buf = PathBuf::from(file_path);
410
411    // Check if the file has a JSON extension
412    if path_buf.extension().is_none_or(|ext| ext != "json") {
413        // If not, add .json to the filename
414        path_buf.set_extension("json");
415    }
416
417    let path = path_buf.as_path();
418    if !path.exists() {
419        return Err(anyhow!("File not found: {}", file_path));
420    }
421
422    let file = File::open(path).map_err(|e| anyhow::anyhow!("Error opening file: {}", e))?;
423    let deserialized_data: T =
424        serde_json::from_reader(file).map_err(|e| anyhow!("Deserialization error: {}", e))?;
425
426    Ok(deserialized_data)
427}
428
429async fn divide_checkpoint_tasks(
430    clients: &[SuiClient],
431    data: &GetCheckpoints,
432    num_chunks: usize,
433) -> Vec<Command> {
434    let start = data.start;
435    let end = match data.end {
436        Some(end) => end,
437        None => {
438            let end_checkpoints = join_all(clients.iter().map(|client| async {
439                client
440                    .read_api()
441                    .get_latest_checkpoint_sequence_number()
442                    .await
443                    .expect("get_latest_checkpoint_sequence_number should not fail")
444            }))
445            .await;
446            *end_checkpoints
447                .iter()
448                .max()
449                .expect("get_latest_checkpoint_sequence_number should not return empty")
450        }
451    };
452
453    let chunk_size = (end - start) / num_chunks as u64;
454    (0..num_chunks)
455        .map(|i| {
456            let start_checkpoint = start + (i as u64) * chunk_size;
457            let end_checkpoint = end.min(start + ((i + 1) as u64) * chunk_size);
458            Command::new_get_checkpoints(
459                start_checkpoint,
460                Some(end_checkpoint),
461                data.verify_transactions,
462                data.verify_objects,
463                data.record,
464            )
465        })
466        .collect()
467}
468
469async fn divide_query_transaction_blocks_tasks(
470    data: &QueryTransactionBlocks,
471    num_chunks: usize,
472) -> Vec<Command> {
473    let chunk_size = if data.addresses.len() < num_chunks {
474        1
475    } else {
476        data.addresses.len() as u64 / num_chunks as u64
477    };
478    let chunked = chunk_entities(data.addresses.as_slice(), Some(chunk_size as usize));
479    chunked
480        .into_iter()
481        .map(|chunk| Command::new_query_transaction_blocks(data.address_type.clone(), chunk))
482        .collect()
483}
484
485async fn divide_multi_get_transaction_blocks_tasks(
486    data: &MultiGetTransactionBlocks,
487    num_chunks: usize,
488) -> Vec<Command> {
489    let chunk_size = if data.digests.len() < num_chunks {
490        1
491    } else {
492        data.digests.len() as u64 / num_chunks as u64
493    };
494    let chunked = chunk_entities(data.digests.as_slice(), Some(chunk_size as usize));
495    chunked
496        .into_iter()
497        .map(Command::new_multi_get_transaction_blocks)
498        .collect()
499}
500
501async fn divide_get_all_balances_tasks(data: &GetAllBalances, num_threads: usize) -> Vec<Command> {
502    let per_thread_size = if data.addresses.len() < num_threads {
503        1
504    } else {
505        data.addresses.len() / num_threads
506    };
507
508    let chunked = chunk_entities(data.addresses.as_slice(), Some(per_thread_size));
509    chunked
510        .into_iter()
511        .map(|chunk| Command::new_get_all_balances(chunk, data.chunk_size))
512        .collect()
513}
514
515// TODO: probs can do generic divide tasks
516async fn divide_multi_get_objects_tasks(data: &MultiGetObjects, num_chunks: usize) -> Vec<Command> {
517    let chunk_size = if data.object_ids.len() < num_chunks {
518        1
519    } else {
520        data.object_ids.len() as u64 / num_chunks as u64
521    };
522    let chunked = chunk_entities(data.object_ids.as_slice(), Some(chunk_size as usize));
523    chunked
524        .into_iter()
525        .map(Command::new_multi_get_objects)
526        .collect()
527}
528
529async fn divide_get_object_tasks(data: &GetObject, num_threads: usize) -> Vec<Command> {
530    let per_thread_size = if data.object_ids.len() < num_threads {
531        1
532    } else {
533        data.object_ids.len() / num_threads
534    };
535
536    let chunked = chunk_entities(data.object_ids.as_slice(), Some(per_thread_size));
537    chunked
538        .into_iter()
539        .map(|chunk| Command::new_get_object(chunk, data.chunk_size))
540        .collect()
541}
542
543async fn prepare_new_signer_and_coins(
544    client: &SuiClient,
545    signer_info: &SignerInfo,
546    num_coins: usize,
547    num_transactions_per_coin: u64,
548) -> (Vec<ObjectID>, String) {
549    // TODO(chris): consider reference gas price
550    let amount_per_coin = num_transactions_per_coin * DEFAULT_GAS_BUDGET;
551    let pay_amount = amount_per_coin * num_coins as u64;
552    let num_split_txns =
553        num_transactions_needed(num_coins, MAX_NUM_NEW_OBJECTS_IN_SINGLE_TRANSACTION);
554    let (gas_fee_for_split, gas_fee_for_pay_sui) = (
555        DEFAULT_LARGE_GAS_BUDGET * num_split_txns as u64,
556        DEFAULT_GAS_BUDGET,
557    );
558
559    let primary_keypair = SuiKeyPair::decode_base64(&signer_info.encoded_keypair)
560        .expect("Decoding keypair should not fail");
561    let sender = SuiAddress::from(&primary_keypair.public());
562    let (coin, balance) = get_coin_with_max_balance(client, sender).await;
563    // The balance needs to cover `pay_amount` plus
564    // 1. gas fee for pay_sui from the primary address to the burner address
565    // 2. gas fee for splitting the primary coin into `num_coins`
566    let required_balance = pay_amount + gas_fee_for_split + gas_fee_for_pay_sui;
567    if required_balance > balance {
568        panic!(
569            "Current balance {balance} is smaller than require amount of MIST to fund the operation {required_balance}"
570        );
571    }
572
573    // There is a limit for the number of new objects in a transactions, therefore we need
574    // multiple split transactions if the `num_coins` is large
575    let split_amounts = calculate_split_amounts(
576        num_coins,
577        amount_per_coin,
578        MAX_NUM_NEW_OBJECTS_IN_SINGLE_TRANSACTION,
579    );
580
581    debug!("split_amounts {split_amounts:?}");
582
583    // We don't want to split coins in our primary address because we want to avoid having
584    // a million coin objects in our address. We can also fetch directly from the faucet, but in
585    // some environment that might not be possible when faucet resource is scarce
586    let (burner_address, burner_keypair): (_, AccountKeyPair) = get_key_pair();
587    let burner_keypair = SuiKeyPair::Ed25519(burner_keypair);
588    let pay_amounts = split_amounts
589        .iter()
590        .map(|(amount, _)| *amount)
591        .chain(std::iter::once(gas_fee_for_split))
592        .collect::<Vec<_>>();
593
594    debug!("pay_amounts {pay_amounts:?}");
595
596    pay_sui(
597        client,
598        &primary_keypair,
599        vec![coin],
600        DEFAULT_GAS_BUDGET,
601        vec![burner_address; pay_amounts.len()],
602        pay_amounts,
603    )
604    .await;
605
606    let coins = get_sui_coin_ids(client, burner_address).await;
607    let gas_coin_id = get_coin_with_balance(&coins, gas_fee_for_split);
608    let primary_coin = get_coin_with_balance(&coins, split_amounts[0].0);
609    assert!(!coins.is_empty());
610    let mut results: Vec<ObjectID> = vec![];
611    assert!(!split_amounts.is_empty());
612    if split_amounts.len() == 1 && split_amounts[0].1 == 0 {
613        results.push(get_coin_with_balance(&coins, split_amounts[0].0));
614    } else if split_amounts.len() == 1 {
615        results.extend(
616            split_coins(
617                client,
618                &burner_keypair,
619                primary_coin,
620                gas_coin_id,
621                split_amounts[0].1 as u64,
622            )
623            .await,
624        );
625    } else {
626        let (max_amount, max_split) = &split_amounts[0];
627        let (remainder_amount, remainder_split) = split_amounts.last().unwrap();
628        let primary_coins = coins
629            .iter()
630            .filter(|(_, balance)| balance == max_amount)
631            .map(|(id, _)| (*id, *max_split as u64))
632            .chain(
633                coins
634                    .iter()
635                    .filter(|(_, balance)| balance == remainder_amount)
636                    .map(|(id, _)| (*id, *remainder_split as u64)),
637            )
638            .collect::<Vec<_>>();
639
640        for (coin_id, splits) in primary_coins {
641            results
642                .extend(split_coins(client, &burner_keypair, coin_id, gas_coin_id, splits).await);
643        }
644    }
645    assert_eq!(results.len(), num_coins);
646    debug!("Split off {} coins for gas payment {results:?}", num_coins);
647    (results, burner_keypair.encode_base64())
648}
649
650/// Calculate the number of transactions needed to split the given number of coins.
651/// new_coins_per_txn must be greater than 0
652fn num_transactions_needed(num_coins: usize, new_coins_per_txn: usize) -> usize {
653    assert!(new_coins_per_txn > 0);
654    if num_coins == 1 {
655        return 0;
656    }
657    num_coins.div_ceil(new_coins_per_txn)
658}
659
660/// Calculate the split amounts for a given number of coins, amount per coin, and maximum number of coins per transaction.
661/// Returns a Vec of (primary_coin_amount, split_into_n_coins)
662fn calculate_split_amounts(
663    num_coins: usize,
664    amount_per_coin: u64,
665    max_coins_per_txn: usize,
666) -> Vec<(u64, usize)> {
667    let total_amount = amount_per_coin * num_coins as u64;
668    let num_transactions = num_transactions_needed(num_coins, max_coins_per_txn);
669
670    if num_transactions == 0 {
671        return vec![(total_amount, 0)];
672    }
673
674    let amount_per_transaction = max_coins_per_txn as u64 * amount_per_coin;
675    let remaining_amount = total_amount - amount_per_transaction * (num_transactions as u64 - 1);
676    let mut split_amounts: Vec<(u64, usize)> =
677        vec![(amount_per_transaction, max_coins_per_txn); num_transactions - 1];
678    split_amounts.push((
679        remaining_amount,
680        num_coins - max_coins_per_txn * (num_transactions - 1),
681    ));
682    split_amounts
683}
684
685async fn get_coin_with_max_balance(client: &SuiClient, address: SuiAddress) -> (ObjectID, u64) {
686    let coins = get_sui_coin_ids(client, address).await;
687    assert!(!coins.is_empty());
688    coins.into_iter().max_by(|a, b| a.1.cmp(&b.1)).unwrap()
689}
690
691fn get_coin_with_balance(coins: &[(ObjectID, u64)], target: u64) -> ObjectID {
692    coins.iter().find(|(_, b)| b == &target).unwrap().0
693}
694
695// TODO: move this to the Rust SDK
696async fn get_sui_coin_ids(client: &SuiClient, address: SuiAddress) -> Vec<(ObjectID, u64)> {
697    match client
698        .coin_read_api()
699        .get_coins(address, None, None, None)
700        .await
701    {
702        Ok(page) => page
703            .data
704            .into_iter()
705            .map(|c| (c.coin_object_id, c.balance))
706            .collect::<Vec<_>>(),
707        Err(e) => {
708            panic!("get_sui_coin_ids error for address {address} {e}")
709        }
710    }
711    // TODO: implement iteration over next page
712}
713
714async fn pay_sui(
715    client: &SuiClient,
716    keypair: &SuiKeyPair,
717    input_coins: Vec<ObjectID>,
718    gas_budget: u64,
719    recipients: Vec<SuiAddress>,
720    amounts: Vec<u64>,
721) -> SuiTransactionBlockResponse {
722    let sender = SuiAddress::from(&keypair.public());
723    let tx = client
724        .transaction_builder()
725        .pay(sender, input_coins, recipients, amounts, None, gas_budget)
726        .await
727        .expect("Failed to construct pay sui transaction");
728    sign_and_execute(
729        client,
730        keypair,
731        tx,
732        ExecuteTransactionRequestType::WaitForLocalExecution,
733    )
734    .await
735}
736
737async fn split_coins(
738    client: &SuiClient,
739    keypair: &SuiKeyPair,
740    coin_to_split: ObjectID,
741    gas_payment: ObjectID,
742    num_coins: u64,
743) -> Vec<ObjectID> {
744    let sender = SuiAddress::from(&keypair.public());
745    let split_coin_tx = client
746        .transaction_builder()
747        .split_coin_equal(
748            sender,
749            coin_to_split,
750            num_coins,
751            Some(gas_payment),
752            DEFAULT_LARGE_GAS_BUDGET,
753        )
754        .await
755        .expect("Failed to construct split coin transaction");
756    sign_and_execute(
757        client,
758        keypair,
759        split_coin_tx,
760        ExecuteTransactionRequestType::WaitForLocalExecution,
761    )
762    .await
763    .effects
764    .unwrap()
765    .created()
766    .iter()
767    .map(|owned_object_ref| owned_object_ref.reference.object_id)
768    .chain(std::iter::once(coin_to_split))
769    .collect::<Vec<_>>()
770}
771
772pub(crate) async fn sign_and_execute(
773    client: &SuiClient,
774    keypair: &SuiKeyPair,
775    txn_data: TransactionData,
776    request_type: ExecuteTransactionRequestType,
777) -> SuiTransactionBlockResponse {
778    let signature = Signature::new_secure(
779        &IntentMessage::new(Intent::sui_transaction(), &txn_data),
780        keypair,
781    );
782
783    let transaction_response = match client
784        .quorum_driver_api()
785        .execute_transaction_block(
786            Transaction::from_data(txn_data, vec![signature]),
787            SuiTransactionBlockResponseOptions::new().with_effects(),
788            Some(request_type),
789        )
790        .await
791    {
792        Ok(response) => response,
793        Err(e) => {
794            panic!("sign_and_execute error {e}")
795        }
796    };
797
798    match &transaction_response.effects {
799        Some(effects) => {
800            if let SuiExecutionStatus::Failure { error } = effects.status() {
801                panic!(
802                    "Transaction {} failed with error: {}. Transaction Response: {:?}",
803                    transaction_response.digest, error, &transaction_response
804                );
805            }
806        }
807        None => {
808            panic!(
809                "Transaction {} has no effects. Response {:?}",
810                transaction_response.digest, &transaction_response
811            );
812        }
813    };
814    transaction_response
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use std::{assert_eq, vec};
821
822    #[test]
823    fn test_calculate_split_amounts_no_split_needed() {
824        let num_coins = 10;
825        let amount_per_coin = 100;
826        let max_coins_per_txn = 20;
827        let expected = vec![(1000, 10)];
828        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
829
830        assert_eq!(expected, result);
831    }
832
833    #[test]
834    fn test_calculate_split_amounts_exact_split() {
835        let num_coins = 10;
836        let amount_per_coin = 100;
837        let max_coins_per_txn = 5;
838        let expected = vec![(500, 5), (500, 5)];
839        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
840
841        assert_eq!(expected, result);
842    }
843
844    #[test]
845    fn test_calculate_split_amounts_with_remainder() {
846        let num_coins = 12;
847        let amount_per_coin = 100;
848        let max_coins_per_txn = 5;
849        let expected = vec![(500, 5), (500, 5), (200, 2)];
850        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
851
852        assert_eq!(expected, result);
853    }
854
855    #[test]
856    fn test_calculate_split_amounts_single_coin() {
857        let num_coins = 1;
858        let amount_per_coin = 100;
859        let max_coins_per_txn = 5;
860        let expected = vec![(100, 0)];
861        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
862
863        assert_eq!(expected, result);
864    }
865
866    #[test]
867    fn test_calculate_split_amounts_max_coins_equals_num_coins() {
868        let num_coins = 5;
869        let amount_per_coin = 100;
870        let max_coins_per_txn = 5;
871        let expected = vec![(500, 5)];
872        let result = calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
873
874        assert_eq!(expected, result);
875    }
876
877    #[test]
878    #[should_panic(expected = "assertion failed: new_coins_per_txn > 0")]
879    fn test_calculate_split_amounts_zero_max_coins() {
880        let num_coins = 5;
881        let amount_per_coin = 100;
882        let max_coins_per_txn = 0;
883
884        calculate_split_amounts(num_coins, amount_per_coin, max_coins_per_txn);
885    }
886}