1use std::collections::{BTreeMap, HashMap, HashSet};
5
6use async_trait::async_trait;
7use sui_types::balance_change::derive_balance_changes;
8use tokio::sync::RwLock;
9
10use sui_json_rpc_types::BalanceChange;
11use sui_types::base_types::{ObjectID, ObjectRef, SequenceNumber};
12use sui_types::digests::ObjectDigest;
13use sui_types::effects::{TransactionEffects, TransactionEffectsAPI};
14use sui_types::object::{Object, Owner};
15use sui_types::storage::WriteKind;
16use sui_types::transaction::InputObjectKind;
17use tracing::instrument;
18
19#[instrument(skip_all, fields(transaction_digest = %effects.transaction_digest()))]
20pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
21 object_provider: &P,
22 effects: &TransactionEffects,
23 input_objs: Vec<InputObjectKind>,
24 mocked_coin: Option<ObjectID>,
25) -> Result<Vec<BalanceChange>, E> {
26 let all_mutated = effects
27 .all_changed_objects()
28 .into_iter()
29 .filter_map(|((id, version, digest), _, _)| {
30 if matches!(mocked_coin, Some(coin) if id == coin) {
31 return None;
32 }
33 Some((id, version, Some(digest)))
34 })
35 .collect::<Vec<_>>();
36
37 let input_objs_to_digest = input_objs
38 .iter()
39 .filter_map(|k| match k {
40 InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
41 InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
42 })
43 .collect::<HashMap<ObjectID, ObjectDigest>>();
44 let unwrapped_then_deleted = effects
45 .unwrapped_then_deleted()
46 .iter()
47 .map(|e| e.0)
48 .collect::<HashSet<_>>();
49
50 let modified_at_version = effects
51 .modified_at_versions()
52 .into_iter()
53 .filter_map(|(id, version)| {
54 if matches!(mocked_coin, Some(coin) if id == coin) {
55 return None;
56 }
57 if unwrapped_then_deleted.contains(&id) {
59 return None;
60 }
61 Some((id, version, input_objs_to_digest.get(&id).cloned()))
62 })
63 .collect::<Vec<_>>();
64 let input_coins = fetch_coins(object_provider, &modified_at_version).await?;
65 let mutated_coins = fetch_coins(object_provider, &all_mutated).await?;
66 Ok(
67 derive_balance_changes(effects, &input_coins, &mutated_coins)
68 .into_iter()
69 .map(|change| BalanceChange {
70 owner: Owner::AddressOwner(change.address),
71 coin_type: change.coin_type,
72 amount: change.amount,
73 })
74 .collect(),
75 )
76}
77
78#[instrument(skip_all)]
79async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
80 object_provider: &P,
81 objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
82) -> Result<Vec<Object>, E> {
83 let mut coins = vec![];
84 for (id, version, digest_opt) in objects {
85 let o = object_provider.get_object(id, version).await?;
87 if let Some(type_) = o.type_()
88 && type_.is_coin()
89 {
90 if let Some(digest) = digest_opt {
91 assert_eq!(
93 *digest,
94 o.digest(),
95 "Object digest mismatch--got bad data from object_provider?"
96 )
97 }
98 coins.push(o);
99 }
100 }
101 Ok(coins)
102}
103
104#[async_trait]
105pub trait ObjectProvider {
106 type Error;
107 async fn get_object(
108 &self,
109 id: &ObjectID,
110 version: &SequenceNumber,
111 ) -> Result<Object, Self::Error>;
112 async fn find_object_lt_or_eq_version(
113 &self,
114 id: &ObjectID,
115 version: &SequenceNumber,
116 ) -> Result<Option<Object>, Self::Error>;
117}
118
119pub struct ObjectProviderCache<P> {
120 object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
121 last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
122 provider: P,
123}
124
125impl<P> ObjectProviderCache<P> {
126 pub fn new(provider: P) -> Self {
127 Self {
128 object_cache: Default::default(),
129 last_version_cache: Default::default(),
130 provider,
131 }
132 }
133
134 pub fn insert_objects_into_cache(&mut self, objects: Vec<Object>) {
135 let object_cache = self.object_cache.get_mut();
136 let last_version_cache = self.last_version_cache.get_mut();
137
138 for object in objects {
139 let object_id = object.id();
140 let version = object.version();
141
142 let key = (object_id, version);
143 object_cache.insert(key, object.clone());
144
145 match last_version_cache.get_mut(&key) {
146 Some(existing_seq_number) => {
147 if version > *existing_seq_number {
148 *existing_seq_number = version
149 }
150 }
151 None => {
152 last_version_cache.insert(key, version);
153 }
154 }
155 }
156 }
157
158 pub fn new_with_cache(
159 provider: P,
160 written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
161 ) -> Self {
162 let mut object_cache = BTreeMap::new();
163 let mut last_version_cache = BTreeMap::new();
164
165 for (object_id, (object_ref, object, _)) in written_objects {
166 let key = (object_id, object_ref.1);
167 object_cache.insert(key, object.clone());
168
169 match last_version_cache.get_mut(&key) {
170 Some(existing_seq_number) => {
171 if object_ref.1 > *existing_seq_number {
172 *existing_seq_number = object_ref.1
173 }
174 }
175 None => {
176 last_version_cache.insert(key, object_ref.1);
177 }
178 }
179 }
180
181 Self {
182 object_cache: RwLock::new(object_cache),
183 last_version_cache: RwLock::new(last_version_cache),
184 provider,
185 }
186 }
187}
188
189#[async_trait]
190impl<P, E> ObjectProvider for ObjectProviderCache<P>
191where
192 P: ObjectProvider<Error = E> + Sync + Send,
193 E: Sync + Send,
194{
195 type Error = P::Error;
196
197 async fn get_object(
198 &self,
199 id: &ObjectID,
200 version: &SequenceNumber,
201 ) -> Result<Object, Self::Error> {
202 if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
203 return Ok(o.clone());
204 }
205 let o = self.provider.get_object(id, version).await?;
206 self.object_cache
207 .write()
208 .await
209 .insert((*id, *version), o.clone());
210 Ok(o)
211 }
212
213 async fn find_object_lt_or_eq_version(
214 &self,
215 id: &ObjectID,
216 version: &SequenceNumber,
217 ) -> Result<Option<Object>, Self::Error> {
218 if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
219 return Ok(self.get_object(id, version).await.ok());
220 }
221 if let Some(o) = self
222 .provider
223 .find_object_lt_or_eq_version(id, version)
224 .await?
225 {
226 self.object_cache
227 .write()
228 .await
229 .insert((*id, o.version()), o.clone());
230 self.last_version_cache
231 .write()
232 .await
233 .insert((*id, *version), o.version());
234 Ok(Some(o))
235 } else {
236 Ok(None)
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use std::collections::BTreeMap;
245 use sui_types::accumulator_root::AccumulatorValue as AccumulatorValueRoot;
246 use sui_types::balance::Balance;
247 use sui_types::base_types::{ObjectID, SequenceNumber, SuiAddress};
248 use sui_types::digests::TransactionDigest;
249 use sui_types::effects::{
250 AccumulatorAddress, AccumulatorOperation, AccumulatorValue, AccumulatorWriteV1,
251 EffectsObjectChange,
252 };
253 use sui_types::execution_status::{ExecutionFailureStatus, ExecutionStatus};
254 use sui_types::gas::GasCostSummary;
255 use sui_types::gas_coin::GAS;
256 use sui_types::object::MoveObject;
257
258 struct MockObjectProvider {
259 objects: HashMap<(ObjectID, SequenceNumber), Object>,
260 }
261
262 impl MockObjectProvider {
263 fn new() -> Self {
264 Self {
265 objects: HashMap::new(),
266 }
267 }
268
269 fn insert(&mut self, obj: Object) {
270 self.objects.insert((obj.id(), obj.version()), obj);
271 }
272 }
273
274 #[async_trait]
275 impl ObjectProvider for MockObjectProvider {
276 type Error = anyhow::Error;
277 async fn get_object(
278 &self,
279 id: &ObjectID,
280 version: &SequenceNumber,
281 ) -> Result<Object, Self::Error> {
282 self.objects
283 .get(&(*id, *version))
284 .cloned()
285 .ok_or_else(|| anyhow::anyhow!("Object not found: {} v{}", id, version))
286 }
287 async fn find_object_lt_or_eq_version(
288 &self,
289 id: &ObjectID,
290 version: &SequenceNumber,
291 ) -> Result<Option<Object>, Self::Error> {
292 let result = self
293 .objects
294 .iter()
295 .filter(|((oid, v), _)| oid == id && v <= version)
296 .max_by_key(|((_, v), _)| *v)
297 .map(|(_, obj)| obj.clone());
298 Ok(result)
299 }
300 }
301
302 fn create_failed_effects_with_gas_coin(
303 gas_owner: SuiAddress,
304 gas_deduction: u64,
305 ) -> (TransactionEffects, MockObjectProvider) {
306 let gas_id = ObjectID::random();
307 let old_version = SequenceNumber::from_u64(1);
308 let lamport_version = SequenceNumber::from_u64(2);
309 let initial_value = 1_000_000 + gas_deduction;
310 let final_value = 1_000_000u64;
311
312 let input_obj = Object::new_move(
313 MoveObject::new_gas_coin(old_version, gas_id, initial_value),
314 Owner::AddressOwner(gas_owner),
315 TransactionDigest::random(),
316 );
317 let output_obj = Object::new_move(
318 MoveObject::new_gas_coin(lamport_version, gas_id, final_value),
319 Owner::AddressOwner(gas_owner),
320 TransactionDigest::random(),
321 );
322
323 let mut changed_objects = BTreeMap::new();
324 changed_objects.insert(
325 gas_id,
326 EffectsObjectChange::new(
327 Some((
328 (old_version, input_obj.digest()),
329 Owner::AddressOwner(gas_owner),
330 )),
331 Some(&output_obj),
332 false,
333 false,
334 ),
335 );
336
337 let effects = TransactionEffects::new_from_execution_v2(
338 ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
339 0,
340 GasCostSummary::new(gas_deduction, 0, 0, 0),
341 vec![],
342 std::collections::BTreeSet::new(),
343 TransactionDigest::random(),
344 lamport_version,
345 changed_objects,
346 Some(gas_id),
347 None,
348 vec![],
349 );
350
351 let mut provider = MockObjectProvider::new();
352 provider.insert(input_obj);
353 provider.insert(output_obj);
354
355 (effects, provider)
356 }
357
358 fn sui_balance_type() -> move_core_types::language_storage::TypeTag {
359 Balance::type_tag("0x2::sui::SUI".parse().unwrap())
360 }
361
362 fn create_accumulator_write(
363 address: SuiAddress,
364 amount: u64,
365 ) -> (ObjectID, EffectsObjectChange) {
366 let balance_type = sui_balance_type();
367 let obj_id = *AccumulatorValueRoot::get_field_id(address, &balance_type)
368 .unwrap()
369 .inner();
370 let write = AccumulatorWriteV1 {
371 address: AccumulatorAddress::new(address, balance_type),
372 operation: AccumulatorOperation::Split,
373 value: AccumulatorValue::Integer(amount),
374 };
375 (
376 obj_id,
377 EffectsObjectChange::new_from_accumulator_write(write),
378 )
379 }
380
381 fn create_failed_effects_with_accumulator(
382 address: SuiAddress,
383 amount: u64,
384 ) -> TransactionEffects {
385 let (obj_id, change) = create_accumulator_write(address, amount);
386 let mut changed_objects = BTreeMap::new();
387 changed_objects.insert(obj_id, change);
388
389 TransactionEffects::new_from_execution_v2(
390 ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
391 0,
392 GasCostSummary::new(amount, 0, 0, 0),
393 vec![],
394 std::collections::BTreeSet::new(),
395 TransactionDigest::random(),
396 SequenceNumber::new(),
397 changed_objects,
398 None,
399 None,
400 vec![],
401 )
402 }
403
404 fn create_failed_effects_with_gas_coin_and_accumulator(
405 gas_owner: SuiAddress,
406 gas_deduction: u64,
407 acc_address: SuiAddress,
408 acc_amount: u64,
409 ) -> (TransactionEffects, MockObjectProvider) {
410 let gas_id = ObjectID::random();
411 let old_version = SequenceNumber::from_u64(1);
412 let lamport_version = SequenceNumber::from_u64(2);
413 let initial_value = 1_000_000 + gas_deduction;
414 let final_value = 1_000_000u64;
415
416 let input_obj = Object::new_move(
417 MoveObject::new_gas_coin(old_version, gas_id, initial_value),
418 Owner::AddressOwner(gas_owner),
419 TransactionDigest::random(),
420 );
421 let output_obj = Object::new_move(
422 MoveObject::new_gas_coin(lamport_version, gas_id, final_value),
423 Owner::AddressOwner(gas_owner),
424 TransactionDigest::random(),
425 );
426
427 let mut changed_objects = BTreeMap::new();
428 changed_objects.insert(
429 gas_id,
430 EffectsObjectChange::new(
431 Some((
432 (old_version, input_obj.digest()),
433 Owner::AddressOwner(gas_owner),
434 )),
435 Some(&output_obj),
436 false,
437 false,
438 ),
439 );
440
441 let (acc_obj_id, acc_change) = create_accumulator_write(acc_address, acc_amount);
442 changed_objects.insert(acc_obj_id, acc_change);
443
444 let effects = TransactionEffects::new_from_execution_v2(
445 ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
446 0,
447 GasCostSummary::new(gas_deduction, 0, 0, 0),
448 vec![],
449 std::collections::BTreeSet::new(),
450 TransactionDigest::random(),
451 lamport_version,
452 changed_objects,
453 Some(gas_id),
454 None,
455 vec![],
456 );
457
458 let mut provider = MockObjectProvider::new();
459 provider.insert(input_obj);
460 provider.insert(output_obj);
461
462 (effects, provider)
463 }
464
465 #[tokio::test]
466 async fn test_failed_txn_coin_gas_balance_change() {
467 let gas_owner = SuiAddress::random_for_testing_only();
468 let (effects, provider) = create_failed_effects_with_gas_coin(gas_owner, 1000);
469
470 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
471 .await
472 .unwrap();
473
474 assert_eq!(result.len(), 1);
475 assert_eq!(result[0].owner, Owner::AddressOwner(gas_owner));
476 assert_eq!(result[0].coin_type, GAS::type_tag());
477 assert_eq!(result[0].amount, -1000);
478 }
479
480 #[tokio::test]
481 async fn test_failed_txn_address_balance_gas_balance_change() {
482 let address = SuiAddress::random_for_testing_only();
483 let effects = create_failed_effects_with_accumulator(address, 500);
484 let provider = MockObjectProvider::new();
485
486 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
487 .await
488 .unwrap();
489
490 assert_eq!(result.len(), 1);
491 assert_eq!(result[0].owner, Owner::AddressOwner(address));
492 assert_eq!(result[0].amount, -500);
493 }
494
495 #[tokio::test]
496 async fn test_failed_txn_zero_coin_gas_returns_empty() {
497 let gas_owner = SuiAddress::random_for_testing_only();
498 let (effects, provider) = create_failed_effects_with_gas_coin(gas_owner, 0);
499
500 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
501 .await
502 .unwrap();
503
504 assert!(result.is_empty());
505 }
506
507 #[tokio::test]
508 async fn test_failed_txn_zero_gas_no_objects_returns_empty() {
509 let effects = TransactionEffects::new_from_execution_v2(
510 ExecutionStatus::new_failure(ExecutionFailureStatus::InsufficientGas, None),
511 0,
512 GasCostSummary::new(0, 0, 0, 0),
513 vec![],
514 std::collections::BTreeSet::new(),
515 TransactionDigest::random(),
516 SequenceNumber::new(),
517 BTreeMap::new(),
518 None,
519 None,
520 vec![],
521 );
522 let provider = MockObjectProvider::new();
523
524 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
525 .await
526 .unwrap();
527
528 assert!(result.is_empty());
529 }
530
531 #[tokio::test]
532 async fn test_failed_txn_sponsored_address_balance_gas() {
533 let sponsor = SuiAddress::random_for_testing_only();
534 let effects = create_failed_effects_with_accumulator(sponsor, 750);
535 let provider = MockObjectProvider::new();
536
537 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
538 .await
539 .unwrap();
540
541 assert_eq!(result.len(), 1);
542 assert_eq!(result[0].owner, Owner::AddressOwner(sponsor));
543 assert_eq!(result[0].amount, -750);
544 }
545
546 #[tokio::test]
547 async fn test_failed_txn_coin_gas_and_accumulator_different_addresses() {
548 let gas_owner = SuiAddress::random_for_testing_only();
549 let acc_address = SuiAddress::random_for_testing_only();
550 let (effects, provider) =
551 create_failed_effects_with_gas_coin_and_accumulator(gas_owner, 1000, acc_address, 200);
552
553 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
554 .await
555 .unwrap();
556
557 assert_eq!(result.len(), 2);
558 let gas_change = result
559 .iter()
560 .find(|c| c.owner == Owner::AddressOwner(gas_owner))
561 .unwrap();
562 let acc_change = result
563 .iter()
564 .find(|c| c.owner == Owner::AddressOwner(acc_address))
565 .unwrap();
566 assert_eq!(gas_change.amount, -1000);
567 assert_eq!(acc_change.amount, -200);
568 }
569
570 #[tokio::test]
571 async fn test_failed_txn_coin_gas_and_accumulator_same_address() {
572 let address = SuiAddress::random_for_testing_only();
573 let (effects, provider) =
574 create_failed_effects_with_gas_coin_and_accumulator(address, 1000, address, 200);
575
576 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
577 .await
578 .unwrap();
579
580 assert_eq!(result.len(), 1);
581 assert_eq!(result[0].owner, Owner::AddressOwner(address));
582 assert_eq!(result[0].coin_type, GAS::type_tag());
583 assert_eq!(result[0].amount, -1200);
584 }
585}