1use std::{collections::BTreeMap, sync::Arc};
5
6use tokio::{
7 sync::mpsc,
8 task::JoinHandle,
9 time::{interval, MissedTickBehavior},
10};
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, info};
13
14use crate::{
15 metrics::{CheckpointLagMetricReporter, IndexerMetrics},
16 pipeline::{CommitterConfig, IndexedCheckpoint, WatermarkPart},
17};
18
19use super::{BatchedRows, Handler};
20
21struct PendingCheckpoint<H: Handler> {
24 values: Vec<H::Value>,
26 watermark: WatermarkPart,
28}
29
30impl<H: Handler> PendingCheckpoint<H> {
31 fn is_empty(&self) -> bool {
33 let empty = self.values.is_empty();
34 debug_assert!(!empty || self.watermark.batch_rows == 0);
35 empty
36 }
37
38 fn batch_into(&mut self, batch: &mut BatchedRows<H>) {
41 let max_chunk_rows = super::max_chunk_rows::<H>();
42 if batch.values.len() + self.values.len() > max_chunk_rows {
43 let mut for_batch = self.values.split_off(max_chunk_rows - batch.values.len());
44
45 std::mem::swap(&mut self.values, &mut for_batch);
46 batch.watermark.push(self.watermark.take(for_batch.len()));
47 batch.values.extend(for_batch);
48 } else {
49 batch.watermark.push(self.watermark.take(self.values.len()));
50 batch.values.extend(std::mem::take(&mut self.values));
51 }
52 }
53}
54
55impl<H: Handler> From<IndexedCheckpoint<H>> for PendingCheckpoint<H> {
56 fn from(indexed: IndexedCheckpoint<H>) -> Self {
57 Self {
58 watermark: WatermarkPart {
59 watermark: indexed.watermark,
60 batch_rows: indexed.values.len(),
61 total_rows: indexed.values.len(),
62 },
63 values: indexed.values,
64 }
65 }
66}
67
68pub(super) fn collector<H: Handler + 'static>(
83 config: CommitterConfig,
84 mut rx: mpsc::Receiver<IndexedCheckpoint<H>>,
85 tx: mpsc::Sender<BatchedRows<H>>,
86 metrics: Arc<IndexerMetrics>,
87 cancel: CancellationToken,
88) -> JoinHandle<()> {
89 tokio::spawn(async move {
90 let mut poll = interval(config.collect_interval());
93 poll.set_missed_tick_behavior(MissedTickBehavior::Delay);
94
95 let checkpoint_lag_reporter = CheckpointLagMetricReporter::new_for_pipeline::<H>(
96 &metrics.collected_checkpoint_timestamp_lag,
97 &metrics.latest_collected_checkpoint_timestamp_lag_ms,
98 &metrics.latest_collected_checkpoint,
99 );
100
101 let mut pending: BTreeMap<u64, PendingCheckpoint<H>> = BTreeMap::new();
103 let mut pending_rows = 0;
104
105 info!(pipeline = H::NAME, "Starting collector");
106
107 loop {
108 tokio::select! {
109 _ = cancel.cancelled() => {
110 info!(pipeline = H::NAME, "Shutdown received, stopping collector");
111 break;
112 }
113
114 _ = poll.tick() => {
116 let guard = metrics
117 .collector_gather_latency
118 .with_label_values(&[H::NAME])
119 .start_timer();
120
121 let mut batch = BatchedRows::new();
122 while !batch.is_full() {
123 let Some(mut entry) = pending.first_entry() else {
124 break;
125 };
126
127 let indexed = entry.get_mut();
128 indexed.batch_into(&mut batch);
129 if indexed.is_empty() {
130 checkpoint_lag_reporter.report_lag(
131 indexed.watermark.checkpoint(),
132 indexed.watermark.timestamp_ms(),
133 );
134 entry.remove();
135 }
136 }
137
138 pending_rows -= batch.len();
139 let elapsed = guard.stop_and_record();
140 debug!(
141 pipeline = H::NAME,
142 elapsed_ms = elapsed * 1000.0,
143 rows = batch.len(),
144 pending_rows = pending_rows,
145 "Gathered batch",
146 );
147
148 metrics
149 .total_collector_batches_created
150 .with_label_values(&[H::NAME])
151 .inc();
152
153 metrics
154 .collector_batch_size
155 .with_label_values(&[H::NAME])
156 .observe(batch.len() as f64);
157
158 if tx.send(batch).await.is_err() {
159 info!(pipeline = H::NAME, "Committer closed channel, stopping collector");
160 break;
161 }
162
163 if pending_rows > 0 {
164 poll.reset_immediately();
165 } else if rx.is_closed() && rx.is_empty() {
166 info!(
167 pipeline = H::NAME,
168 "Processor closed channel, pending rows empty, stopping collector",
169 );
170 break;
171 }
172 }
173
174 Some(indexed) = rx.recv(), if pending_rows < H::MAX_PENDING_ROWS => {
176 metrics
177 .total_collector_rows_received
178 .with_label_values(&[H::NAME])
179 .inc_by(indexed.len() as u64);
180 metrics
181 .total_collector_checkpoints_received
182 .with_label_values(&[H::NAME])
183 .inc();
184
185 pending_rows += indexed.len();
186 pending.insert(indexed.checkpoint(), indexed.into());
187
188 if pending_rows >= H::MIN_EAGER_ROWS {
189 poll.reset_immediately()
190 }
191 }
192 }
194 }
195 })
196}
197
198#[cfg(test)]
199mod tests {
200 use std::time::Duration;
201
202 use async_trait::async_trait;
203 use sui_pg_db::{Connection, Db};
204 use tokio::sync::mpsc;
205
206 use crate::{
207 metrics::tests::test_metrics,
208 pipeline::{concurrent::max_chunk_rows, Processor},
209 types::full_checkpoint_content::CheckpointData,
210 FieldCount,
211 };
212
213 use super::*;
214
215 #[derive(Clone)]
216 struct Entry;
217
218 impl FieldCount for Entry {
219 const FIELD_COUNT: usize = 32;
221 }
222
223 struct TestHandler;
224
225 #[async_trait]
226 impl Processor for TestHandler {
227 type Value = Entry;
228 const NAME: &'static str = "test_handler";
229 const FANOUT: usize = 1;
230
231 async fn process(
232 &self,
233 _checkpoint: &Arc<CheckpointData>,
234 ) -> anyhow::Result<Vec<Self::Value>> {
235 Ok(vec![])
236 }
237 }
238
239 #[async_trait]
240 impl Handler for TestHandler {
241 type Store = Db;
242
243 const MIN_EAGER_ROWS: usize = 10;
244 const MAX_PENDING_ROWS: usize = 10000;
245 async fn commit<'a>(
246 _values: &[Self::Value],
247 _conn: &mut Connection<'a>,
248 ) -> anyhow::Result<usize> {
249 tokio::time::sleep(Duration::from_millis(1000)).await;
250 Ok(0)
251 }
252 }
253
254 async fn expect_timeout(rx: &mut mpsc::Receiver<BatchedRows<TestHandler>>, duration: Duration) {
256 match tokio::time::timeout(duration, rx.recv()).await {
257 Err(_) => (), Ok(_) => panic!("Expected timeout but received data instead"),
259 }
260 }
261
262 async fn recv_with_timeout(
265 rx: &mut mpsc::Receiver<BatchedRows<TestHandler>>,
266 timeout: Duration,
267 ) -> BatchedRows<TestHandler> {
268 match tokio::time::timeout(timeout, rx.recv()).await {
269 Ok(Some(batch)) => batch,
270 Ok(None) => panic!("Collector channel was closed unexpectedly"),
271 Err(_) => panic!("Test timed out waiting for batch from collector"),
272 }
273 }
274
275 #[tokio::test]
276 async fn test_collector_batches_data() {
277 let (processor_tx, processor_rx) = mpsc::channel(10);
278 let (collector_tx, mut collector_rx) = mpsc::channel(10);
279 let cancel = CancellationToken::new();
280
281 let _collector = collector::<TestHandler>(
282 CommitterConfig::default(),
283 processor_rx,
284 collector_tx,
285 test_metrics(),
286 cancel.clone(),
287 );
288
289 let max_chunk_rows = max_chunk_rows::<TestHandler>();
290 let part1_length = max_chunk_rows / 2;
291 let part2_length = max_chunk_rows - part1_length - 1;
292
293 let test_data = vec![
295 IndexedCheckpoint::new(0, 1, 10, 1000, vec![Entry; part1_length]),
296 IndexedCheckpoint::new(0, 2, 20, 2000, vec![Entry; part2_length]),
297 IndexedCheckpoint::new(0, 3, 30, 3000, vec![Entry, Entry]),
298 ];
299
300 for data in test_data {
301 processor_tx.send(data).await.unwrap();
302 }
303
304 let batch1 = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
305 assert_eq!(batch1.len(), max_chunk_rows);
306
307 let batch2 = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
308 assert_eq!(batch2.len(), 1);
309
310 let batch3 = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
311 assert_eq!(batch3.len(), 0);
312
313 cancel.cancel();
314 }
315
316 #[tokio::test]
317 async fn test_collector_shutdown() {
318 let (processor_tx, processor_rx) = mpsc::channel(10);
319 let (collector_tx, mut collector_rx) = mpsc::channel(10);
320 let cancel = CancellationToken::new();
321
322 let collector = collector::<TestHandler>(
323 CommitterConfig::default(),
324 processor_rx,
325 collector_tx,
326 test_metrics(),
327 cancel.clone(),
328 );
329
330 processor_tx
331 .send(IndexedCheckpoint::new(0, 1, 10, 1000, vec![Entry, Entry]))
332 .await
333 .unwrap();
334
335 tokio::time::sleep(Duration::from_millis(200)).await;
336
337 let batch = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
338 assert_eq!(batch.len(), 2);
339
340 drop(processor_tx);
342
343 let _ = tokio::time::timeout(Duration::from_millis(500), collector)
345 .await
346 .expect("collector did not shutdown");
347
348 cancel.cancel();
349 }
350
351 #[tokio::test]
352 async fn test_collector_respects_max_pending() {
353 let processor_channel_size = 5; let collector_channel_size = 2; let (processor_tx, processor_rx) = mpsc::channel(processor_channel_size);
356 let (collector_tx, _collector_rx) = mpsc::channel(collector_channel_size);
357
358 let metrics = test_metrics();
359 let cancel = CancellationToken::new();
360
361 let _collector = collector::<TestHandler>(
362 CommitterConfig::default(),
363 processor_rx,
364 collector_tx,
365 metrics.clone(),
366 cancel.clone(),
367 );
368
369 let data = IndexedCheckpoint::new(
371 0,
372 1,
373 10,
374 1000,
375 vec![
376 Entry;
377 TestHandler::MAX_PENDING_ROWS
379 + max_chunk_rows::<TestHandler>() * collector_channel_size
380 ],
381 );
382 processor_tx.send(data).await.unwrap();
383
384 tokio::time::sleep(Duration::from_millis(200)).await;
385
386 for _ in 0..processor_channel_size {
388 let more_data = IndexedCheckpoint::new(0, 2, 11, 1000, vec![Entry]);
389 processor_tx.send(more_data).await.unwrap();
390 }
391
392 let even_more_data = IndexedCheckpoint::new(0, 3, 12, 1000, vec![Entry]);
394
395 let send_result = processor_tx.try_send(even_more_data);
396 assert!(matches!(
397 send_result,
398 Err(mpsc::error::TrySendError::Full(_))
399 ));
400
401 cancel.cancel();
402 }
403
404 #[tokio::test]
405 async fn test_collector_accumulates_across_checkpoints_until_eager_threshold() {
406 let (processor_tx, processor_rx) = mpsc::channel(10);
407 let (collector_tx, mut collector_rx) = mpsc::channel(10);
408 let cancel = CancellationToken::new();
409
410 let config = CommitterConfig {
412 collect_interval_ms: 60_000,
413 ..CommitterConfig::default()
414 };
415 let _collector = collector::<TestHandler>(
416 config,
417 processor_rx,
418 collector_tx,
419 test_metrics(),
420 cancel.clone(),
421 );
422
423 let start_time = std::time::Instant::now();
424
425 let initial_batch = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
427 assert_eq!(initial_batch.len(), 0);
428
429 let below_threshold =
431 IndexedCheckpoint::new(0, 1, 10, 1000, vec![Entry; TestHandler::MIN_EAGER_ROWS - 1]);
432 processor_tx.send(below_threshold).await.unwrap();
433
434 expect_timeout(&mut collector_rx, Duration::from_secs(1)).await;
436
437 let threshold_trigger = IndexedCheckpoint::new(
439 0,
440 2,
441 20,
442 2000,
443 vec![Entry; 1], );
445 processor_tx.send(threshold_trigger).await.unwrap();
446
447 let eager_batch = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
449 assert_eq!(eager_batch.len(), TestHandler::MIN_EAGER_ROWS);
450
451 let elapsed = start_time.elapsed();
453 assert!(elapsed < Duration::from_secs(10));
454
455 cancel.cancel();
456 }
457
458 #[tokio::test]
459 async fn test_immediate_batch_on_min_eager_rows() {
460 let (processor_tx, processor_rx) = mpsc::channel(10);
461 let (collector_tx, mut collector_rx) = mpsc::channel(10);
462 let cancel = CancellationToken::new();
463
464 let config = CommitterConfig {
466 collect_interval_ms: 60_000,
467 ..CommitterConfig::default()
468 };
469 let _collector = collector::<TestHandler>(
470 config,
471 processor_rx,
472 collector_tx,
473 test_metrics(),
474 cancel.clone(),
475 );
476
477 let initial_batch = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
479 assert_eq!(initial_batch.len(), 0);
480 expect_timeout(&mut collector_rx, Duration::from_secs(1)).await;
482
483 let start_time = std::time::Instant::now();
484
485 let exact_threshold =
487 IndexedCheckpoint::new(0, 1, 10, 1000, vec![Entry; TestHandler::MIN_EAGER_ROWS]);
488 processor_tx.send(exact_threshold).await.unwrap();
489
490 let batch = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
492 assert_eq!(batch.len(), TestHandler::MIN_EAGER_ROWS);
493
494 let elapsed = start_time.elapsed();
496 assert!(elapsed < Duration::from_secs(10));
497
498 cancel.cancel();
499 }
500
501 #[tokio::test]
502 async fn test_collector_waits_for_timer_when_below_eager_threshold() {
503 let (processor_tx, processor_rx) = mpsc::channel(10);
504 let (collector_tx, mut collector_rx) = mpsc::channel(10);
505 let cancel = CancellationToken::new();
506
507 let config = CommitterConfig {
509 collect_interval_ms: 3000,
510 ..CommitterConfig::default()
511 };
512 let _collector = collector::<TestHandler>(
513 config,
514 processor_rx,
515 collector_tx,
516 test_metrics(),
517 cancel.clone(),
518 );
519
520 let initial_batch = recv_with_timeout(&mut collector_rx, Duration::from_secs(1)).await;
522 assert_eq!(initial_batch.len(), 0);
523
524 let below_threshold =
526 IndexedCheckpoint::new(0, 1, 10, 1000, vec![Entry; TestHandler::MIN_EAGER_ROWS - 1]);
527 processor_tx.send(below_threshold).await.unwrap();
528
529 expect_timeout(&mut collector_rx, Duration::from_secs(1)).await;
531
532 let timer_batch = recv_with_timeout(&mut collector_rx, Duration::from_secs(4)).await;
534 assert_eq!(timer_batch.len(), TestHandler::MIN_EAGER_ROWS - 1);
535
536 cancel.cancel();
537 }
538}