sui_indexer_alt_framework/ingestion/
regulator.rsuse std::collections::HashMap;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio_util::sync::CancellationToken;
use tracing::info;
pub(super) fn regulator<I>(
checkpoints: I,
buffer_size: usize,
mut ingest_hi_rx: mpsc::UnboundedReceiver<(&'static str, u64)>,
checkpoint_tx: mpsc::Sender<u64>,
cancel: CancellationToken,
) -> JoinHandle<()>
where
I: IntoIterator<Item = u64> + Send + Sync + 'static,
I::IntoIter: Send + Sync + 'static,
{
tokio::spawn(async move {
let mut ingest_hi = None;
let mut subscribers_hi = HashMap::new();
let mut checkpoints = checkpoints.into_iter().peekable();
info!("Starting ingestion regulator");
loop {
let Some(cp) = checkpoints.peek() else {
info!("Checkpoints done, stopping regulator");
break;
};
tokio::select! {
_ = cancel.cancelled() => {
info!("Shutdown received, stopping regulator");
break;
}
Some((name, hi)) = ingest_hi_rx.recv() => {
subscribers_hi.insert(name, hi);
ingest_hi = subscribers_hi.values().copied().min().map(|hi| hi + buffer_size as u64);
}
res = checkpoint_tx.send(*cp), if ingest_hi.is_none_or(|hi| *cp <= hi) => if res.is_ok() {
checkpoints.next();
} else {
info!("Checkpoint channel closed, stopping regulator");
break;
}
}
}
})
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tokio::time::{error::Elapsed, timeout};
use super::*;
async fn expect_recv(rx: &mut mpsc::Receiver<u64>) -> Option<u64> {
timeout(Duration::from_secs(1), rx.recv()).await.unwrap()
}
async fn expect_timeout(rx: &mut mpsc::Receiver<u64>) -> Elapsed {
timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap_err()
}
#[tokio::test]
async fn finite_list_of_checkpoints() {
let (_, hi_rx) = mpsc::unbounded_channel();
let (cp_tx, mut cp_rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
let cps = 0..5;
let h_regulator = regulator(cps, 0, hi_rx, cp_tx, cancel.clone());
for i in 0..5 {
assert_eq!(Some(i), expect_recv(&mut cp_rx).await);
}
h_regulator.await.unwrap();
}
#[tokio::test]
async fn shutdown_on_sender_closed() {
let (_, hi_rx) = mpsc::unbounded_channel();
let (cp_tx, mut cp_rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
let h_regulator = regulator(0.., 0, hi_rx, cp_tx, cancel.clone());
for i in 0..5 {
assert_eq!(Some(i), expect_recv(&mut cp_rx).await);
}
drop(cp_rx);
h_regulator.await.unwrap();
}
#[tokio::test]
async fn shutdown_on_cancel() {
let (_, hi_rx) = mpsc::unbounded_channel();
let (cp_tx, mut cp_rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
let h_regulator = regulator(0.., 0, hi_rx, cp_tx, cancel.clone());
for i in 0..5 {
assert_eq!(Some(i), expect_recv(&mut cp_rx).await);
}
cancel.cancel();
h_regulator.await.unwrap();
}
#[tokio::test]
async fn halted() {
let (hi_tx, hi_rx) = mpsc::unbounded_channel();
let (cp_tx, mut cp_rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
hi_tx.send(("test", 4)).unwrap();
let h_regulator = regulator(0.., 0, hi_rx, cp_tx, cancel.clone());
for _ in 0..=4 {
expect_recv(&mut cp_rx).await;
}
expect_timeout(&mut cp_rx).await;
cancel.cancel();
h_regulator.await.unwrap();
}
#[tokio::test]
async fn halted_buffered() {
let (hi_tx, hi_rx) = mpsc::unbounded_channel();
let (cp_tx, mut cp_rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
hi_tx.send(("test", 2)).unwrap();
let h_regulator = regulator(0.., 2, hi_rx, cp_tx, cancel.clone());
for i in 0..=4 {
assert_eq!(Some(i), expect_recv(&mut cp_rx).await);
}
expect_timeout(&mut cp_rx).await;
cancel.cancel();
h_regulator.await.unwrap();
}
#[tokio::test]
async fn resumption() {
let (hi_tx, hi_rx) = mpsc::unbounded_channel();
let (cp_tx, mut cp_rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
hi_tx.send(("test", 2)).unwrap();
let h_regulator = regulator(0.., 0, hi_rx, cp_tx, cancel.clone());
for i in 0..=2 {
assert_eq!(Some(i), expect_recv(&mut cp_rx).await);
}
expect_timeout(&mut cp_rx).await;
hi_tx.send(("test", 4)).unwrap();
for i in 3..=4 {
assert_eq!(Some(i), expect_recv(&mut cp_rx).await);
}
expect_timeout(&mut cp_rx).await;
cancel.cancel();
h_regulator.await.unwrap();
}
#[tokio::test]
async fn multiple_subscribers() {
let (hi_tx, hi_rx) = mpsc::unbounded_channel();
let (cp_tx, mut cp_rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
hi_tx.send(("a", 2)).unwrap();
hi_tx.send(("b", 3)).unwrap();
let cps = 0..10;
let h_regulator = regulator(cps, 0, hi_rx, cp_tx, cancel.clone());
for i in 0..=2 {
assert_eq!(Some(i), expect_recv(&mut cp_rx).await);
}
expect_timeout(&mut cp_rx).await;
hi_tx.send(("b", 4)).unwrap();
expect_timeout(&mut cp_rx).await;
hi_tx.send(("a", 3)).unwrap();
assert_eq!(Some(3), expect_recv(&mut cp_rx).await);
expect_timeout(&mut cp_rx).await;
hi_tx.send(("a", 4)).unwrap();
assert_eq!(Some(4), expect_recv(&mut cp_rx).await);
hi_tx.send(("a", 5)).unwrap();
expect_timeout(&mut cp_rx).await;
cancel.cancel();
h_regulator.await.unwrap();
}
}