sui_json_rpc/
balance_changes.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::{BTreeMap, HashMap, HashSet};
5use std::ops::Neg;
6
7use async_trait::async_trait;
8use move_core_types::language_storage::TypeTag;
9use tokio::sync::RwLock;
10
11use sui_json_rpc_types::BalanceChange;
12use sui_types::base_types::{ObjectID, ObjectRef, SequenceNumber};
13use sui_types::coin::Coin;
14use sui_types::digests::ObjectDigest;
15use sui_types::effects::{TransactionEffects, TransactionEffectsAPI};
16use sui_types::execution_status::ExecutionStatus;
17use sui_types::gas_coin::GAS;
18use sui_types::object::{Object, Owner};
19use sui_types::storage::WriteKind;
20use sui_types::transaction::InputObjectKind;
21use tracing::instrument;
22
23#[instrument(skip_all, fields(transaction_digest = %effects.transaction_digest()))]
24pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
25    object_provider: &P,
26    effects: &TransactionEffects,
27    input_objs: Vec<InputObjectKind>,
28    mocked_coin: Option<ObjectID>,
29) -> Result<Vec<BalanceChange>, E> {
30    let (_, gas_owner) = effects.gas_object();
31
32    // Only charge gas when tx fails, skip all object parsing
33    if effects.status() != &ExecutionStatus::Success {
34        return Ok(vec![BalanceChange {
35            owner: gas_owner,
36            coin_type: GAS::type_tag(),
37            amount: effects.gas_cost_summary().net_gas_usage().neg() as i128,
38        }]);
39    }
40
41    let all_mutated = effects
42        .all_changed_objects()
43        .into_iter()
44        .filter_map(|((id, version, digest), _, _)| {
45            if matches!(mocked_coin, Some(coin) if id == coin) {
46                return None;
47            }
48            Some((id, version, Some(digest)))
49        })
50        .collect::<Vec<_>>();
51
52    let input_objs_to_digest = input_objs
53        .iter()
54        .filter_map(|k| match k {
55            InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
56            InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
57        })
58        .collect::<HashMap<ObjectID, ObjectDigest>>();
59    let unwrapped_then_deleted = effects
60        .unwrapped_then_deleted()
61        .iter()
62        .map(|e| e.0)
63        .collect::<HashSet<_>>();
64    get_balance_changes(
65        object_provider,
66        &effects
67            .modified_at_versions()
68            .into_iter()
69            .filter_map(|(id, version)| {
70                if matches!(mocked_coin, Some(coin) if id == coin) {
71                    return None;
72                }
73                // We won't be able to get dynamic object from object provider today
74                if unwrapped_then_deleted.contains(&id) {
75                    return None;
76                }
77                Some((id, version, input_objs_to_digest.get(&id).cloned()))
78            })
79            .collect::<Vec<_>>(),
80        &all_mutated,
81    )
82    .await
83}
84
85#[instrument(skip_all)]
86pub async fn get_balance_changes<P: ObjectProvider<Error = E>, E>(
87    object_provider: &P,
88    modified_at_version: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
89    all_mutated: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
90) -> Result<Vec<BalanceChange>, E> {
91    // 1. subtract all input coins
92    let balances = fetch_coins(object_provider, modified_at_version)
93        .await?
94        .into_iter()
95        .fold(
96            BTreeMap::<_, i128>::new(),
97            |mut acc, (owner, type_, amount)| {
98                *acc.entry((owner, type_)).or_default() -= amount as i128;
99                acc
100            },
101        );
102    // 2. add all mutated coins
103    let balances = fetch_coins(object_provider, all_mutated)
104        .await?
105        .into_iter()
106        .fold(balances, |mut acc, (owner, type_, amount)| {
107            *acc.entry((owner, type_)).or_default() += amount as i128;
108            acc
109        });
110
111    Ok(balances
112        .into_iter()
113        .filter_map(|((owner, coin_type), amount)| {
114            if amount == 0 {
115                return None;
116            }
117            Some(BalanceChange {
118                owner,
119                coin_type,
120                amount,
121            })
122        })
123        .collect())
124}
125
126#[instrument(skip_all)]
127async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
128    object_provider: &P,
129    objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
130) -> Result<Vec<(Owner, TypeTag, u64)>, E> {
131    let mut all_mutated_coins = vec![];
132    for (id, version, digest_opt) in objects {
133        // TODO: use multi get object
134        let o = object_provider.get_object(id, version).await?;
135        if let Some(type_) = o.type_()
136            && type_.is_coin()
137        {
138            if let Some(digest) = digest_opt {
139                // TODO: can we return Err here instead?
140                assert_eq!(
141                    *digest,
142                    o.digest(),
143                    "Object digest mismatch--got bad data from object_provider?"
144                )
145            }
146            let [coin_type]: [TypeTag; 1] = type_.clone().into_type_params().try_into().unwrap();
147            all_mutated_coins.push((
148                o.owner.clone(),
149                coin_type,
150                // we know this is a coin, safe to unwrap
151                Coin::extract_balance_if_coin(&o).unwrap().unwrap().1,
152            ))
153        }
154    }
155    Ok(all_mutated_coins)
156}
157
158#[async_trait]
159pub trait ObjectProvider {
160    type Error;
161    async fn get_object(
162        &self,
163        id: &ObjectID,
164        version: &SequenceNumber,
165    ) -> Result<Object, Self::Error>;
166    async fn find_object_lt_or_eq_version(
167        &self,
168        id: &ObjectID,
169        version: &SequenceNumber,
170    ) -> Result<Option<Object>, Self::Error>;
171}
172
173pub struct ObjectProviderCache<P> {
174    object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
175    last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
176    provider: P,
177}
178
179impl<P> ObjectProviderCache<P> {
180    pub fn new(provider: P) -> Self {
181        Self {
182            object_cache: Default::default(),
183            last_version_cache: Default::default(),
184            provider,
185        }
186    }
187
188    pub fn insert_objects_into_cache(&mut self, objects: Vec<Object>) {
189        let object_cache = self.object_cache.get_mut();
190        let last_version_cache = self.last_version_cache.get_mut();
191
192        for object in objects {
193            let object_id = object.id();
194            let version = object.version();
195
196            let key = (object_id, version);
197            object_cache.insert(key, object.clone());
198
199            match last_version_cache.get_mut(&key) {
200                Some(existing_seq_number) => {
201                    if version > *existing_seq_number {
202                        *existing_seq_number = version
203                    }
204                }
205                None => {
206                    last_version_cache.insert(key, version);
207                }
208            }
209        }
210    }
211
212    pub fn new_with_cache(
213        provider: P,
214        written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
215    ) -> Self {
216        let mut object_cache = BTreeMap::new();
217        let mut last_version_cache = BTreeMap::new();
218
219        for (object_id, (object_ref, object, _)) in written_objects {
220            let key = (object_id, object_ref.1);
221            object_cache.insert(key, object.clone());
222
223            match last_version_cache.get_mut(&key) {
224                Some(existing_seq_number) => {
225                    if object_ref.1 > *existing_seq_number {
226                        *existing_seq_number = object_ref.1
227                    }
228                }
229                None => {
230                    last_version_cache.insert(key, object_ref.1);
231                }
232            }
233        }
234
235        Self {
236            object_cache: RwLock::new(object_cache),
237            last_version_cache: RwLock::new(last_version_cache),
238            provider,
239        }
240    }
241}
242
243#[async_trait]
244impl<P, E> ObjectProvider for ObjectProviderCache<P>
245where
246    P: ObjectProvider<Error = E> + Sync + Send,
247    E: Sync + Send,
248{
249    type Error = P::Error;
250
251    async fn get_object(
252        &self,
253        id: &ObjectID,
254        version: &SequenceNumber,
255    ) -> Result<Object, Self::Error> {
256        if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
257            return Ok(o.clone());
258        }
259        let o = self.provider.get_object(id, version).await?;
260        self.object_cache
261            .write()
262            .await
263            .insert((*id, *version), o.clone());
264        Ok(o)
265    }
266
267    async fn find_object_lt_or_eq_version(
268        &self,
269        id: &ObjectID,
270        version: &SequenceNumber,
271    ) -> Result<Option<Object>, Self::Error> {
272        if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
273            return Ok(self.get_object(id, version).await.ok());
274        }
275        if let Some(o) = self
276            .provider
277            .find_object_lt_or_eq_version(id, version)
278            .await?
279        {
280            self.object_cache
281                .write()
282                .await
283                .insert((*id, o.version()), o.clone());
284            self.last_version_cache
285                .write()
286                .await
287                .insert((*id, *version), o.version());
288            Ok(Some(o))
289        } else {
290            Ok(None)
291        }
292    }
293}