1use std::collections::{BTreeMap, HashMap, HashSet};
5
6use async_trait::async_trait;
7use sui_types::balance_change::derive_balance_changes;
8use tokio::sync::RwLock;
9
10use sui_json_rpc_types::BalanceChange;
11use sui_types::base_types::{ObjectID, ObjectRef, SequenceNumber};
12use sui_types::digests::ObjectDigest;
13use sui_types::effects::{TransactionEffects, TransactionEffectsAPI};
14use sui_types::object::{Object, Owner};
15use sui_types::storage::WriteKind;
16use sui_types::transaction::InputObjectKind;
17use tracing::instrument;
18
19#[instrument(skip_all, fields(transaction_digest = %effects.transaction_digest()))]
20pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
21 object_provider: &P,
22 effects: &TransactionEffects,
23 input_objs: Vec<InputObjectKind>,
24 mocked_coin: Option<ObjectID>,
25) -> Result<Vec<BalanceChange>, E> {
26 let all_mutated = effects
27 .all_changed_objects()
28 .into_iter()
29 .filter_map(|((id, version, digest), _, _)| {
30 if matches!(mocked_coin, Some(coin) if id == coin) {
31 return None;
32 }
33 Some((id, version, Some(digest)))
34 })
35 .collect::<Vec<_>>();
36
37 let input_objs_to_digest = input_objs
38 .iter()
39 .filter_map(|k| match k {
40 InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
41 InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
42 })
43 .collect::<HashMap<ObjectID, ObjectDigest>>();
44 let unwrapped_then_deleted = effects
45 .unwrapped_then_deleted()
46 .iter()
47 .map(|e| e.0)
48 .collect::<HashSet<_>>();
49
50 let modified_at_version = effects
51 .modified_at_versions()
52 .into_iter()
53 .filter_map(|(id, version)| {
54 if matches!(mocked_coin, Some(coin) if id == coin) {
55 return None;
56 }
57 if unwrapped_then_deleted.contains(&id) {
59 return None;
60 }
61 Some((id, version, input_objs_to_digest.get(&id).cloned()))
62 })
63 .collect::<Vec<_>>();
64 let input_coins = fetch_coins(object_provider, &modified_at_version).await?;
65 let mutated_coins = fetch_coins(object_provider, &all_mutated).await?;
66 Ok(
67 derive_balance_changes(effects, &input_coins, &mutated_coins)
68 .into_iter()
69 .map(|change| BalanceChange {
70 owner: Owner::AddressOwner(change.address),
71 coin_type: change.coin_type,
72 amount: change.amount,
73 })
74 .collect(),
75 )
76}
77
78#[instrument(skip_all)]
79async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
80 object_provider: &P,
81 objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
82) -> Result<Vec<Object>, E> {
83 let mut coins = vec![];
84 for (id, version, digest_opt) in objects {
85 let o = object_provider.get_object(id, version).await?;
87 if let Some(type_) = o.type_()
88 && type_.is_coin()
89 {
90 if let Some(digest) = digest_opt {
91 assert_eq!(
93 *digest,
94 o.digest(),
95 "Object digest mismatch--got bad data from object_provider?"
96 )
97 }
98 coins.push(o);
99 }
100 }
101 Ok(coins)
102}
103
104#[async_trait]
105pub trait ObjectProvider {
106 type Error;
107 async fn get_object(
108 &self,
109 id: &ObjectID,
110 version: &SequenceNumber,
111 ) -> Result<Object, Self::Error>;
112 async fn find_object_lt_or_eq_version(
113 &self,
114 id: &ObjectID,
115 version: &SequenceNumber,
116 ) -> Result<Option<Object>, Self::Error>;
117}
118
119pub struct ObjectProviderCache<P> {
120 object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
121 last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
122 provider: P,
123}
124
125impl<P> ObjectProviderCache<P> {
126 pub fn new(provider: P) -> Self {
127 Self {
128 object_cache: Default::default(),
129 last_version_cache: Default::default(),
130 provider,
131 }
132 }
133
134 pub fn insert_objects_into_cache(&mut self, objects: Vec<Object>) {
135 let object_cache = self.object_cache.get_mut();
136 let last_version_cache = self.last_version_cache.get_mut();
137
138 for object in objects {
139 let object_id = object.id();
140 let version = object.version();
141
142 let key = (object_id, version);
143 object_cache.insert(key, object.clone());
144
145 match last_version_cache.get_mut(&key) {
146 Some(existing_seq_number) => {
147 if version > *existing_seq_number {
148 *existing_seq_number = version
149 }
150 }
151 None => {
152 last_version_cache.insert(key, version);
153 }
154 }
155 }
156 }
157
158 pub fn new_with_cache(
159 provider: P,
160 written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
161 ) -> Self {
162 let mut object_cache = BTreeMap::new();
163 let mut last_version_cache = BTreeMap::new();
164
165 for (object_id, (object_ref, object, _)) in written_objects {
166 let key = (object_id, object_ref.1);
167 object_cache.insert(key, object.clone());
168
169 match last_version_cache.get_mut(&key) {
170 Some(existing_seq_number) => {
171 if object_ref.1 > *existing_seq_number {
172 *existing_seq_number = object_ref.1
173 }
174 }
175 None => {
176 last_version_cache.insert(key, object_ref.1);
177 }
178 }
179 }
180
181 Self {
182 object_cache: RwLock::new(object_cache),
183 last_version_cache: RwLock::new(last_version_cache),
184 provider,
185 }
186 }
187}
188
189#[async_trait]
190impl<P, E> ObjectProvider for ObjectProviderCache<P>
191where
192 P: ObjectProvider<Error = E> + Sync + Send,
193 E: Sync + Send,
194{
195 type Error = P::Error;
196
197 async fn get_object(
198 &self,
199 id: &ObjectID,
200 version: &SequenceNumber,
201 ) -> Result<Object, Self::Error> {
202 if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
203 return Ok(o.clone());
204 }
205 let o = self.provider.get_object(id, version).await?;
206 self.object_cache
207 .write()
208 .await
209 .insert((*id, *version), o.clone());
210 Ok(o)
211 }
212
213 async fn find_object_lt_or_eq_version(
214 &self,
215 id: &ObjectID,
216 version: &SequenceNumber,
217 ) -> Result<Option<Object>, Self::Error> {
218 if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
219 return Ok(self.get_object(id, version).await.ok());
220 }
221 if let Some(o) = self
222 .provider
223 .find_object_lt_or_eq_version(id, version)
224 .await?
225 {
226 self.object_cache
227 .write()
228 .await
229 .insert((*id, o.version()), o.clone());
230 self.last_version_cache
231 .write()
232 .await
233 .insert((*id, *version), o.version());
234 Ok(Some(o))
235 } else {
236 Ok(None)
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use std::collections::BTreeMap;
245 use sui_types::accumulator_root::AccumulatorValue as AccumulatorValueRoot;
246 use sui_types::balance::Balance;
247 use sui_types::base_types::{ObjectID, SequenceNumber, SuiAddress};
248 use sui_types::digests::TransactionDigest;
249 use sui_types::effects::{
250 AccumulatorAddress, AccumulatorOperation, AccumulatorValue, AccumulatorWriteV1,
251 EffectsObjectChange,
252 };
253 use sui_types::execution_status::{ExecutionFailure, ExecutionFailureStatus, ExecutionStatus};
254 use sui_types::gas::GasCostSummary;
255 use sui_types::gas_coin::GAS;
256 use sui_types::object::MoveObject;
257
258 struct MockObjectProvider {
259 objects: HashMap<(ObjectID, SequenceNumber), Object>,
260 }
261
262 impl MockObjectProvider {
263 fn new() -> Self {
264 Self {
265 objects: HashMap::new(),
266 }
267 }
268
269 fn insert(&mut self, obj: Object) {
270 self.objects.insert((obj.id(), obj.version()), obj);
271 }
272 }
273
274 #[async_trait]
275 impl ObjectProvider for MockObjectProvider {
276 type Error = anyhow::Error;
277 async fn get_object(
278 &self,
279 id: &ObjectID,
280 version: &SequenceNumber,
281 ) -> Result<Object, Self::Error> {
282 self.objects
283 .get(&(*id, *version))
284 .cloned()
285 .ok_or_else(|| anyhow::anyhow!("Object not found: {} v{}", id, version))
286 }
287 async fn find_object_lt_or_eq_version(
288 &self,
289 id: &ObjectID,
290 version: &SequenceNumber,
291 ) -> Result<Option<Object>, Self::Error> {
292 let result = self
293 .objects
294 .iter()
295 .filter(|((oid, v), _)| oid == id && v <= version)
296 .max_by_key(|((_, v), _)| *v)
297 .map(|(_, obj)| obj.clone());
298 Ok(result)
299 }
300 }
301
302 fn create_failed_effects_with_gas_coin(
303 gas_owner: SuiAddress,
304 gas_deduction: u64,
305 ) -> (TransactionEffects, MockObjectProvider) {
306 let gas_id = ObjectID::random();
307 let old_version = SequenceNumber::from_u64(1);
308 let lamport_version = SequenceNumber::from_u64(2);
309 let initial_value = 1_000_000 + gas_deduction;
310 let final_value = 1_000_000u64;
311
312 let input_obj = Object::new_move(
313 MoveObject::new_gas_coin(old_version, gas_id, initial_value),
314 Owner::AddressOwner(gas_owner),
315 TransactionDigest::random(),
316 );
317 let output_obj = Object::new_move(
318 MoveObject::new_gas_coin(lamport_version, gas_id, final_value),
319 Owner::AddressOwner(gas_owner),
320 TransactionDigest::random(),
321 );
322
323 let mut changed_objects = BTreeMap::new();
324 changed_objects.insert(
325 gas_id,
326 EffectsObjectChange::new(
327 Some((
328 (old_version, input_obj.digest()),
329 Owner::AddressOwner(gas_owner),
330 )),
331 Some(&output_obj),
332 false,
333 false,
334 ),
335 );
336
337 let effects = TransactionEffects::new_from_execution_v2(
338 ExecutionStatus::new_failure(ExecutionFailure::new(
339 ExecutionFailureStatus::InsufficientGas,
340 None,
341 )),
342 0,
343 GasCostSummary::new(gas_deduction, 0, 0, 0),
344 vec![],
345 TransactionDigest::random(),
346 lamport_version,
347 changed_objects,
348 Some(gas_id),
349 None,
350 vec![],
351 );
352
353 let mut provider = MockObjectProvider::new();
354 provider.insert(input_obj);
355 provider.insert(output_obj);
356
357 (effects, provider)
358 }
359
360 fn sui_balance_type() -> move_core_types::language_storage::TypeTag {
361 Balance::type_tag("0x2::sui::SUI".parse().unwrap())
362 }
363
364 fn create_accumulator_write(
365 address: SuiAddress,
366 amount: u64,
367 ) -> (ObjectID, EffectsObjectChange) {
368 let balance_type = sui_balance_type();
369 let obj_id = *AccumulatorValueRoot::get_field_id(address, &balance_type)
370 .unwrap()
371 .inner();
372 let write = AccumulatorWriteV1 {
373 address: AccumulatorAddress::new(address, balance_type),
374 operation: AccumulatorOperation::Split,
375 value: AccumulatorValue::Integer(amount),
376 };
377 (
378 obj_id,
379 EffectsObjectChange::new_from_accumulator_write(write),
380 )
381 }
382
383 fn create_failed_effects_with_accumulator(
384 address: SuiAddress,
385 amount: u64,
386 ) -> TransactionEffects {
387 let (obj_id, change) = create_accumulator_write(address, amount);
388 let mut changed_objects = BTreeMap::new();
389 changed_objects.insert(obj_id, change);
390
391 TransactionEffects::new_from_execution_v2(
392 ExecutionStatus::new_failure(ExecutionFailure::new(
393 ExecutionFailureStatus::InsufficientGas,
394 None,
395 )),
396 0,
397 GasCostSummary::new(amount, 0, 0, 0),
398 vec![],
399 TransactionDigest::random(),
400 SequenceNumber::new(),
401 changed_objects,
402 None,
403 None,
404 vec![],
405 )
406 }
407
408 fn create_failed_effects_with_gas_coin_and_accumulator(
409 gas_owner: SuiAddress,
410 gas_deduction: u64,
411 acc_address: SuiAddress,
412 acc_amount: u64,
413 ) -> (TransactionEffects, MockObjectProvider) {
414 let gas_id = ObjectID::random();
415 let old_version = SequenceNumber::from_u64(1);
416 let lamport_version = SequenceNumber::from_u64(2);
417 let initial_value = 1_000_000 + gas_deduction;
418 let final_value = 1_000_000u64;
419
420 let input_obj = Object::new_move(
421 MoveObject::new_gas_coin(old_version, gas_id, initial_value),
422 Owner::AddressOwner(gas_owner),
423 TransactionDigest::random(),
424 );
425 let output_obj = Object::new_move(
426 MoveObject::new_gas_coin(lamport_version, gas_id, final_value),
427 Owner::AddressOwner(gas_owner),
428 TransactionDigest::random(),
429 );
430
431 let mut changed_objects = BTreeMap::new();
432 changed_objects.insert(
433 gas_id,
434 EffectsObjectChange::new(
435 Some((
436 (old_version, input_obj.digest()),
437 Owner::AddressOwner(gas_owner),
438 )),
439 Some(&output_obj),
440 false,
441 false,
442 ),
443 );
444
445 let (acc_obj_id, acc_change) = create_accumulator_write(acc_address, acc_amount);
446 changed_objects.insert(acc_obj_id, acc_change);
447
448 let effects = TransactionEffects::new_from_execution_v2(
449 ExecutionStatus::new_failure(ExecutionFailure::new(
450 ExecutionFailureStatus::InsufficientGas,
451 None,
452 )),
453 0,
454 GasCostSummary::new(gas_deduction, 0, 0, 0),
455 vec![],
456 TransactionDigest::random(),
457 lamport_version,
458 changed_objects,
459 Some(gas_id),
460 None,
461 vec![],
462 );
463
464 let mut provider = MockObjectProvider::new();
465 provider.insert(input_obj);
466 provider.insert(output_obj);
467
468 (effects, provider)
469 }
470
471 #[tokio::test]
472 async fn test_failed_txn_coin_gas_balance_change() {
473 let gas_owner = SuiAddress::random_for_testing_only();
474 let (effects, provider) = create_failed_effects_with_gas_coin(gas_owner, 1000);
475
476 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
477 .await
478 .unwrap();
479
480 assert_eq!(result.len(), 1);
481 assert_eq!(result[0].owner, Owner::AddressOwner(gas_owner));
482 assert_eq!(result[0].coin_type, GAS::type_tag());
483 assert_eq!(result[0].amount, -1000);
484 }
485
486 #[tokio::test]
487 async fn test_failed_txn_address_balance_gas_balance_change() {
488 let address = SuiAddress::random_for_testing_only();
489 let effects = create_failed_effects_with_accumulator(address, 500);
490 let provider = MockObjectProvider::new();
491
492 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
493 .await
494 .unwrap();
495
496 assert_eq!(result.len(), 1);
497 assert_eq!(result[0].owner, Owner::AddressOwner(address));
498 assert_eq!(result[0].amount, -500);
499 }
500
501 #[tokio::test]
502 async fn test_failed_txn_zero_coin_gas_returns_empty() {
503 let gas_owner = SuiAddress::random_for_testing_only();
504 let (effects, provider) = create_failed_effects_with_gas_coin(gas_owner, 0);
505
506 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
507 .await
508 .unwrap();
509
510 assert!(result.is_empty());
511 }
512
513 #[tokio::test]
514 async fn test_failed_txn_zero_gas_no_objects_returns_empty() {
515 let effects = TransactionEffects::new_from_execution_v2(
516 ExecutionStatus::new_failure(ExecutionFailure::new(
517 ExecutionFailureStatus::InsufficientGas,
518 None,
519 )),
520 0,
521 GasCostSummary::new(0, 0, 0, 0),
522 vec![],
523 TransactionDigest::random(),
524 SequenceNumber::new(),
525 BTreeMap::new(),
526 None,
527 None,
528 vec![],
529 );
530 let provider = MockObjectProvider::new();
531
532 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
533 .await
534 .unwrap();
535
536 assert!(result.is_empty());
537 }
538
539 #[tokio::test]
540 async fn test_failed_txn_sponsored_address_balance_gas() {
541 let sponsor = SuiAddress::random_for_testing_only();
542 let effects = create_failed_effects_with_accumulator(sponsor, 750);
543 let provider = MockObjectProvider::new();
544
545 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
546 .await
547 .unwrap();
548
549 assert_eq!(result.len(), 1);
550 assert_eq!(result[0].owner, Owner::AddressOwner(sponsor));
551 assert_eq!(result[0].amount, -750);
552 }
553
554 #[tokio::test]
555 async fn test_failed_txn_coin_gas_and_accumulator_different_addresses() {
556 let gas_owner = SuiAddress::random_for_testing_only();
557 let acc_address = SuiAddress::random_for_testing_only();
558 let (effects, provider) =
559 create_failed_effects_with_gas_coin_and_accumulator(gas_owner, 1000, acc_address, 200);
560
561 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
562 .await
563 .unwrap();
564
565 assert_eq!(result.len(), 2);
566 let gas_change = result
567 .iter()
568 .find(|c| c.owner == Owner::AddressOwner(gas_owner))
569 .unwrap();
570 let acc_change = result
571 .iter()
572 .find(|c| c.owner == Owner::AddressOwner(acc_address))
573 .unwrap();
574 assert_eq!(gas_change.amount, -1000);
575 assert_eq!(acc_change.amount, -200);
576 }
577
578 #[tokio::test]
579 async fn test_failed_txn_coin_gas_and_accumulator_same_address() {
580 let address = SuiAddress::random_for_testing_only();
581 let (effects, provider) =
582 create_failed_effects_with_gas_coin_and_accumulator(address, 1000, address, 200);
583
584 let result = get_balance_changes_from_effect(&provider, &effects, vec![], None)
585 .await
586 .unwrap();
587
588 assert_eq!(result.len(), 1);
589 assert_eq!(result[0].owner, Owner::AddressOwner(address));
590 assert_eq!(result[0].coin_type, GAS::type_tag());
591 assert_eq!(result[0].amount, -1200);
592 }
593}