1use std::collections::{BTreeMap, HashMap, HashSet};
5use std::ops::Neg;
6
7use async_trait::async_trait;
8use sui_types::balance_change::derive_balance_changes;
9use tokio::sync::RwLock;
10
11use sui_json_rpc_types::BalanceChange;
12use sui_types::base_types::{ObjectID, ObjectRef, SequenceNumber};
13use sui_types::digests::ObjectDigest;
14use sui_types::effects::{TransactionEffects, TransactionEffectsAPI};
15use sui_types::execution_status::ExecutionStatus;
16use sui_types::gas_coin::GAS;
17use sui_types::object::Object;
18use sui_types::storage::WriteKind;
19use sui_types::transaction::InputObjectKind;
20use tracing::instrument;
21
22#[instrument(skip_all, fields(transaction_digest = %effects.transaction_digest()))]
23pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
24 object_provider: &P,
25 effects: &TransactionEffects,
26 input_objs: Vec<InputObjectKind>,
27 mocked_coin: Option<ObjectID>,
28) -> Result<Vec<BalanceChange>, E> {
29 let (_, gas_owner) = effects.gas_object();
30
31 if effects.status() != &ExecutionStatus::Success {
33 return Ok(vec![BalanceChange {
34 owner: gas_owner,
35 coin_type: GAS::type_tag(),
36 amount: effects.gas_cost_summary().net_gas_usage().neg() as i128,
37 }]);
38 }
39
40 let all_mutated = effects
41 .all_changed_objects()
42 .into_iter()
43 .filter_map(|((id, version, digest), _, _)| {
44 if matches!(mocked_coin, Some(coin) if id == coin) {
45 return None;
46 }
47 Some((id, version, Some(digest)))
48 })
49 .collect::<Vec<_>>();
50
51 let input_objs_to_digest = input_objs
52 .iter()
53 .filter_map(|k| match k {
54 InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
55 InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
56 })
57 .collect::<HashMap<ObjectID, ObjectDigest>>();
58 let unwrapped_then_deleted = effects
59 .unwrapped_then_deleted()
60 .iter()
61 .map(|e| e.0)
62 .collect::<HashSet<_>>();
63
64 let modified_at_version = effects
65 .modified_at_versions()
66 .into_iter()
67 .filter_map(|(id, version)| {
68 if matches!(mocked_coin, Some(coin) if id == coin) {
69 return None;
70 }
71 if unwrapped_then_deleted.contains(&id) {
73 return None;
74 }
75 Some((id, version, input_objs_to_digest.get(&id).cloned()))
76 })
77 .collect::<Vec<_>>();
78 let input_coins = fetch_coins(object_provider, &modified_at_version).await?;
79 let mutated_coins = fetch_coins(object_provider, &all_mutated).await?;
80 Ok(
81 derive_balance_changes(effects, &input_coins, &mutated_coins)
82 .into_iter()
83 .map(|change| BalanceChange {
84 owner: sui_types::object::Owner::AddressOwner(change.address),
85 coin_type: change.coin_type,
86 amount: change.amount,
87 })
88 .collect(),
89 )
90}
91
92#[instrument(skip_all)]
93async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
94 object_provider: &P,
95 objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
96) -> Result<Vec<Object>, E> {
97 let mut coins = vec![];
98 for (id, version, digest_opt) in objects {
99 let o = object_provider.get_object(id, version).await?;
101 if let Some(type_) = o.type_()
102 && type_.is_coin()
103 {
104 if let Some(digest) = digest_opt {
105 assert_eq!(
107 *digest,
108 o.digest(),
109 "Object digest mismatch--got bad data from object_provider?"
110 )
111 }
112 coins.push(o);
113 }
114 }
115 Ok(coins)
116}
117
118#[async_trait]
119pub trait ObjectProvider {
120 type Error;
121 async fn get_object(
122 &self,
123 id: &ObjectID,
124 version: &SequenceNumber,
125 ) -> Result<Object, Self::Error>;
126 async fn find_object_lt_or_eq_version(
127 &self,
128 id: &ObjectID,
129 version: &SequenceNumber,
130 ) -> Result<Option<Object>, Self::Error>;
131}
132
133pub struct ObjectProviderCache<P> {
134 object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
135 last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
136 provider: P,
137}
138
139impl<P> ObjectProviderCache<P> {
140 pub fn new(provider: P) -> Self {
141 Self {
142 object_cache: Default::default(),
143 last_version_cache: Default::default(),
144 provider,
145 }
146 }
147
148 pub fn insert_objects_into_cache(&mut self, objects: Vec<Object>) {
149 let object_cache = self.object_cache.get_mut();
150 let last_version_cache = self.last_version_cache.get_mut();
151
152 for object in objects {
153 let object_id = object.id();
154 let version = object.version();
155
156 let key = (object_id, version);
157 object_cache.insert(key, object.clone());
158
159 match last_version_cache.get_mut(&key) {
160 Some(existing_seq_number) => {
161 if version > *existing_seq_number {
162 *existing_seq_number = version
163 }
164 }
165 None => {
166 last_version_cache.insert(key, version);
167 }
168 }
169 }
170 }
171
172 pub fn new_with_cache(
173 provider: P,
174 written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
175 ) -> Self {
176 let mut object_cache = BTreeMap::new();
177 let mut last_version_cache = BTreeMap::new();
178
179 for (object_id, (object_ref, object, _)) in written_objects {
180 let key = (object_id, object_ref.1);
181 object_cache.insert(key, object.clone());
182
183 match last_version_cache.get_mut(&key) {
184 Some(existing_seq_number) => {
185 if object_ref.1 > *existing_seq_number {
186 *existing_seq_number = object_ref.1
187 }
188 }
189 None => {
190 last_version_cache.insert(key, object_ref.1);
191 }
192 }
193 }
194
195 Self {
196 object_cache: RwLock::new(object_cache),
197 last_version_cache: RwLock::new(last_version_cache),
198 provider,
199 }
200 }
201}
202
203#[async_trait]
204impl<P, E> ObjectProvider for ObjectProviderCache<P>
205where
206 P: ObjectProvider<Error = E> + Sync + Send,
207 E: Sync + Send,
208{
209 type Error = P::Error;
210
211 async fn get_object(
212 &self,
213 id: &ObjectID,
214 version: &SequenceNumber,
215 ) -> Result<Object, Self::Error> {
216 if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
217 return Ok(o.clone());
218 }
219 let o = self.provider.get_object(id, version).await?;
220 self.object_cache
221 .write()
222 .await
223 .insert((*id, *version), o.clone());
224 Ok(o)
225 }
226
227 async fn find_object_lt_or_eq_version(
228 &self,
229 id: &ObjectID,
230 version: &SequenceNumber,
231 ) -> Result<Option<Object>, Self::Error> {
232 if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
233 return Ok(self.get_object(id, version).await.ok());
234 }
235 if let Some(o) = self
236 .provider
237 .find_object_lt_or_eq_version(id, version)
238 .await?
239 {
240 self.object_cache
241 .write()
242 .await
243 .insert((*id, o.version()), o.clone());
244 self.last_version_cache
245 .write()
246 .await
247 .insert((*id, *version), o.version());
248 Ok(Some(o))
249 } else {
250 Ok(None)
251 }
252 }
253}