1use crate::SuiClient;
5use crate::sui_client_config::{SuiClientConfig, SuiEnv};
6use anyhow::{anyhow, ensure};
7use futures::future;
8use shared_crypto::intent::Intent;
9use std::collections::BTreeSet;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use sui_config::{Config, PersistedConfig};
13use sui_json_rpc_types::{
14    SuiObjectData, SuiObjectDataFilter, SuiObjectDataOptions, SuiObjectResponse,
15    SuiObjectResponseQuery, SuiTransactionBlockResponse, SuiTransactionBlockResponseOptions,
16};
17use sui_keys::key_identity::KeyIdentity;
18use sui_keys::keystore::{AccountKeystore, Keystore};
19use sui_types::base_types::{FullObjectRef, ObjectID, ObjectRef, SuiAddress};
20use sui_types::crypto::{Signature, SuiKeyPair};
21
22use sui_types::gas_coin::GasCoin;
23use sui_types::transaction::{Transaction, TransactionData, TransactionDataAPI};
24use tokio::sync::RwLock;
25use tracing::info;
26
27pub struct WalletContext {
28    pub config: PersistedConfig<SuiClientConfig>,
29    request_timeout: Option<std::time::Duration>,
30    client: Arc<RwLock<Option<SuiClient>>>,
31    max_concurrent_requests: Option<u64>,
32    env_override: Option<String>,
33}
34
35impl WalletContext {
36    pub fn new(config_path: &Path) -> Result<Self, anyhow::Error> {
37        let config: SuiClientConfig = PersistedConfig::read(config_path).map_err(|err| {
38            anyhow!(
39                "Cannot open wallet config file at {:?}. Err: {err}",
40                config_path
41            )
42        })?;
43
44        let config = config.persisted(config_path);
45        let context = Self {
46            config,
47            request_timeout: None,
48            client: Default::default(),
49            max_concurrent_requests: None,
50            env_override: None,
51        };
52        Ok(context)
53    }
54
55    pub fn new_for_tests(
56        keystore: Keystore,
57        external: Option<Keystore>,
58        path: Option<PathBuf>,
59    ) -> Self {
60        let mut config = SuiClientConfig::new(keystore)
61            .persisted(&path.unwrap_or(PathBuf::from("test_config.yaml")));
62        config.external_keys = external;
63        Self {
64            config,
65            request_timeout: None,
66            client: Arc::new(Default::default()),
67            max_concurrent_requests: None,
68            env_override: None,
69        }
70    }
71
72    pub fn with_request_timeout(mut self, request_timeout: std::time::Duration) -> Self {
73        self.request_timeout = Some(request_timeout);
74        self
75    }
76
77    pub fn with_max_concurrent_requests(mut self, max_concurrent_requests: u64) -> Self {
78        self.max_concurrent_requests = Some(max_concurrent_requests);
79        self
80    }
81
82    pub fn with_env_override(mut self, env_override: String) -> Self {
83        self.env_override = Some(env_override);
84        self
85    }
86
87    pub fn get_addresses(&self) -> Vec<SuiAddress> {
88        self.config.keystore.addresses()
89    }
90
91    pub fn get_env_override(&self) -> Option<String> {
92        self.env_override.clone()
93    }
94
95    pub fn get_identity_address(
96        &mut self,
97        input: Option<KeyIdentity>,
98    ) -> Result<SuiAddress, anyhow::Error> {
99        if let Some(key_identity) = input {
100            if let Ok(address) = self.config.keystore.get_by_identity(&key_identity) {
101                return Ok(address);
102            }
103            if let Some(address) = self
104                .config
105                .external_keys
106                .as_ref()
107                .and_then(|external_keys| external_keys.get_by_identity(&key_identity).ok())
108            {
109                return Ok(address);
110            }
111
112            Err(anyhow!(
113                "No address found for the provided key identity: {key_identity}"
114            ))
115        } else {
116            self.active_address()
117        }
118    }
119
120    pub async fn get_client(&self) -> Result<SuiClient, anyhow::Error> {
121        let read = self.client.read().await;
122
123        Ok(if let Some(client) = read.as_ref() {
124            client.clone()
125        } else {
126            drop(read);
127            let client = self
128                .get_active_env()?
129                .create_rpc_client(self.request_timeout, self.max_concurrent_requests)
130                .await?;
131
132            self.client.write().await.insert(client).clone()
133        })
134    }
135
136    pub async fn load_or_cache_chain_id(
141        &self,
142        client: &SuiClient,
143    ) -> Result<String, anyhow::Error> {
144        self.internal_load_or_cache_chain_id(client, false).await
145    }
146
147    pub async fn cache_chain_id(&self, client: &SuiClient) -> Result<String, anyhow::Error> {
150        self.internal_load_or_cache_chain_id(client, true).await
151    }
152
153    async fn internal_load_or_cache_chain_id(
154        &self,
155        client: &SuiClient,
156        force_recache: bool,
157    ) -> Result<String, anyhow::Error> {
158        let env = self.get_active_env()?;
159        if !force_recache && env.chain_id.is_some() {
160            let chain_id = env.chain_id.as_ref().unwrap();
161            info!("Found cached chain ID for env {}: {}", env.alias, chain_id);
162            return Ok(chain_id.clone());
163        }
164        let chain_id = client.read_api().get_chain_identifier().await?;
165        let path = self.config.path();
166        let mut config_result = SuiClientConfig::load_with_lock(path)?;
167
168        config_result.update_env_chain_id(&env.alias, chain_id.clone())?;
169        config_result.save_with_lock(path)?;
170        Ok(chain_id)
171    }
172
173    pub fn get_active_env(&self) -> Result<&SuiEnv, anyhow::Error> {
174        if self.env_override.is_some() {
175            self.config.get_env(&self.env_override).ok_or_else(|| {
176                anyhow!(
177                    "Environment configuration not found for env [{}]",
178                    self.env_override.as_deref().unwrap_or("None")
179                )
180            })
181        } else {
182            self.config.get_active_env()
183        }
184    }
185
186    pub fn active_address(&mut self) -> Result<SuiAddress, anyhow::Error> {
188        if self.config.keystore.entries().is_empty() {
189            return Err(anyhow!(
190                "No managed addresses. Create new address with `new-address` command."
191            ));
192        }
193
194        self.config.active_address = Some(
197            self.config
198                .active_address
199                .unwrap_or(*self.config.keystore.addresses().first().unwrap()),
200        );
201
202        Ok(self.config.active_address.unwrap())
203    }
204
205    pub async fn get_object_ref(&self, object_id: ObjectID) -> Result<ObjectRef, anyhow::Error> {
207        let client = self.get_client().await?;
208        Ok(client
209            .read_api()
210            .get_object_with_options(object_id, SuiObjectDataOptions::new())
211            .await?
212            .into_object()?
213            .object_ref())
214    }
215
216    pub async fn get_full_object_ref(
218        &self,
219        object_id: ObjectID,
220    ) -> Result<FullObjectRef, anyhow::Error> {
221        let client = self.get_client().await?;
222        let object = client
223            .read_api()
224            .get_object_with_options(object_id, SuiObjectDataOptions::new().with_owner())
225            .await?
226            .into_object()?;
227        let object_ref = object.object_ref();
228        let owner = object
229            .owner
230            .expect("Owner should be present if `with_owner` is set");
231        Ok(FullObjectRef::from_object_ref_and_owner(object_ref, &owner))
232    }
233
234    pub async fn gas_objects(
236        &self,
237        address: SuiAddress,
238    ) -> Result<Vec<(u64, SuiObjectData)>, anyhow::Error> {
239        let client = self.get_client().await?;
240
241        let mut objects: Vec<SuiObjectResponse> = Vec::new();
242        let mut cursor = None;
243        loop {
244            let response = client
245                .read_api()
246                .get_owned_objects(
247                    address,
248                    Some(SuiObjectResponseQuery::new(
249                        Some(SuiObjectDataFilter::StructType(GasCoin::type_())),
250                        Some(SuiObjectDataOptions::full_content()),
251                    )),
252                    cursor,
253                    None,
254                )
255                .await?;
256
257            objects.extend(response.data);
258
259            if response.has_next_page {
260                cursor = response.next_cursor;
261            } else {
262                break;
263            }
264        }
265
266        let mut values_objects = Vec::new();
268
269        for object in objects {
270            let o = object.data;
271            if let Some(o) = o {
272                let gas_coin = GasCoin::try_from(&o)?;
273                values_objects.push((gas_coin.value(), o.clone()));
274            }
275        }
276
277        Ok(values_objects)
278    }
279
280    pub async fn get_object_owner(&self, id: &ObjectID) -> Result<SuiAddress, anyhow::Error> {
281        let client = self.get_client().await?;
282        let object = client
283            .read_api()
284            .get_object_with_options(*id, SuiObjectDataOptions::new().with_owner())
285            .await?
286            .into_object()?;
287        Ok(object
288            .owner
289            .ok_or_else(|| anyhow!("Owner field is None"))?
290            .get_owner_address()?)
291    }
292
293    pub async fn try_get_object_owner(
294        &self,
295        id: &Option<ObjectID>,
296    ) -> Result<Option<SuiAddress>, anyhow::Error> {
297        if let Some(id) = id {
298            Ok(Some(self.get_object_owner(id).await?))
299        } else {
300            Ok(None)
301        }
302    }
303
304    pub async fn infer_sender(&mut self, gas: &[ObjectID]) -> Result<SuiAddress, anyhow::Error> {
307        if gas.is_empty() {
308            return self.active_address();
309        }
310
311        let owners = future::try_join_all(gas.iter().map(|id| self.get_object_owner(id))).await?;
313
314        let owner = owners.first().copied().unwrap();
316
317        ensure!(
318            owners.iter().all(|o| o == &owner),
319            "Cannot infer sender, not all gas objects have the same owner."
320        );
321
322        Ok(owner)
323    }
324
325    pub async fn gas_for_owner_budget(
327        &self,
328        address: SuiAddress,
329        budget: u64,
330        forbidden_gas_objects: BTreeSet<ObjectID>,
331    ) -> Result<(u64, SuiObjectData), anyhow::Error> {
332        for o in self.gas_objects(address).await? {
333            if o.0 >= budget && !forbidden_gas_objects.contains(&o.1.object_id) {
334                return Ok((o.0, o.1));
335            }
336        }
337        Err(anyhow!(
338            "No non-argument gas objects found for this address with value >= budget {budget}. Run sui client gas to check for gas objects."
339        ))
340    }
341
342    pub async fn get_all_gas_objects_owned_by_address(
343        &self,
344        address: SuiAddress,
345    ) -> anyhow::Result<Vec<ObjectRef>> {
346        self.get_gas_objects_owned_by_address(address, None).await
347    }
348
349    pub async fn get_gas_objects_owned_by_address(
350        &self,
351        address: SuiAddress,
352        limit: Option<usize>,
353    ) -> anyhow::Result<Vec<ObjectRef>> {
354        let client = self.get_client().await?;
355        let results: Vec<_> = client
356            .read_api()
357            .get_owned_objects(
358                address,
359                Some(SuiObjectResponseQuery::new(
360                    Some(SuiObjectDataFilter::StructType(GasCoin::type_())),
361                    Some(SuiObjectDataOptions::full_content()),
362                )),
363                None,
364                limit,
365            )
366            .await?
367            .data
368            .into_iter()
369            .filter_map(|r| r.data.map(|o| o.object_ref()))
370            .collect();
371        Ok(results)
372    }
373
374    pub async fn get_one_gas_object_owned_by_address(
377        &self,
378        address: SuiAddress,
379    ) -> anyhow::Result<Option<ObjectRef>> {
380        Ok(self
381            .get_gas_objects_owned_by_address(address, Some(1))
382            .await?
383            .pop())
384    }
385
386    pub async fn get_one_account(&self) -> anyhow::Result<(SuiAddress, Vec<ObjectRef>)> {
388        let address = self.get_addresses().pop().unwrap();
389        Ok((
390            address,
391            self.get_all_gas_objects_owned_by_address(address).await?,
392        ))
393    }
394
395    pub async fn get_one_gas_object(&self) -> anyhow::Result<Option<(SuiAddress, ObjectRef)>> {
397        for address in self.get_addresses() {
398            if let Some(gas_object) = self.get_one_gas_object_owned_by_address(address).await? {
399                return Ok(Some((address, gas_object)));
400            }
401        }
402        Ok(None)
403    }
404
405    pub async fn get_all_accounts_and_gas_objects(
407        &self,
408    ) -> anyhow::Result<Vec<(SuiAddress, Vec<ObjectRef>)>> {
409        let mut result = vec![];
410        for address in self.get_addresses() {
411            let objects = self
412                .gas_objects(address)
413                .await?
414                .into_iter()
415                .map(|(_, o)| o.object_ref())
416                .collect();
417            result.push((address, objects));
418        }
419        Ok(result)
420    }
421
422    pub async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
423        let client = self.get_client().await?;
424        let gas_price = client.governance_api().get_reference_gas_price().await?;
425        Ok(gas_price)
426    }
427
428    pub async fn add_account(&mut self, alias: Option<String>, keypair: SuiKeyPair) {
430        self.config.keystore.import(alias, keypair).await.unwrap();
431    }
432
433    pub fn get_keystore_by_identity(
434        &self,
435        key_identity: &KeyIdentity,
436    ) -> Result<&Keystore, anyhow::Error> {
437        if self.config.keystore.get_by_identity(key_identity).is_ok() {
438            return Ok(&self.config.keystore);
439        }
440
441        if let Some(external_keys) = self.config.external_keys.as_ref()
442            && external_keys.get_by_identity(key_identity).is_ok()
443        {
444            return Ok(external_keys);
445        }
446
447        Err(anyhow!(
448            "No keystore found for the provided key identity: {key_identity}"
449        ))
450    }
451
452    pub fn get_keystore_by_identity_mut(
453        &mut self,
454        key_identity: &KeyIdentity,
455    ) -> Result<&mut Keystore, anyhow::Error> {
456        if self.config.keystore.get_by_identity(key_identity).is_ok() {
457            return Ok(&mut self.config.keystore);
458        }
459
460        if let Some(external_keys) = self.config.external_keys.as_mut()
461            && external_keys.get_by_identity(key_identity).is_ok()
462        {
463            return Ok(external_keys);
464        }
465
466        Err(anyhow!(
467            "No keystore found for the provided key identity: {key_identity}"
468        ))
469    }
470
471    pub async fn sign_secure(
472        &self,
473        key_identity: &KeyIdentity,
474        data: &TransactionData,
475        intent: Intent,
476    ) -> Result<Signature, anyhow::Error> {
477        let keystore = self.get_keystore_by_identity(key_identity)?;
478        let sig = keystore.sign_secure(&data.sender(), data, intent).await?;
479        Ok(sig)
480    }
481
482    pub async fn sign_transaction(&self, data: &TransactionData) -> Transaction {
484        let sig = self
485            .config
486            .keystore
487            .sign_secure(&data.sender(), data, Intent::sui_transaction())
488            .await
489            .unwrap();
490        Transaction::from_data(data.clone(), vec![sig])
492    }
493
494    pub async fn execute_transaction_must_succeed(
497        &self,
498        tx: Transaction,
499    ) -> SuiTransactionBlockResponse {
500        tracing::debug!("Executing transaction: {:?}", tx);
501        let response = self.execute_transaction_may_fail(tx).await.unwrap();
502        assert!(
503            response.status_ok().unwrap(),
504            "Transaction failed: {:?}",
505            response
506        );
507        response
508    }
509
510    pub async fn execute_transaction_may_fail(
514        &self,
515        tx: Transaction,
516    ) -> anyhow::Result<SuiTransactionBlockResponse> {
517        let client = self.get_client().await?;
518        Ok(client
519            .quorum_driver_api()
520            .execute_transaction_block(
521                tx,
522                SuiTransactionBlockResponseOptions::new()
523                    .with_effects()
524                    .with_input()
525                    .with_events()
526                    .with_object_changes()
527                    .with_balance_changes(),
528                Some(sui_types::quorum_driver_types::ExecuteTransactionRequestType::WaitForLocalExecution),
529            )
530            .await?)
531    }
532}