sui_sdk/
wallet_context.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::sui_client_config::{SuiClientConfig, SuiEnv};
5use anyhow::{anyhow, ensure};
6use futures::future;
7use futures::stream::TryStreamExt;
8use shared_crypto::intent::Intent;
9use std::collections::BTreeSet;
10use std::path::{Path, PathBuf};
11use sui_config::{Config, PersistedConfig};
12use sui_keys::key_identity::KeyIdentity;
13use sui_keys::keystore::{AccountKeystore, Keystore};
14use sui_rpc_api::client::ExecutedTransaction;
15use sui_types::base_types::{FullObjectRef, ObjectID, ObjectRef, SuiAddress};
16use sui_types::crypto::{Signature, SuiKeyPair};
17use sui_types::effects::TransactionEffectsAPI;
18use sui_types::object::Object;
19
20use std::sync::OnceLock;
21use sui_rpc_api::Client;
22use sui_types::gas_coin::GasCoin;
23use sui_types::transaction::{Transaction, TransactionData, TransactionDataAPI};
24use tracing::info;
25
26pub struct WalletContext {
27    pub config: PersistedConfig<SuiClientConfig>,
28    request_timeout: Option<std::time::Duration>,
29    grpc: OnceLock<Client>,
30    max_concurrent_requests: Option<u64>,
31    env_override: Option<String>,
32}
33
34impl WalletContext {
35    pub fn new(config_path: &Path) -> Result<Self, anyhow::Error> {
36        let config: SuiClientConfig = PersistedConfig::read(config_path).map_err(|err| {
37            anyhow!(
38                "Cannot open wallet config file at {:?}. Err: {err}",
39                config_path
40            )
41        })?;
42
43        let config = config.persisted(config_path);
44        let context = Self {
45            config,
46            request_timeout: None,
47            grpc: OnceLock::new(),
48            max_concurrent_requests: None,
49            env_override: None,
50        };
51        Ok(context)
52    }
53
54    pub fn new_for_tests(
55        keystore: Keystore,
56        external: Option<Keystore>,
57        path: Option<PathBuf>,
58    ) -> Self {
59        let mut config = SuiClientConfig::new(keystore)
60            .persisted(&path.unwrap_or(PathBuf::from("test_config.yaml")));
61        config.external_keys = external;
62        Self {
63            config,
64            request_timeout: None,
65            grpc: OnceLock::new(),
66            max_concurrent_requests: None,
67            env_override: None,
68        }
69    }
70
71    pub fn with_request_timeout(mut self, request_timeout: std::time::Duration) -> Self {
72        self.request_timeout = Some(request_timeout);
73        self
74    }
75
76    pub fn with_max_concurrent_requests(mut self, max_concurrent_requests: u64) -> Self {
77        self.max_concurrent_requests = Some(max_concurrent_requests);
78        self
79    }
80
81    pub fn with_env_override(mut self, env_override: String) -> Self {
82        self.env_override = Some(env_override);
83        self
84    }
85
86    pub fn get_addresses(&self) -> Vec<SuiAddress> {
87        self.config.keystore.addresses()
88    }
89
90    pub fn get_env_override(&self) -> Option<String> {
91        self.env_override.clone()
92    }
93
94    pub fn get_identity_address(
95        &mut self,
96        input: Option<KeyIdentity>,
97    ) -> Result<SuiAddress, anyhow::Error> {
98        if let Some(key_identity) = input {
99            if let Ok(address) = self.config.keystore.get_by_identity(&key_identity) {
100                return Ok(address);
101            }
102            if let Some(address) = self
103                .config
104                .external_keys
105                .as_ref()
106                .and_then(|external_keys| external_keys.get_by_identity(&key_identity).ok())
107            {
108                return Ok(address);
109            }
110
111            Err(anyhow!(
112                "No address found for the provided key identity: {key_identity}"
113            ))
114        } else {
115            self.active_address()
116        }
117    }
118
119    pub fn grpc_client(&self) -> Result<Client, anyhow::Error> {
120        if let Some(client) = self.grpc.get() {
121            Ok(client.clone())
122        } else {
123            let client = self.get_active_env()?.create_grpc_client()?;
124            Ok(self.grpc.get_or_init(move || client).clone())
125        }
126    }
127
128    /// Load the chain ID corresponding to the active environment, or fetch and cache it if not
129    /// present.
130    ///
131    /// The chain ID is cached in the `client.yaml` file to avoid redundant network requests.
132    pub async fn load_or_cache_chain_id(&self) -> Result<String, anyhow::Error> {
133        self.internal_load_or_cache_chain_id(false).await
134    }
135
136    /// Try to load the cached chain ID for the active environment.
137    pub async fn try_load_chain_id_from_cache(
138        &self,
139        env: Option<String>,
140    ) -> Result<String, anyhow::Error> {
141        let env = if let Some(env) = env {
142            self.config
143                .get_env(&Some(env.to_string()))
144                .ok_or_else(|| anyhow!("Environment configuration not found for env [{}]", env))?
145        } else {
146            self.get_active_env()?
147        };
148        if let Some(chain_id) = &env.chain_id {
149            Ok(chain_id.clone())
150        } else {
151            Err(anyhow!(
152                "No cached chain ID found for env {}. Please pass `-e env_name` to your command",
153                env.alias
154            ))
155        }
156    }
157
158    /// Cache (or recache) chain ID for the active environment by fetching it from the
159    /// network
160    pub async fn cache_chain_id(&self) -> Result<String, anyhow::Error> {
161        self.internal_load_or_cache_chain_id(true).await
162    }
163
164    async fn internal_load_or_cache_chain_id(
165        &self,
166        force_recache: bool,
167    ) -> Result<String, anyhow::Error> {
168        let env = self.get_active_env()?;
169        if !force_recache && env.chain_id.is_some() {
170            let chain_id = env.chain_id.as_ref().unwrap();
171            info!("Found cached chain ID for env {}: {}", env.alias, chain_id);
172            return Ok(chain_id.clone());
173        }
174        let chain_id = self.grpc_client()?.get_chain_identifier().await?;
175        let path = self.config.path();
176        let mut config_result = SuiClientConfig::load_with_lock(path)?;
177
178        config_result.update_env_chain_id(&env.alias, chain_id.to_string())?;
179        config_result.save_with_lock(path)?;
180        Ok(chain_id.to_string())
181    }
182
183    pub fn get_active_env(&self) -> Result<&SuiEnv, anyhow::Error> {
184        if self.env_override.is_some() {
185            self.config.get_env(&self.env_override).ok_or_else(|| {
186                anyhow!(
187                    "Environment configuration not found for env [{}]",
188                    self.env_override.as_deref().unwrap_or("None")
189                )
190            })
191        } else {
192            self.config.get_active_env()
193        }
194    }
195
196    // TODO: Ger rid of mut
197    pub fn active_address(&mut self) -> Result<SuiAddress, anyhow::Error> {
198        if self.config.keystore.entries().is_empty() {
199            return Err(anyhow!(
200                "No managed addresses. Create new address with `new-address` command."
201            ));
202        }
203
204        // Ok to unwrap because we checked that config addresses not empty
205        // Set it if not exists
206        self.config.active_address = Some(
207            self.config
208                .active_address
209                .unwrap_or(*self.config.keystore.addresses().first().unwrap()),
210        );
211
212        Ok(self.config.active_address.unwrap())
213    }
214
215    /// Get the latest object reference given a object id
216    pub async fn get_object_ref(&self, object_id: ObjectID) -> Result<ObjectRef, anyhow::Error> {
217        Ok(self
218            .grpc_client()?
219            .get_object(object_id)
220            .await?
221            .compute_object_reference())
222    }
223
224    /// Get the latest full object reference given a object id
225    pub async fn get_full_object_ref(
226        &self,
227        object_id: ObjectID,
228    ) -> Result<FullObjectRef, anyhow::Error> {
229        Ok(self
230            .grpc_client()?
231            .get_object(object_id)
232            .await?
233            .compute_full_object_reference())
234    }
235
236    /// Get all the gas objects (and conveniently, gas amounts) for the address
237    pub async fn gas_objects(
238        &self,
239        owner: SuiAddress,
240    ) -> Result<Vec<(u64, Object)>, anyhow::Error> {
241        let client = self.grpc_client()?;
242
243        client
244            .list_owned_objects(owner, Some(GasCoin::type_()))
245            .map_err(Into::into)
246            .and_then(|object| async move {
247                let gas_coin = GasCoin::try_from(&object)?;
248
249                Ok((gas_coin.value(), object))
250            })
251            .try_collect()
252            .await
253    }
254
255    pub async fn get_object_owner(&self, id: &ObjectID) -> Result<SuiAddress, anyhow::Error> {
256        self.grpc_client()?
257            .get_object(*id)
258            .await?
259            .owner()
260            .get_owner_address()
261            .map_err(Into::into)
262    }
263
264    pub async fn try_get_object_owner(
265        &self,
266        id: &Option<ObjectID>,
267    ) -> Result<Option<SuiAddress>, anyhow::Error> {
268        if let Some(id) = id {
269            Ok(Some(self.get_object_owner(id).await?))
270        } else {
271            Ok(None)
272        }
273    }
274
275    /// Infer the sender of a transaction based on the gas objects provided. If no gas objects are
276    /// provided, assume the active address is the sender.
277    pub async fn infer_sender(&mut self, gas: &[ObjectID]) -> Result<SuiAddress, anyhow::Error> {
278        if gas.is_empty() {
279            return self.active_address();
280        }
281
282        // Find the owners of all supplied object IDs
283        let owners = future::try_join_all(gas.iter().map(|id| self.get_object_owner(id))).await?;
284
285        // SAFETY `gas` is non-empty.
286        let owner = owners.first().copied().unwrap();
287
288        ensure!(
289            owners.iter().all(|o| o == &owner),
290            "Cannot infer sender, not all gas objects have the same owner."
291        );
292
293        Ok(owner)
294    }
295
296    /// Find a gas object which fits the budget
297    pub async fn gas_for_owner_budget(
298        &self,
299        address: SuiAddress,
300        budget: u64,
301        forbidden_gas_objects: BTreeSet<ObjectID>,
302    ) -> Result<(u64, Object), anyhow::Error> {
303        for o in self.gas_objects(address).await? {
304            if o.0 >= budget && !forbidden_gas_objects.contains(&o.1.id()) {
305                return Ok((o.0, o.1));
306            }
307        }
308        Err(anyhow!(
309            "No non-argument gas objects found for this address with value >= budget {budget}. Run sui client gas to check for gas objects."
310        ))
311    }
312
313    pub async fn get_all_gas_objects_owned_by_address(
314        &self,
315        address: SuiAddress,
316    ) -> anyhow::Result<Vec<ObjectRef>> {
317        self.get_gas_objects_owned_by_address(address, None).await
318    }
319
320    pub async fn get_gas_objects_owned_by_address(
321        &self,
322        owner: SuiAddress,
323        page_size: Option<u32>,
324    ) -> anyhow::Result<Vec<ObjectRef>> {
325        let page = self
326            .grpc_client()?
327            .get_owned_objects(owner, Some(GasCoin::type_()), page_size, None)
328            .await?;
329
330        Ok(page
331            .items
332            .into_iter()
333            .map(|o| o.compute_object_reference())
334            .collect())
335    }
336
337    /// Given an address, return one gas object owned by this address.
338    /// The actual implementation just returns the first one returned by the read api.
339    pub async fn get_one_gas_object_owned_by_address(
340        &self,
341        address: SuiAddress,
342    ) -> anyhow::Result<Option<ObjectRef>> {
343        Ok(self
344            .get_gas_objects_owned_by_address(address, Some(1))
345            .await?
346            .pop())
347    }
348
349    /// Returns one address and all gas objects owned by that address.
350    pub async fn get_one_account(&self) -> anyhow::Result<(SuiAddress, Vec<ObjectRef>)> {
351        let address = self.get_addresses().pop().unwrap();
352        Ok((
353            address,
354            self.get_all_gas_objects_owned_by_address(address).await?,
355        ))
356    }
357
358    /// Return a gas object owned by an arbitrary address managed by the wallet.
359    pub async fn get_one_gas_object(&self) -> anyhow::Result<Option<(SuiAddress, ObjectRef)>> {
360        for address in self.get_addresses() {
361            if let Some(gas_object) = self.get_one_gas_object_owned_by_address(address).await? {
362                return Ok(Some((address, gas_object)));
363            }
364        }
365        Ok(None)
366    }
367
368    /// Returns all the account addresses managed by the wallet and their owned gas objects.
369    pub async fn get_all_accounts_and_gas_objects(
370        &self,
371    ) -> anyhow::Result<Vec<(SuiAddress, Vec<ObjectRef>)>> {
372        let mut result = vec![];
373        for address in self.get_addresses() {
374            let objects = self
375                .gas_objects(address)
376                .await?
377                .into_iter()
378                .map(|(_, o)| o.compute_object_reference())
379                .collect();
380            result.push((address, objects));
381        }
382        Ok(result)
383    }
384
385    pub async fn get_reference_gas_price(&self) -> Result<u64, anyhow::Error> {
386        self.grpc_client()?
387            .get_reference_gas_price()
388            .await
389            .map_err(Into::into)
390    }
391
392    /// Add an account
393    pub async fn add_account(&mut self, alias: Option<String>, keypair: SuiKeyPair) {
394        self.config.keystore.import(alias, keypair).await.unwrap();
395    }
396
397    pub fn get_keystore_by_identity(
398        &self,
399        key_identity: &KeyIdentity,
400    ) -> Result<&Keystore, anyhow::Error> {
401        if self.config.keystore.get_by_identity(key_identity).is_ok() {
402            return Ok(&self.config.keystore);
403        }
404
405        if let Some(external_keys) = self.config.external_keys.as_ref()
406            && external_keys.get_by_identity(key_identity).is_ok()
407        {
408            return Ok(external_keys);
409        }
410
411        Err(anyhow!(
412            "No keystore found for the provided key identity: {key_identity}"
413        ))
414    }
415
416    pub fn get_keystore_by_identity_mut(
417        &mut self,
418        key_identity: &KeyIdentity,
419    ) -> Result<&mut Keystore, anyhow::Error> {
420        if self.config.keystore.get_by_identity(key_identity).is_ok() {
421            return Ok(&mut self.config.keystore);
422        }
423
424        if let Some(external_keys) = self.config.external_keys.as_mut()
425            && external_keys.get_by_identity(key_identity).is_ok()
426        {
427            return Ok(external_keys);
428        }
429
430        Err(anyhow!(
431            "No keystore found for the provided key identity: {key_identity}"
432        ))
433    }
434
435    pub async fn sign_secure(
436        &self,
437        key_identity: &KeyIdentity,
438        data: &TransactionData,
439        intent: Intent,
440    ) -> Result<Signature, anyhow::Error> {
441        let keystore = self.get_keystore_by_identity(key_identity)?;
442        let sig = keystore.sign_secure(&data.sender(), data, intent).await?;
443        Ok(sig)
444    }
445
446    /// Sign a transaction with a key currently managed by the WalletContext
447    pub async fn sign_transaction(&self, data: &TransactionData) -> Transaction {
448        let sig = self
449            .config
450            .keystore
451            .sign_secure(&data.sender(), data, Intent::sui_transaction())
452            .await
453            .unwrap();
454        // TODO: To support sponsored transaction, we should also look at the gas owner.
455        Transaction::from_data(data.clone(), vec![sig])
456    }
457
458    /// Execute a transaction and wait for it to be locally executed on the fullnode.
459    /// Also expects the effects status to be ExecutionStatus::Success.
460    pub async fn execute_transaction_must_succeed(&self, tx: Transaction) -> ExecutedTransaction {
461        tracing::debug!("Executing transaction: {:?}", tx);
462        let response = self.execute_transaction_may_fail(tx).await.unwrap();
463        assert!(
464            response.effects.status().is_ok(),
465            "Transaction failed: {:?}",
466            response
467        );
468        response
469    }
470
471    /// Execute a transaction and wait for it to be locally executed on the fullnode.
472    /// The transaction execution is not guaranteed to succeed and may fail. This is usually only
473    /// needed in non-test environment or the caller is explicitly testing some failure behavior.
474    pub async fn execute_transaction_may_fail(
475        &self,
476        tx: Transaction,
477    ) -> anyhow::Result<ExecutedTransaction> {
478        self.grpc_client()?
479            .execute_transaction_and_wait_for_checkpoint(&tx)
480            .await
481            .map_err(Into::into)
482    }
483}