sui_json_rpc/
balance_changes.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::{BTreeMap, HashMap, HashSet};
5use std::ops::Neg;
6
7use async_trait::async_trait;
8use sui_types::balance_change::derive_balance_changes;
9use tokio::sync::RwLock;
10
11use sui_json_rpc_types::BalanceChange;
12use sui_types::base_types::{ObjectID, ObjectRef, SequenceNumber};
13use sui_types::digests::ObjectDigest;
14use sui_types::effects::{TransactionEffects, TransactionEffectsAPI};
15use sui_types::execution_status::ExecutionStatus;
16use sui_types::gas_coin::GAS;
17use sui_types::object::Object;
18use sui_types::storage::WriteKind;
19use sui_types::transaction::InputObjectKind;
20use tracing::instrument;
21
22#[instrument(skip_all, fields(transaction_digest = %effects.transaction_digest()))]
23pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
24    object_provider: &P,
25    effects: &TransactionEffects,
26    input_objs: Vec<InputObjectKind>,
27    mocked_coin: Option<ObjectID>,
28) -> Result<Vec<BalanceChange>, E> {
29    let (_, gas_owner) = effects.gas_object();
30
31    // Only charge gas when tx fails, skip all object parsing
32    if effects.status() != &ExecutionStatus::Success {
33        return Ok(vec![BalanceChange {
34            owner: gas_owner,
35            coin_type: GAS::type_tag(),
36            amount: effects.gas_cost_summary().net_gas_usage().neg() as i128,
37        }]);
38    }
39
40    let all_mutated = effects
41        .all_changed_objects()
42        .into_iter()
43        .filter_map(|((id, version, digest), _, _)| {
44            if matches!(mocked_coin, Some(coin) if id == coin) {
45                return None;
46            }
47            Some((id, version, Some(digest)))
48        })
49        .collect::<Vec<_>>();
50
51    let input_objs_to_digest = input_objs
52        .iter()
53        .filter_map(|k| match k {
54            InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
55            InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
56        })
57        .collect::<HashMap<ObjectID, ObjectDigest>>();
58    let unwrapped_then_deleted = effects
59        .unwrapped_then_deleted()
60        .iter()
61        .map(|e| e.0)
62        .collect::<HashSet<_>>();
63
64    let modified_at_version = effects
65        .modified_at_versions()
66        .into_iter()
67        .filter_map(|(id, version)| {
68            if matches!(mocked_coin, Some(coin) if id == coin) {
69                return None;
70            }
71            // We won't be able to get dynamic object from object provider today
72            if unwrapped_then_deleted.contains(&id) {
73                return None;
74            }
75            Some((id, version, input_objs_to_digest.get(&id).cloned()))
76        })
77        .collect::<Vec<_>>();
78    let input_coins = fetch_coins(object_provider, &modified_at_version).await?;
79    let mutated_coins = fetch_coins(object_provider, &all_mutated).await?;
80    Ok(
81        derive_balance_changes(effects, &input_coins, &mutated_coins)
82            .into_iter()
83            .map(|change| BalanceChange {
84                owner: sui_types::object::Owner::AddressOwner(change.address),
85                coin_type: change.coin_type,
86                amount: change.amount,
87            })
88            .collect(),
89    )
90}
91
92#[instrument(skip_all)]
93async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
94    object_provider: &P,
95    objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
96) -> Result<Vec<Object>, E> {
97    let mut coins = vec![];
98    for (id, version, digest_opt) in objects {
99        // TODO: use multi get object
100        let o = object_provider.get_object(id, version).await?;
101        if let Some(type_) = o.type_()
102            && type_.is_coin()
103        {
104            if let Some(digest) = digest_opt {
105                // TODO: can we return Err here instead?
106                assert_eq!(
107                    *digest,
108                    o.digest(),
109                    "Object digest mismatch--got bad data from object_provider?"
110                )
111            }
112            coins.push(o);
113        }
114    }
115    Ok(coins)
116}
117
118#[async_trait]
119pub trait ObjectProvider {
120    type Error;
121    async fn get_object(
122        &self,
123        id: &ObjectID,
124        version: &SequenceNumber,
125    ) -> Result<Object, Self::Error>;
126    async fn find_object_lt_or_eq_version(
127        &self,
128        id: &ObjectID,
129        version: &SequenceNumber,
130    ) -> Result<Option<Object>, Self::Error>;
131}
132
133pub struct ObjectProviderCache<P> {
134    object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
135    last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
136    provider: P,
137}
138
139impl<P> ObjectProviderCache<P> {
140    pub fn new(provider: P) -> Self {
141        Self {
142            object_cache: Default::default(),
143            last_version_cache: Default::default(),
144            provider,
145        }
146    }
147
148    pub fn insert_objects_into_cache(&mut self, objects: Vec<Object>) {
149        let object_cache = self.object_cache.get_mut();
150        let last_version_cache = self.last_version_cache.get_mut();
151
152        for object in objects {
153            let object_id = object.id();
154            let version = object.version();
155
156            let key = (object_id, version);
157            object_cache.insert(key, object.clone());
158
159            match last_version_cache.get_mut(&key) {
160                Some(existing_seq_number) => {
161                    if version > *existing_seq_number {
162                        *existing_seq_number = version
163                    }
164                }
165                None => {
166                    last_version_cache.insert(key, version);
167                }
168            }
169        }
170    }
171
172    pub fn new_with_cache(
173        provider: P,
174        written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
175    ) -> Self {
176        let mut object_cache = BTreeMap::new();
177        let mut last_version_cache = BTreeMap::new();
178
179        for (object_id, (object_ref, object, _)) in written_objects {
180            let key = (object_id, object_ref.1);
181            object_cache.insert(key, object.clone());
182
183            match last_version_cache.get_mut(&key) {
184                Some(existing_seq_number) => {
185                    if object_ref.1 > *existing_seq_number {
186                        *existing_seq_number = object_ref.1
187                    }
188                }
189                None => {
190                    last_version_cache.insert(key, object_ref.1);
191                }
192            }
193        }
194
195        Self {
196            object_cache: RwLock::new(object_cache),
197            last_version_cache: RwLock::new(last_version_cache),
198            provider,
199        }
200    }
201}
202
203#[async_trait]
204impl<P, E> ObjectProvider for ObjectProviderCache<P>
205where
206    P: ObjectProvider<Error = E> + Sync + Send,
207    E: Sync + Send,
208{
209    type Error = P::Error;
210
211    async fn get_object(
212        &self,
213        id: &ObjectID,
214        version: &SequenceNumber,
215    ) -> Result<Object, Self::Error> {
216        if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
217            return Ok(o.clone());
218        }
219        let o = self.provider.get_object(id, version).await?;
220        self.object_cache
221            .write()
222            .await
223            .insert((*id, *version), o.clone());
224        Ok(o)
225    }
226
227    async fn find_object_lt_or_eq_version(
228        &self,
229        id: &ObjectID,
230        version: &SequenceNumber,
231    ) -> Result<Option<Object>, Self::Error> {
232        if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
233            return Ok(self.get_object(id, version).await.ok());
234        }
235        if let Some(o) = self
236            .provider
237            .find_object_lt_or_eq_version(id, version)
238            .await?
239        {
240            self.object_cache
241                .write()
242                .await
243                .insert((*id, o.version()), o.clone());
244            self.last_version_cache
245                .write()
246                .await
247                .insert((*id, *version), o.version());
248            Ok(Some(o))
249        } else {
250            Ok(None)
251        }
252    }
253}