1use std::collections::{BTreeMap, HashMap, HashSet};
5use std::ops::Neg;
6
7use async_trait::async_trait;
8use move_core_types::language_storage::TypeTag;
9use tokio::sync::RwLock;
10
11use sui_json_rpc_types::BalanceChange;
12use sui_types::base_types::{ObjectID, ObjectRef, SequenceNumber};
13use sui_types::coin::Coin;
14use sui_types::digests::ObjectDigest;
15use sui_types::effects::{TransactionEffects, TransactionEffectsAPI};
16use sui_types::execution_status::ExecutionStatus;
17use sui_types::gas_coin::GAS;
18use sui_types::object::{Object, Owner};
19use sui_types::storage::WriteKind;
20use sui_types::transaction::InputObjectKind;
21use tracing::instrument;
22
23#[instrument(skip_all, fields(transaction_digest = %effects.transaction_digest()))]
24pub async fn get_balance_changes_from_effect<P: ObjectProvider<Error = E>, E>(
25 object_provider: &P,
26 effects: &TransactionEffects,
27 input_objs: Vec<InputObjectKind>,
28 mocked_coin: Option<ObjectID>,
29) -> Result<Vec<BalanceChange>, E> {
30 let (_, gas_owner) = effects.gas_object();
31
32 if effects.status() != &ExecutionStatus::Success {
34 return Ok(vec![BalanceChange {
35 owner: gas_owner,
36 coin_type: GAS::type_tag(),
37 amount: effects.gas_cost_summary().net_gas_usage().neg() as i128,
38 }]);
39 }
40
41 let all_mutated = effects
42 .all_changed_objects()
43 .into_iter()
44 .filter_map(|((id, version, digest), _, _)| {
45 if matches!(mocked_coin, Some(coin) if id == coin) {
46 return None;
47 }
48 Some((id, version, Some(digest)))
49 })
50 .collect::<Vec<_>>();
51
52 let input_objs_to_digest = input_objs
53 .iter()
54 .filter_map(|k| match k {
55 InputObjectKind::ImmOrOwnedMoveObject(o) => Some((o.0, o.2)),
56 InputObjectKind::MovePackage(_) | InputObjectKind::SharedMoveObject { .. } => None,
57 })
58 .collect::<HashMap<ObjectID, ObjectDigest>>();
59 let unwrapped_then_deleted = effects
60 .unwrapped_then_deleted()
61 .iter()
62 .map(|e| e.0)
63 .collect::<HashSet<_>>();
64 get_balance_changes(
65 object_provider,
66 &effects
67 .modified_at_versions()
68 .into_iter()
69 .filter_map(|(id, version)| {
70 if matches!(mocked_coin, Some(coin) if id == coin) {
71 return None;
72 }
73 if unwrapped_then_deleted.contains(&id) {
75 return None;
76 }
77 Some((id, version, input_objs_to_digest.get(&id).cloned()))
78 })
79 .collect::<Vec<_>>(),
80 &all_mutated,
81 )
82 .await
83}
84
85#[instrument(skip_all)]
86pub async fn get_balance_changes<P: ObjectProvider<Error = E>, E>(
87 object_provider: &P,
88 modified_at_version: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
89 all_mutated: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
90) -> Result<Vec<BalanceChange>, E> {
91 let balances = fetch_coins(object_provider, modified_at_version)
93 .await?
94 .into_iter()
95 .fold(
96 BTreeMap::<_, i128>::new(),
97 |mut acc, (owner, type_, amount)| {
98 *acc.entry((owner, type_)).or_default() -= amount as i128;
99 acc
100 },
101 );
102 let balances = fetch_coins(object_provider, all_mutated)
104 .await?
105 .into_iter()
106 .fold(balances, |mut acc, (owner, type_, amount)| {
107 *acc.entry((owner, type_)).or_default() += amount as i128;
108 acc
109 });
110
111 Ok(balances
112 .into_iter()
113 .filter_map(|((owner, coin_type), amount)| {
114 if amount == 0 {
115 return None;
116 }
117 Some(BalanceChange {
118 owner,
119 coin_type,
120 amount,
121 })
122 })
123 .collect())
124}
125
126#[instrument(skip_all)]
127async fn fetch_coins<P: ObjectProvider<Error = E>, E>(
128 object_provider: &P,
129 objects: &[(ObjectID, SequenceNumber, Option<ObjectDigest>)],
130) -> Result<Vec<(Owner, TypeTag, u64)>, E> {
131 let mut all_mutated_coins = vec![];
132 for (id, version, digest_opt) in objects {
133 let o = object_provider.get_object(id, version).await?;
135 if let Some(type_) = o.type_()
136 && type_.is_coin()
137 {
138 if let Some(digest) = digest_opt {
139 assert_eq!(
141 *digest,
142 o.digest(),
143 "Object digest mismatch--got bad data from object_provider?"
144 )
145 }
146 let [coin_type]: [TypeTag; 1] = type_.clone().into_type_params().try_into().unwrap();
147 all_mutated_coins.push((
148 o.owner.clone(),
149 coin_type,
150 Coin::extract_balance_if_coin(&o).unwrap().unwrap().1,
152 ))
153 }
154 }
155 Ok(all_mutated_coins)
156}
157
158#[async_trait]
159pub trait ObjectProvider {
160 type Error;
161 async fn get_object(
162 &self,
163 id: &ObjectID,
164 version: &SequenceNumber,
165 ) -> Result<Object, Self::Error>;
166 async fn find_object_lt_or_eq_version(
167 &self,
168 id: &ObjectID,
169 version: &SequenceNumber,
170 ) -> Result<Option<Object>, Self::Error>;
171}
172
173pub struct ObjectProviderCache<P> {
174 object_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), Object>>,
175 last_version_cache: RwLock<BTreeMap<(ObjectID, SequenceNumber), SequenceNumber>>,
176 provider: P,
177}
178
179impl<P> ObjectProviderCache<P> {
180 pub fn new(provider: P) -> Self {
181 Self {
182 object_cache: Default::default(),
183 last_version_cache: Default::default(),
184 provider,
185 }
186 }
187
188 pub fn insert_objects_into_cache(&mut self, objects: Vec<Object>) {
189 let object_cache = self.object_cache.get_mut();
190 let last_version_cache = self.last_version_cache.get_mut();
191
192 for object in objects {
193 let object_id = object.id();
194 let version = object.version();
195
196 let key = (object_id, version);
197 object_cache.insert(key, object.clone());
198
199 match last_version_cache.get_mut(&key) {
200 Some(existing_seq_number) => {
201 if version > *existing_seq_number {
202 *existing_seq_number = version
203 }
204 }
205 None => {
206 last_version_cache.insert(key, version);
207 }
208 }
209 }
210 }
211
212 pub fn new_with_cache(
213 provider: P,
214 written_objects: BTreeMap<ObjectID, (ObjectRef, Object, WriteKind)>,
215 ) -> Self {
216 let mut object_cache = BTreeMap::new();
217 let mut last_version_cache = BTreeMap::new();
218
219 for (object_id, (object_ref, object, _)) in written_objects {
220 let key = (object_id, object_ref.1);
221 object_cache.insert(key, object.clone());
222
223 match last_version_cache.get_mut(&key) {
224 Some(existing_seq_number) => {
225 if object_ref.1 > *existing_seq_number {
226 *existing_seq_number = object_ref.1
227 }
228 }
229 None => {
230 last_version_cache.insert(key, object_ref.1);
231 }
232 }
233 }
234
235 Self {
236 object_cache: RwLock::new(object_cache),
237 last_version_cache: RwLock::new(last_version_cache),
238 provider,
239 }
240 }
241}
242
243#[async_trait]
244impl<P, E> ObjectProvider for ObjectProviderCache<P>
245where
246 P: ObjectProvider<Error = E> + Sync + Send,
247 E: Sync + Send,
248{
249 type Error = P::Error;
250
251 async fn get_object(
252 &self,
253 id: &ObjectID,
254 version: &SequenceNumber,
255 ) -> Result<Object, Self::Error> {
256 if let Some(o) = self.object_cache.read().await.get(&(*id, *version)) {
257 return Ok(o.clone());
258 }
259 let o = self.provider.get_object(id, version).await?;
260 self.object_cache
261 .write()
262 .await
263 .insert((*id, *version), o.clone());
264 Ok(o)
265 }
266
267 async fn find_object_lt_or_eq_version(
268 &self,
269 id: &ObjectID,
270 version: &SequenceNumber,
271 ) -> Result<Option<Object>, Self::Error> {
272 if let Some(version) = self.last_version_cache.read().await.get(&(*id, *version)) {
273 return Ok(self.get_object(id, version).await.ok());
274 }
275 if let Some(o) = self
276 .provider
277 .find_object_lt_or_eq_version(id, version)
278 .await?
279 {
280 self.object_cache
281 .write()
282 .await
283 .insert((*id, o.version()), o.clone());
284 self.last_version_cache
285 .write()
286 .await
287 .insert((*id, *version), o.version());
288 Ok(Some(o))
289 } else {
290 Ok(None)
291 }
292 }
293}