1use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, Mutex};
6
7use crate::Task;
8
9#[derive(Debug, Clone)]
10pub enum ProgressSavingPolicy {
11 SaveAfterDuration(SaveAfterDurationPolicy),
12 OutOfOrderSaveAfterDuration(OutOfOrderSaveAfterDurationPolicy),
13}
14
15#[derive(Debug, Clone)]
16pub struct SaveAfterDurationPolicy {
17 duration: tokio::time::Duration,
18 last_save_time: Arc<Mutex<HashMap<String, Option<tokio::time::Instant>>>>,
19}
20
21impl SaveAfterDurationPolicy {
22 pub fn new(duration: tokio::time::Duration) -> Self {
23 Self {
24 duration,
25 last_save_time: Arc::new(Mutex::new(HashMap::new())),
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct OutOfOrderSaveAfterDurationPolicy {
32 duration: tokio::time::Duration,
33 last_save_time: Arc<Mutex<HashMap<String, Option<tokio::time::Instant>>>>,
34 seen: Arc<Mutex<HashMap<String, HashSet<u64>>>>,
35 next_to_fill: Arc<Mutex<HashMap<String, Option<u64>>>>,
36}
37
38impl OutOfOrderSaveAfterDurationPolicy {
39 pub fn new(duration: tokio::time::Duration) -> Self {
40 Self {
41 duration,
42 last_save_time: Arc::new(Mutex::new(HashMap::new())),
43 seen: Arc::new(Mutex::new(HashMap::new())),
44 next_to_fill: Arc::new(Mutex::new(HashMap::new())),
45 }
46 }
47}
48
49impl ProgressSavingPolicy {
50 pub fn cache_progress(&mut self, task: &Task, heights: &[u64]) -> Option<u64> {
52 let task_name = task.task_name.clone();
53 let start_height = task.start_checkpoint;
54 let target_height = task.target_checkpoint;
55 match self {
56 ProgressSavingPolicy::SaveAfterDuration(policy) => {
57 let height = *heights.iter().max().unwrap();
58 let mut last_save_time_guard = policy.last_save_time.lock().unwrap();
59 let last_save_time = last_save_time_guard.entry(task_name).or_insert(None);
60 if height >= target_height {
61 *last_save_time = Some(tokio::time::Instant::now());
62 return Some(height);
63 }
64 if let Some(v) = last_save_time {
65 if v.elapsed() >= policy.duration {
66 *last_save_time = Some(tokio::time::Instant::now());
67 Some(height)
68 } else {
69 None
70 }
71 } else {
72 *last_save_time = Some(tokio::time::Instant::now());
74 None
75 }
76 }
77 ProgressSavingPolicy::OutOfOrderSaveAfterDuration(policy) => {
78 let mut next_to_fill = {
79 let mut next_to_fill_guard = policy.next_to_fill.lock().unwrap();
80 (*next_to_fill_guard
81 .entry(task_name.clone())
82 .or_insert(Some(start_height)))
83 .unwrap()
84 };
85 let old_next_to_fill = next_to_fill;
86 {
87 let mut seen_guard = policy.seen.lock().unwrap();
88 let seen = seen_guard
89 .entry(task_name.clone())
90 .or_insert(HashSet::new());
91 seen.extend(heights.iter().cloned());
92 while seen.remove(&next_to_fill) {
93 next_to_fill += 1;
94 }
95 }
96 if old_next_to_fill != next_to_fill {
98 policy
99 .next_to_fill
100 .lock()
101 .unwrap()
102 .insert(task_name.clone(), Some(next_to_fill));
103 }
104
105 let mut last_save_time_guard = policy.last_save_time.lock().unwrap();
106 let last_save_time = last_save_time_guard
107 .entry(task_name.clone())
108 .or_insert(None);
109
110 if next_to_fill > target_height {
112 *last_save_time = Some(tokio::time::Instant::now());
113 return Some(next_to_fill - 1);
114 }
115 if let Some(v) = last_save_time {
117 if v.elapsed() >= policy.duration && next_to_fill > start_height {
118 *last_save_time = Some(tokio::time::Instant::now());
119 Some(next_to_fill - 1)
120 } else {
121 None
122 }
123 } else {
124 *last_save_time = Some(tokio::time::Instant::now());
126 None
127 }
128 }
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135
136 use super::*;
137
138 #[tokio::test]
139 async fn test_save_after_duration_policy() {
140 let duration = tokio::time::Duration::from_millis(100);
141 let mut policy =
142 ProgressSavingPolicy::SaveAfterDuration(SaveAfterDurationPolicy::new(duration));
143 assert_eq!(
144 policy.cache_progress(&new_test_task("task1", 0, 100), &[1]),
145 None
146 );
147 tokio::time::sleep(duration).await;
148 assert_eq!(
149 policy.cache_progress(&new_test_task("task1", 0, 100), &[2]),
150 Some(2)
151 );
152 tokio::time::sleep(duration).await;
153 assert_eq!(
154 policy.cache_progress(&new_test_task("task1", 0, 100), &[3]),
155 Some(3)
156 );
157
158 assert_eq!(
159 policy.cache_progress(&new_test_task("task2", 0, 100), &[4]),
160 None
161 );
162 tokio::time::sleep(duration).await;
163 assert_eq!(
164 policy.cache_progress(&new_test_task("task2", 0, 100), &[5, 6]),
165 Some(6)
166 );
167 tokio::time::sleep(duration).await;
168 assert_eq!(
169 policy.cache_progress(&new_test_task("task2", 0, 100), &[8, 7]),
170 Some(8)
171 );
172 }
173
174 #[tokio::test]
175 async fn test_out_of_order_save_after_duration_policy() {
176 let duration = tokio::time::Duration::from_millis(100);
177 let mut policy = ProgressSavingPolicy::OutOfOrderSaveAfterDuration(
178 OutOfOrderSaveAfterDurationPolicy::new(duration),
179 );
180
181 assert_eq!(
182 policy.cache_progress(&new_test_task("task1", 0, 100), &[0]),
183 None
184 );
185 tokio::time::sleep(duration).await;
186 assert_eq!(
187 policy.cache_progress(&new_test_task("task1", 0, 100), &[1]),
188 Some(1)
189 );
190 assert_eq!(
191 policy.cache_progress(&new_test_task("task1", 0, 100), &[3]),
192 None
193 );
194 tokio::time::sleep(duration).await;
195 assert_eq!(
196 policy.cache_progress(&new_test_task("task1", 0, 100), &[4]),
197 Some(1)
198 );
199 tokio::time::sleep(duration).await;
200 assert_eq!(
201 policy.cache_progress(&new_test_task("task1", 0, 100), &[2]),
202 Some(4)
203 );
204
205 assert_eq!(
206 policy.cache_progress(&new_test_task("task2", 0, 100), &[0]),
207 None
208 );
209 tokio::time::sleep(duration).await;
210 assert_eq!(
211 policy.cache_progress(&new_test_task("task2", 0, 100), &[1]),
212 Some(1)
213 );
214 tokio::time::sleep(duration).await;
215 assert_eq!(
216 policy.cache_progress(&new_test_task("task2", 0, 100), &[2]),
217 Some(2)
218 );
219 assert_eq!(
220 policy.cache_progress(&new_test_task("task2", 0, 100), &[3]),
221 None
222 );
223 tokio::time::sleep(duration).await;
224 assert_eq!(
225 policy.cache_progress(&new_test_task("task2", 0, 100), &[4]),
226 Some(4)
227 );
228
229 assert_eq!(
230 policy.cache_progress(&new_test_task("task2", 0, 100), &[6, 7, 8]),
231 None
232 );
233 tokio::time::sleep(duration).await;
234 assert_eq!(
235 policy.cache_progress(&new_test_task("task2", 0, 100), &[5, 9]),
236 Some(9)
237 );
238 }
239
240 fn new_test_task(name: &str, start: u64, target: u64) -> Task {
241 Task {
242 task_name: name.to_string(),
243 start_checkpoint: start,
244 target_checkpoint: target,
245 timestamp: 0,
246 is_live_task: false,
247 }
248 }
249}