sui_json_rpc/
balance_changes.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::{BTreeMap, HashMap, HashSet};
5
6use async_trait::async_trait;
7use sui_types::balance_change::derive_balance_changes;
8use tokio::sync::RwLock;
9
10use sui_json_rpc_types::BalanceChange;
11use sui_types::base_types::{ObjectID, ObjectRef, SequenceNumber};
12use sui_types::digests::ObjectDigest;
13use sui_types::effects::{TransactionEffects, TransactionEffectsAPI};
14use sui_types::object::{Object, Owner};
15use sui_types::storage::WriteKind;
16use sui_types::transaction::InputObjectKind;
17use tracing::instrument;
18
19#[instrument(skip_all, fields(transaction_digest = %effects.transaction_digest()))]
20pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
21    object_provider: &P,
22    effects: &TransactionEffects,
23    input_objs: Vec<InputObjectKind>,
24    mocked_coin: Option<ObjectID>,
25) -> Result<Vec<BalanceChange>, E> {
26    let all_mutated = effects
27        .all_changed_objects()
28        .into_iter()
29        .filter_map(|((id, version, digest), _, _)| {
30            if matches!(mocked_coin, Some(coin) if id == coin) {
31                return None;
32            }
33            Some((id, version, Some(digest)))
34        })
35        .collect::<Vec<_>>();
36
37    let input_objs_to_digest = input_objs
38        .iter()
39        .filter_map(|k| match k {
40            InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
41            InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
42        })
43        .collect::<HashMap<ObjectID, ObjectDigest>>();
44    let unwrapped_then_deleted = effects
45        .unwrapped_then_deleted()
46        .iter()
47        .map(|e| e.0)
48        .collect::<HashSet<_>>();
49
50    let modified_at_version = effects
51        .modified_at_versions()
52        .into_iter()
53        .filter_map(|(id, version)| {
54            if matches!(mocked_coin, Some(coin) if id == coin) {
55                return None;
56            }
57            // We won't be able to get dynamic object from object provider today
58            if unwrapped_then_deleted.contains(&id) {
59                return None;
60            }
61            Some((id, version, input_objs_to_digest.get(&id).cloned()))
62        })
63        .collect::<Vec<_>>();
64    let input_coins = fetch_coins(object_provider, &modified_at_version).await?;
65    let mutated_coins = fetch_coins(object_provider, &all_mutated).await?;
66    Ok(
67        derive_balance_changes(effects, &input_coins, &mutated_coins)
68            .into_iter()
69            .map(|change| BalanceChange {
70                owner: Owner::AddressOwner(change.address),
71                coin_type: change.coin_type,
72                amount: change.amount,
73            })
74            .collect(),
75    )
76}
77
78#[instrument(skip_all)]
79async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
80    object_provider: &P,
81    objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
82) -> Result<Vec<Object>, E> {
83    let mut coins = vec![];
84    for (id, version, digest_opt) in objects {
85        // TODO: use multi get object
86        let o = object_provider.get_object(id, version).await?;
87        if let Some(type_) = o.type_()
88            && type_.is_coin()
89        {
90            if let Some(digest) = digest_opt {
91                // TODO: can we return Err here instead?
92                assert_eq!(
93                    *digest,
94                    o.digest(),
95                    "Object digest mismatch--got bad data from object_provider?"
96                )
97            }
98            coins.push(o);
99        }
100    }
101    Ok(coins)
102}
103
104#[async_trait]
105pub trait ObjectProvider {
106    type Error;
107    async fn get_object(
108        &self,
109        id: &ObjectID,
110        version: &SequenceNumber,
111    ) -> Result<Object, Self::Error>;
112    async fn find_object_lt_or_eq_version(
113        &self,
114        id: &ObjectID,
115        version: &SequenceNumber,
116    ) -> Result<Option<Object>, Self::Error>;
117}
118
119pub struct ObjectProviderCache<P> {
120    object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
121    last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
122    provider: P,
123}
124
125impl<P> ObjectProviderCache<P> {
126    pub fn new(provider: P) -> Self {
127        Self {
128            object_cache: Default::default(),
129            last_version_cache: Default::default(),
130            provider,
131        }
132    }
133
134    pub fn insert_objects_into_cache(&mut self, objects: Vec<Object>) {
135        let object_cache = self.object_cache.get_mut();
136        let last_version_cache = self.last_version_cache.get_mut();
137
138        for object in objects {
139            let object_id = object.id();
140            let version = object.version();
141
142            let key = (object_id, version);
143            object_cache.insert(key, object.clone());
144
145            match last_version_cache.get_mut(&key) {
146                Some(existing_seq_number) => {
147                    if version > *existing_seq_number {
148                        *existing_seq_number = version
149                    }
150                }
151                None => {
152                    last_version_cache.insert(key, version);
153                }
154            }
155        }
156    }
157
158    pub fn new_with_cache(
159        provider: P,
160        written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
161    ) -> Self {
162        let mut object_cache = BTreeMap::new();
163        let mut last_version_cache = BTreeMap::new();
164
165        for (object_id, (object_ref, object, _)) in written_objects {
166            let key = (object_id, object_ref.1);
167            object_cache.insert(key, object.clone());
168
169            match last_version_cache.get_mut(&key) {
170                Some(existing_seq_number) => {
171                    if object_ref.1 > *existing_seq_number {
172                        *existing_seq_number = object_ref.1
173                    }
174                }
175                None => {
176                    last_version_cache.insert(key, object_ref.1);
177                }
178            }
179        }
180
181        Self {
182            object_cache: RwLock::new(object_cache),
183            last_version_cache: RwLock::new(last_version_cache),
184            provider,
185        }
186    }
187}
188
189#[async_trait]
190impl<P, E> ObjectProvider for ObjectProviderCache<P>
191where
192    P: ObjectProvider<Error = E> + Sync + Send,
193    E: Sync + Send,
194{
195    type Error = P::Error;
196
197    async fn get_object(
198        &self,
199        id: &ObjectID,
200        version: &SequenceNumber,
201    ) -> Result<Object, Self::Error> {
202        if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
203            return Ok(o.clone());
204        }
205        let o = self.provider.get_object(id, version).await?;
206        self.object_cache
207            .write()
208            .await
209            .insert((*id, *version), o.clone());
210        Ok(o)
211    }
212
213    async fn find_object_lt_or_eq_version(
214        &self,
215        id: &ObjectID,
216        version: &SequenceNumber,
217    ) -> Result<Option<Object>, Self::Error> {
218        if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
219            return Ok(self.get_object(id, version).await.ok());
220        }
221        if let Some(o) = self
222            .provider
223            .find_object_lt_or_eq_version(id, version)
224            .await?
225        {
226            self.object_cache
227                .write()
228                .await
229                .insert((*id, o.version()), o.clone());
230            self.last_version_cache
231                .write()
232                .await
233                .insert((*id, *version), o.version());
234            Ok(Some(o))
235        } else {
236            Ok(None)
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use std::collections::BTreeMap;
245    use sui_types::accumulator_root::AccumulatorValue as AccumulatorValueRoot;
246    use sui_types::balance::Balance;
247    use sui_types::base_types::{ObjectID, SequenceNumber, SuiAddress};
248    use sui_types::digests::TransactionDigest;
249    use sui_types::effects::{
250        AccumulatorAddress, AccumulatorOperation, AccumulatorValue, AccumulatorWriteV1,
251        EffectsObjectChange,
252    };
253    use sui_types::execution_status::{ExecutionFailureStatus, ExecutionStatus};
254    use sui_types::gas::GasCostSummary;
255    use sui_types::gas_coin::GAS;
256    use sui_types::object::MoveObject;
257
258    struct MockObjectProvider {
259        objects: HashMap<(ObjectID, SequenceNumber), Object>,
260    }
261
262    impl MockObjectProvider {
263        fn new() -> Self {
264            Self {
265                objects: HashMap::new(),
266            }
267        }
268
269        fn insert(&mut self, obj: Object) {
270            self.objects.insert((obj.id(), obj.version()), obj);
271        }
272    }
273
274    #[async_trait]
275    impl ObjectProvider for MockObjectProvider {
276        type Error = anyhow::Error;
277        async fn get_object(
278            &self,
279            id: &ObjectID,
280            version: &SequenceNumber,
281        ) -> Result<Object, Self::Error> {
282            self.objects
283                .get(&(*id, *version))
284                .cloned()
285                .ok_or_else(|| anyhow::anyhow!("Object not found: {} v{}", id, version))
286        }
287        async fn find_object_lt_or_eq_version(
288            &self,
289            id: &ObjectID,
290            version: &SequenceNumber,
291        ) -> Result<Option<Object>, Self::Error> {
292            let result = self
293                .objects
294                .iter()
295                .filter(|((oid, v), _)| oid == id && v <= version)
296                .max_by_key(|((_, v), _)| *v)
297                .map(|(_, obj)| obj.clone());
298            Ok(result)
299        }
300    }
301
302    fn create_failed_effects_with_gas_coin(
303        gas_owner: SuiAddress,
304        gas_deduction: u64,
305    ) -> (TransactionEffects, MockObjectProvider) {
306        let gas_id = ObjectID::random();
307        let old_version = SequenceNumber::from_u64(1);
308        let lamport_version = SequenceNumber::from_u64(2);
309        let initial_value = 1_000_000 + gas_deduction;
310        let final_value = 1_000_000u64;
311
312        let input_obj = Object::new_move(
313            MoveObject::new_gas_coin(old_version, gas_id, initial_value),
314            Owner::AddressOwner(gas_owner),
315            TransactionDigest::random(),
316        );
317        let output_obj = Object::new_move(
318            MoveObject::new_gas_coin(lamport_version, gas_id, final_value),
319            Owner::AddressOwner(gas_owner),
320            TransactionDigest::random(),
321        );
322
323        let mut changed_objects = BTreeMap::new();
324        changed_objects.insert(
325            gas_id,
326            EffectsObjectChange::new(
327                Some((
328                    (old_version, input_obj.digest()),
329                    Owner::AddressOwner(gas_owner),
330                )),
331                Some(&output_obj),
332                false,
333                false,
334            ),
335        );
336
337        let effects = TransactionEffects::new_from_execution_v2(
338            ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
339            0,
340            GasCostSummary::new(gas_deduction, 0, 0, 0),
341            vec![],
342            std::collections::BTreeSet::new(),
343            TransactionDigest::random(),
344            lamport_version,
345            changed_objects,
346            Some(gas_id),
347            None,
348            vec![],
349        );
350
351        let mut provider = MockObjectProvider::new();
352        provider.insert(input_obj);
353        provider.insert(output_obj);
354
355        (effects, provider)
356    }
357
358    fn sui_balance_type() -> move_core_types::language_storage::TypeTag {
359        Balance::type_tag("0x2::sui::SUI".parse().unwrap())
360    }
361
362    fn create_accumulator_write(
363        address: SuiAddress,
364        amount: u64,
365    ) -> (ObjectID, EffectsObjectChange) {
366        let balance_type = sui_balance_type();
367        let obj_id = *AccumulatorValueRoot::get_field_id(address, &balance_type)
368            .unwrap()
369            .inner();
370        let write = AccumulatorWriteV1 {
371            address: AccumulatorAddress::new(address, balance_type),
372            operation: AccumulatorOperation::Split,
373            value: AccumulatorValue::Integer(amount),
374        };
375        (
376            obj_id,
377            EffectsObjectChange::new_from_accumulator_write(write),
378        )
379    }
380
381    fn create_failed_effects_with_accumulator(
382        address: SuiAddress,
383        amount: u64,
384    ) -> TransactionEffects {
385        let (obj_id, change) = create_accumulator_write(address, amount);
386        let mut changed_objects = BTreeMap::new();
387        changed_objects.insert(obj_id, change);
388
389        TransactionEffects::new_from_execution_v2(
390            ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
391            0,
392            GasCostSummary::new(amount, 0, 0, 0),
393            vec![],
394            std::collections::BTreeSet::new(),
395            TransactionDigest::random(),
396            SequenceNumber::new(),
397            changed_objects,
398            None,
399            None,
400            vec![],
401        )
402    }
403
404    fn create_failed_effects_with_gas_coin_and_accumulator(
405        gas_owner: SuiAddress,
406        gas_deduction: u64,
407        acc_address: SuiAddress,
408        acc_amount: u64,
409    ) -> (TransactionEffects, MockObjectProvider) {
410        let gas_id = ObjectID::random();
411        let old_version = SequenceNumber::from_u64(1);
412        let lamport_version = SequenceNumber::from_u64(2);
413        let initial_value = 1_000_000 + gas_deduction;
414        let final_value = 1_000_000u64;
415
416        let input_obj = Object::new_move(
417            MoveObject::new_gas_coin(old_version, gas_id, initial_value),
418            Owner::AddressOwner(gas_owner),
419            TransactionDigest::random(),
420        );
421        let output_obj = Object::new_move(
422            MoveObject::new_gas_coin(lamport_version, gas_id, final_value),
423            Owner::AddressOwner(gas_owner),
424            TransactionDigest::random(),
425        );
426
427        let mut changed_objects = BTreeMap::new();
428        changed_objects.insert(
429            gas_id,
430            EffectsObjectChange::new(
431                Some((
432                    (old_version, input_obj.digest()),
433                    Owner::AddressOwner(gas_owner),
434                )),
435                Some(&output_obj),
436                false,
437                false,
438            ),
439        );
440
441        let (acc_obj_id, acc_change) = create_accumulator_write(acc_address, acc_amount);
442        changed_objects.insert(acc_obj_id, acc_change);
443
444        let effects = TransactionEffects::new_from_execution_v2(
445            ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
446            0,
447            GasCostSummary::new(gas_deduction, 0, 0, 0),
448            vec![],
449            std::collections::BTreeSet::new(),
450            TransactionDigest::random(),
451            lamport_version,
452            changed_objects,
453            Some(gas_id),
454            None,
455            vec![],
456        );
457
458        let mut provider = MockObjectProvider::new();
459        provider.insert(input_obj);
460        provider.insert(output_obj);
461
462        (effects, provider)
463    }
464
465    #[tokio::test]
466    async fn test_failed_txn_coin_gas_balance_change() {
467        let gas_owner = SuiAddress::random_for_testing_only();
468        let (effects, provider) = create_failed_effects_with_gas_coin(gas_owner, 1000);
469
470        let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
471            .await
472            .unwrap();
473
474        assert_eq!(result.len(), 1);
475        assert_eq!(result[0].owner, Owner::AddressOwner(gas_owner));
476        assert_eq!(result[0].coin_type, GAS::type_tag());
477        assert_eq!(result[0].amount, -1000);
478    }
479
480    #[tokio::test]
481    async fn test_failed_txn_address_balance_gas_balance_change() {
482        let address = SuiAddress::random_for_testing_only();
483        let effects = create_failed_effects_with_accumulator(address, 500);
484        let provider = MockObjectProvider::new();
485
486        let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
487            .await
488            .unwrap();
489
490        assert_eq!(result.len(), 1);
491        assert_eq!(result[0].owner, Owner::AddressOwner(address));
492        assert_eq!(result[0].amount, -500);
493    }
494
495    #[tokio::test]
496    async fn test_failed_txn_zero_coin_gas_returns_empty() {
497        let gas_owner = SuiAddress::random_for_testing_only();
498        let (effects, provider) = create_failed_effects_with_gas_coin(gas_owner, 0);
499
500        let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
501            .await
502            .unwrap();
503
504        assert!(result.is_empty());
505    }
506
507    #[tokio::test]
508    async fn test_failed_txn_zero_gas_no_objects_returns_empty() {
509        let effects = TransactionEffects::new_from_execution_v2(
510            ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
511            0,
512            GasCostSummary::new(0, 0, 0, 0),
513            vec![],
514            std::collections::BTreeSet::new(),
515            TransactionDigest::random(),
516            SequenceNumber::new(),
517            BTreeMap::new(),
518            None,
519            None,
520            vec![],
521        );
522        let provider = MockObjectProvider::new();
523
524        let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
525            .await
526            .unwrap();
527
528        assert!(result.is_empty());
529    }
530
531    #[tokio::test]
532    async fn test_failed_txn_sponsored_address_balance_gas() {
533        let sponsor = SuiAddress::random_for_testing_only();
534        let effects = create_failed_effects_with_accumulator(sponsor, 750);
535        let provider = MockObjectProvider::new();
536
537        let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
538            .await
539            .unwrap();
540
541        assert_eq!(result.len(), 1);
542        assert_eq!(result[0].owner, Owner::AddressOwner(sponsor));
543        assert_eq!(result[0].amount, -750);
544    }
545
546    #[tokio::test]
547    async fn test_failed_txn_coin_gas_and_accumulator_different_addresses() {
548        let gas_owner = SuiAddress::random_for_testing_only();
549        let acc_address = SuiAddress::random_for_testing_only();
550        let (effects, provider) =
551            create_failed_effects_with_gas_coin_and_accumulator(gas_owner, 1000, acc_address, 200);
552
553        let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
554            .await
555            .unwrap();
556
557        assert_eq!(result.len(), 2);
558        let gas_change = result
559            .iter()
560            .find(|c| c.owner == Owner::AddressOwner(gas_owner))
561            .unwrap();
562        let acc_change = result
563            .iter()
564            .find(|c| c.owner == Owner::AddressOwner(acc_address))
565            .unwrap();
566        assert_eq!(gas_change.amount, -1000);
567        assert_eq!(acc_change.amount, -200);
568    }
569
570    #[tokio::test]
571    async fn test_failed_txn_coin_gas_and_accumulator_same_address() {
572        let address = SuiAddress::random_for_testing_only();
573        let (effects, provider) =
574            create_failed_effects_with_gas_coin_and_accumulator(address, 1000, address, 200);
575
576        let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
577            .await
578            .unwrap();
579
580        assert_eq!(result.len(), 1);
581        assert_eq!(result[0].owner, Owner::AddressOwner(address));
582        assert_eq!(result[0].coin_type, GAS::type_tag());
583        assert_eq!(result[0].amount, -1200);
584    }
585}