sui_sdk/
wallet_context.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::SuiClient;
5use crate::sui_client_config::{SuiClientConfig, SuiEnv};
6use anyhow::{anyhow, ensure};
7use futures::future;
8use shared_crypto::intent::Intent;
9use std::collections::BTreeSet;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use sui_config::{Config, PersistedConfig};
13use sui_json_rpc_types::{
14    SuiObjectData, SuiObjectDataFilter, SuiObjectDataOptions, SuiObjectResponse,
15    SuiObjectResponseQuery, SuiTransactionBlockResponse, SuiTransactionBlockResponseOptions,
16};
17use sui_keys::key_identity::KeyIdentity;
18use sui_keys::keystore::{AccountKeystore, Keystore};
19use sui_types::base_types::{FullObjectRef, ObjectID, ObjectRef, SuiAddress};
20use sui_types::crypto::{Signature, SuiKeyPair};
21
22use sui_types::gas_coin::GasCoin;
23use sui_types::transaction::{Transaction, TransactionData, TransactionDataAPI};
24use tokio::sync::RwLock;
25use tracing::info;
26
27pub struct WalletContext {
28    pub config: PersistedConfig<SuiClientConfig>,
29    request_timeout: Option<std::time::Duration>,
30    client: Arc<RwLock<Option<SuiClient>>>,
31    max_concurrent_requests: Option<u64>,
32    env_override: Option<String>,
33}
34
35impl WalletContext {
36    pub fn new(config_path: &Path) -> Result<Self, anyhow::Error> {
37        let config: SuiClientConfig = PersistedConfig::read(config_path).map_err(|err| {
38            anyhow!(
39                "Cannot open wallet config file at {:?}. Err: {err}",
40                config_path
41            )
42        })?;
43
44        let config = config.persisted(config_path);
45        let context = Self {
46            config,
47            request_timeout: None,
48            client: Default::default(),
49            max_concurrent_requests: None,
50            env_override: None,
51        };
52        Ok(context)
53    }
54
55    pub fn new_for_tests(
56        keystore: Keystore,
57        external: Option<Keystore>,
58        path: Option<PathBuf>,
59    ) -> Self {
60        let mut config = SuiClientConfig::new(keystore)
61            .persisted(&path.unwrap_or(PathBuf::from("test_config.yaml")));
62        config.external_keys = external;
63        Self {
64            config,
65            request_timeout: None,
66            client: Arc::new(Default::default()),
67            max_concurrent_requests: None,
68            env_override: None,
69        }
70    }
71
72    pub fn with_request_timeout(mut self, request_timeout: std::time::Duration) -> Self {
73        self.request_timeout = Some(request_timeout);
74        self
75    }
76
77    pub fn with_max_concurrent_requests(mut self, max_concurrent_requests: u64) -> Self {
78        self.max_concurrent_requests = Some(max_concurrent_requests);
79        self
80    }
81
82    pub fn with_env_override(mut self, env_override: String) -> Self {
83        self.env_override = Some(env_override);
84        self
85    }
86
87    pub fn get_addresses(&self) -> Vec<SuiAddress> {
88        self.config.keystore.addresses()
89    }
90
91    pub fn get_env_override(&self) -> Option<String> {
92        self.env_override.clone()
93    }
94
95    pub fn get_identity_address(
96        &mut self,
97        input: Option<KeyIdentity>,
98    ) -> Result<SuiAddress, anyhow::Error> {
99        if let Some(key_identity) = input {
100            if let Ok(address) = self.config.keystore.get_by_identity(&key_identity) {
101                return Ok(address);
102            }
103            if let Some(address) = self
104                .config
105                .external_keys
106                .as_ref()
107                .and_then(|external_keys| external_keys.get_by_identity(&key_identity).ok())
108            {
109                return Ok(address);
110            }
111
112            Err(anyhow!(
113                "No address found for the provided key identity: {key_identity}"
114            ))
115        } else {
116            self.active_address()
117        }
118    }
119
120    pub async fn get_client(&self) -> Result<SuiClient, anyhow::Error> {
121        let read = self.client.read().await;
122
123        Ok(if let Some(client) = read.as_ref() {
124            client.clone()
125        } else {
126            drop(read);
127            let client = self
128                .get_active_env()?
129                .create_rpc_client(self.request_timeout, self.max_concurrent_requests)
130                .await?;
131
132            self.client.write().await.insert(client).clone()
133        })
134    }
135
136    /// Load the chain ID corresponding to the active environment, or fetch and cache it if not
137    /// present.
138    ///
139    /// The chain ID is cached in the `client.yaml` file to avoid redundant network requests.
140    pub async fn load_or_cache_chain_id(
141        &self,
142        client: &SuiClient,
143    ) -> Result<String, anyhow::Error> {
144        self.internal_load_or_cache_chain_id(client, false).await
145    }
146
147    /// Try to load the cached chain ID for the active environment.
148    pub async fn try_load_chain_id_from_cache(
149        &self,
150        env: Option<String>,
151    ) -> Result<String, anyhow::Error> {
152        let env = if let Some(env) = env {
153            self.config
154                .get_env(&Some(env.to_string()))
155                .ok_or_else(|| anyhow!("Environment configuration not found for env [{}]", env))?
156        } else {
157            self.get_active_env()?
158        };
159        if let Some(chain_id) = &env.chain_id {
160            Ok(chain_id.clone())
161        } else {
162            Err(anyhow!(
163                "No cached chain ID found for env {}. Please pass `-e env_name` to your command",
164                env.alias
165            ))
166        }
167    }
168
169    /// Cache (or recache) chain ID for the active environment by fetching it from the
170    /// network
171    pub async fn cache_chain_id(&self, client: &SuiClient) -> Result<String, anyhow::Error> {
172        self.internal_load_or_cache_chain_id(client, true).await
173    }
174
175    async fn internal_load_or_cache_chain_id(
176        &self,
177        client: &SuiClient,
178        force_recache: bool,
179    ) -> Result<String, anyhow::Error> {
180        let env = self.get_active_env()?;
181        if !force_recache && env.chain_id.is_some() {
182            let chain_id = env.chain_id.as_ref().unwrap();
183            info!("Found cached chain ID for env {}: {}", env.alias, chain_id);
184            return Ok(chain_id.clone());
185        }
186        let chain_id = client.read_api().get_chain_identifier().await?;
187        let path = self.config.path();
188        let mut config_result = SuiClientConfig::load_with_lock(path)?;
189
190        config_result.update_env_chain_id(&env.alias, chain_id.clone())?;
191        config_result.save_with_lock(path)?;
192        Ok(chain_id)
193    }
194
195    pub fn get_active_env(&self) -> Result<&SuiEnv, anyhow::Error> {
196        if self.env_override.is_some() {
197            self.config.get_env(&self.env_override).ok_or_else(|| {
198                anyhow!(
199                    "Environment configuration not found for env [{}]",
200                    self.env_override.as_deref().unwrap_or("None")
201                )
202            })
203        } else {
204            self.config.get_active_env()
205        }
206    }
207
208    // TODO: Ger rid of mut
209    pub fn active_address(&mut self) -> Result<SuiAddress, anyhow::Error> {
210        if self.config.keystore.entries().is_empty() {
211            return Err(anyhow!(
212                "No managed addresses. Create new address with `new-address` command."
213            ));
214        }
215
216        // Ok to unwrap because we checked that config addresses not empty
217        // Set it if not exists
218        self.config.active_address = Some(
219            self.config
220                .active_address
221                .unwrap_or(*self.config.keystore.addresses().first().unwrap()),
222        );
223
224        Ok(self.config.active_address.unwrap())
225    }
226
227    /// Get the latest object reference given a object id
228    pub async fn get_object_ref(&self, object_id: ObjectID) -> Result<ObjectRef, anyhow::Error> {
229        let client = self.get_client().await?;
230        Ok(client
231            .read_api()
232            .get_object_with_options(object_id, SuiObjectDataOptions::new())
233            .await?
234            .into_object()?
235            .object_ref())
236    }
237
238    /// Get the latest full object reference given a object id
239    pub async fn get_full_object_ref(
240        &self,
241        object_id: ObjectID,
242    ) -> Result<FullObjectRef, anyhow::Error> {
243        let client = self.get_client().await?;
244        let object = client
245            .read_api()
246            .get_object_with_options(object_id, SuiObjectDataOptions::new().with_owner())
247            .await?
248            .into_object()?;
249        let object_ref = object.object_ref();
250        let owner = object
251            .owner
252            .expect("Owner should be present if `with_owner` is set");
253        Ok(FullObjectRef::from_object_ref_and_owner(object_ref, &owner))
254    }
255
256    /// Get all the gas objects (and conveniently, gas amounts) for the address
257    pub async fn gas_objects(
258        &self,
259        address: SuiAddress,
260    ) -> Result<Vec<(u64, SuiObjectData)>, anyhow::Error> {
261        let client = self.get_client().await?;
262
263        let mut objects: Vec<SuiObjectResponse> = Vec::new();
264        let mut cursor = None;
265        loop {
266            let response = client
267                .read_api()
268                .get_owned_objects(
269                    address,
270                    Some(SuiObjectResponseQuery::new(
271                        Some(SuiObjectDataFilter::StructType(GasCoin::type_())),
272                        Some(SuiObjectDataOptions::full_content()),
273                    )),
274                    cursor,
275                    None,
276                )
277                .await?;
278
279            objects.extend(response.data);
280
281            if response.has_next_page {
282                cursor = response.next_cursor;
283            } else {
284                break;
285            }
286        }
287
288        // TODO: We should ideally fetch the objects from local cache
289        let mut values_objects = Vec::new();
290
291        for object in objects {
292            let o = object.data;
293            if let Some(o) = o {
294                let gas_coin = GasCoin::try_from(&o)?;
295                values_objects.push((gas_coin.value(), o.clone()));
296            }
297        }
298
299        Ok(values_objects)
300    }
301
302    pub async fn get_object_owner(&self, id: &ObjectID) -> Result<SuiAddress, anyhow::Error> {
303        let client = self.get_client().await?;
304        let object = client
305            .read_api()
306            .get_object_with_options(*id, SuiObjectDataOptions::new().with_owner())
307            .await?
308            .into_object()?;
309        Ok(object
310            .owner
311            .ok_or_else(|| anyhow!("Owner field is None"))?
312            .get_owner_address()?)
313    }
314
315    pub async fn try_get_object_owner(
316        &self,
317        id: &Option<ObjectID>,
318    ) -> Result<Option<SuiAddress>, anyhow::Error> {
319        if let Some(id) = id {
320            Ok(Some(self.get_object_owner(id).await?))
321        } else {
322            Ok(None)
323        }
324    }
325
326    /// Infer the sender of a transaction based on the gas objects provided. If no gas objects are
327    /// provided, assume the active address is the sender.
328    pub async fn infer_sender(&mut self, gas: &[ObjectID]) -> Result<SuiAddress, anyhow::Error> {
329        if gas.is_empty() {
330            return self.active_address();
331        }
332
333        // Find the owners of all supplied object IDs
334        let owners = future::try_join_all(gas.iter().map(|id| self.get_object_owner(id))).await?;
335
336        // SAFETY `gas` is non-empty.
337        let owner = owners.first().copied().unwrap();
338
339        ensure!(
340            owners.iter().all(|o| o == &owner),
341            "Cannot infer sender, not all gas objects have the same owner."
342        );
343
344        Ok(owner)
345    }
346
347    /// Find a gas object which fits the budget
348    pub async fn gas_for_owner_budget(
349        &self,
350        address: SuiAddress,
351        budget: u64,
352        forbidden_gas_objects: BTreeSet<ObjectID>,
353    ) -> Result<(u64, SuiObjectData), anyhow::Error> {
354        for o in self.gas_objects(address).await? {
355            if o.0 >= budget && !forbidden_gas_objects.contains(&o.1.object_id) {
356                return Ok((o.0, o.1));
357            }
358        }
359        Err(anyhow!(
360            "No non-argument gas objects found for this address with value >= budget {budget}. Run sui client gas to check for gas objects."
361        ))
362    }
363
364    pub async fn get_all_gas_objects_owned_by_address(
365        &self,
366        address: SuiAddress,
367    ) -> anyhow::Result<Vec<ObjectRef>> {
368        self.get_gas_objects_owned_by_address(address, None).await
369    }
370
371    pub async fn get_gas_objects_owned_by_address(
372        &self,
373        address: SuiAddress,
374        limit: Option<usize>,
375    ) -> anyhow::Result<Vec<ObjectRef>> {
376        let client = self.get_client().await?;
377        let results: Vec<_> = client
378            .read_api()
379            .get_owned_objects(
380                address,
381                Some(SuiObjectResponseQuery::new(
382                    Some(SuiObjectDataFilter::StructType(GasCoin::type_())),
383                    Some(SuiObjectDataOptions::full_content()),
384                )),
385                None,
386                limit,
387            )
388            .await?
389            .data
390            .into_iter()
391            .filter_map(|r| r.data.map(|o| o.object_ref()))
392            .collect();
393        Ok(results)
394    }
395
396    /// Given an address, return one gas object owned by this address.
397    /// The actual implementation just returns the first one returned by the read api.
398    pub async fn get_one_gas_object_owned_by_address(
399        &self,
400        address: SuiAddress,
401    ) -> anyhow::Result<Option<ObjectRef>> {
402        Ok(self
403            .get_gas_objects_owned_by_address(address, Some(1))
404            .await?
405            .pop())
406    }
407
408    /// Returns one address and all gas objects owned by that address.
409    pub async fn get_one_account(&self) -> anyhow::Result<(SuiAddress, Vec<ObjectRef>)> {
410        let address = self.get_addresses().pop().unwrap();
411        Ok((
412            address,
413            self.get_all_gas_objects_owned_by_address(address).await?,
414        ))
415    }
416
417    /// Return a gas object owned by an arbitrary address managed by the wallet.
418    pub async fn get_one_gas_object(&self) -> anyhow::Result<Option<(SuiAddress, ObjectRef)>> {
419        for address in self.get_addresses() {
420            if let Some(gas_object) = self.get_one_gas_object_owned_by_address(address).await? {
421                return Ok(Some((address, gas_object)));
422            }
423        }
424        Ok(None)
425    }
426
427    /// Returns all the account addresses managed by the wallet and their owned gas objects.
428    pub async fn get_all_accounts_and_gas_objects(
429        &self,
430    ) -> anyhow::Result<Vec<(SuiAddress, Vec<ObjectRef>)>> {
431        let mut result = vec![];
432        for address in self.get_addresses() {
433            let objects = self
434                .gas_objects(address)
435                .await?
436                .into_iter()
437                .map(|(_, o)| o.object_ref())
438                .collect();
439            result.push((address, objects));
440        }
441        Ok(result)
442    }
443
444    pub async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
445        let client = self.get_client().await?;
446        let gas_price = client.governance_api().get_reference_gas_price().await?;
447        Ok(gas_price)
448    }
449
450    /// Add an account
451    pub async fn add_account(&mut self, alias: Option<String>, keypair: SuiKeyPair) {
452        self.config.keystore.import(alias, keypair).await.unwrap();
453    }
454
455    pub fn get_keystore_by_identity(
456        &self,
457        key_identity: &KeyIdentity,
458    ) -> Result<&Keystore, anyhow::Error> {
459        if self.config.keystore.get_by_identity(key_identity).is_ok() {
460            return Ok(&self.config.keystore);
461        }
462
463        if let Some(external_keys) = self.config.external_keys.as_ref()
464            && external_keys.get_by_identity(key_identity).is_ok()
465        {
466            return Ok(external_keys);
467        }
468
469        Err(anyhow!(
470            "No keystore found for the provided key identity: {key_identity}"
471        ))
472    }
473
474    pub fn get_keystore_by_identity_mut(
475        &mut self,
476        key_identity: &KeyIdentity,
477    ) -> Result<&mut Keystore, anyhow::Error> {
478        if self.config.keystore.get_by_identity(key_identity).is_ok() {
479            return Ok(&mut self.config.keystore);
480        }
481
482        if let Some(external_keys) = self.config.external_keys.as_mut()
483            && external_keys.get_by_identity(key_identity).is_ok()
484        {
485            return Ok(external_keys);
486        }
487
488        Err(anyhow!(
489            "No keystore found for the provided key identity: {key_identity}"
490        ))
491    }
492
493    pub async fn sign_secure(
494        &self,
495        key_identity: &KeyIdentity,
496        data: &TransactionData,
497        intent: Intent,
498    ) -> Result<Signature, anyhow::Error> {
499        let keystore = self.get_keystore_by_identity(key_identity)?;
500        let sig = keystore.sign_secure(&data.sender(), data, intent).await?;
501        Ok(sig)
502    }
503
504    /// Sign a transaction with a key currently managed by the WalletContext
505    pub async fn sign_transaction(&self, data: &TransactionData) -> Transaction {
506        let sig = self
507            .config
508            .keystore
509            .sign_secure(&data.sender(), data, Intent::sui_transaction())
510            .await
511            .unwrap();
512        // TODO: To support sponsored transaction, we should also look at the gas owner.
513        Transaction::from_data(data.clone(), vec![sig])
514    }
515
516    /// Execute a transaction and wait for it to be locally executed on the fullnode.
517    /// Also expects the effects status to be ExecutionStatus::Success.
518    pub async fn execute_transaction_must_succeed(
519        &self,
520        tx: Transaction,
521    ) -> SuiTransactionBlockResponse {
522        tracing::debug!("Executing transaction: {:?}", tx);
523        let response = self.execute_transaction_may_fail(tx).await.unwrap();
524        assert!(
525            response.status_ok().unwrap(),
526            "Transaction failed: {:?}",
527            response
528        );
529        response
530    }
531
532    /// Execute a transaction and wait for it to be locally executed on the fullnode.
533    /// The transaction execution is not guaranteed to succeed and may fail. This is usually only
534    /// needed in non-test environment or the caller is explicitly testing some failure behavior.
535    pub async fn execute_transaction_may_fail(
536        &self,
537        tx: Transaction,
538    ) -> anyhow::Result<SuiTransactionBlockResponse> {
539        let client = self.get_client().await?;
540        Ok(client
541            .quorum_driver_api()
542            .execute_transaction_block(
543                tx,
544                SuiTransactionBlockResponseOptions::new()
545                    .with_effects()
546                    .with_input()
547                    .with_events()
548                    .with_object_changes()
549                    .with_balance_changes(),
550                Some(sui_types::quorum_driver_types::ExecuteTransactionRequestType::WaitForLocalExecution),
551            )
552            .await?)
553    }
554}