1use std::{
5 collections::BTreeSet,
6 fmt::Debug,
7 sync::{
8 Arc,
9 atomic::{AtomicU32, Ordering},
10 },
11};
12
13use async_trait::async_trait;
14use consensus_types::block::{BlockRef, Round};
15use mysten_metrics::{
16 monitored_mpsc::{Receiver, Sender, WeakSender, channel},
17 monitored_scope, spawn_logged_monitored_task,
18};
19use parking_lot::RwLock;
20use thiserror::Error;
21use tokio::sync::{oneshot, watch};
22use tracing::warn;
23
24use crate::{
25 BlockAPI as _,
26 block::VerifiedBlock,
27 commit::CertifiedCommits,
28 context::Context,
29 core::Core,
30 core_thread::CoreError::Shutdown,
31 dag_state::DagState,
32 error::{ConsensusError, ConsensusResult},
33};
34
35const CORE_THREAD_COMMANDS_CHANNEL_SIZE: usize = 2000;
36
37enum CoreThreadCommand {
38 AddBlocks(Vec<VerifiedBlock>, oneshot::Sender<BTreeSet<BlockRef>>),
40 CheckBlockRefs(Vec<BlockRef>, oneshot::Sender<BTreeSet<BlockRef>>),
42 AddCertifiedCommits(CertifiedCommits, oneshot::Sender<BTreeSet<BlockRef>>),
46 NewBlock(Round, oneshot::Sender<()>, bool),
50 GetMissing(oneshot::Sender<BTreeSet<BlockRef>>),
52}
53
54#[derive(Error, Debug)]
55pub enum CoreError {
56 #[error("Core thread shutdown: {0}")]
57 Shutdown(String),
58}
59
60#[async_trait]
63pub trait CoreThreadDispatcher: Sync + Send + 'static {
64 async fn add_blocks(&self, blocks: Vec<VerifiedBlock>)
65 -> Result<BTreeSet<BlockRef>, CoreError>;
66
67 async fn check_block_refs(
68 &self,
69 block_refs: Vec<BlockRef>,
70 ) -> Result<BTreeSet<BlockRef>, CoreError>;
71
72 async fn add_certified_commits(
73 &self,
74 commits: CertifiedCommits,
75 ) -> Result<BTreeSet<BlockRef>, CoreError>;
76
77 async fn new_block(&self, round: Round, force: bool) -> Result<(), CoreError>;
78
79 async fn get_missing_blocks(&self) -> Result<BTreeSet<BlockRef>, CoreError>;
80
81 fn set_propagation_delay(&self, delay: Round) -> Result<(), CoreError>;
84
85 fn set_last_known_proposed_round(&self, round: Round) -> Result<(), CoreError>;
86
87 fn highest_received_rounds(&self) -> Vec<Round>;
89}
90
91pub(crate) struct CoreThreadHandle {
92 sender: Sender<CoreThreadCommand>,
93 join_handle: tokio::task::JoinHandle<()>,
94}
95
96impl CoreThreadHandle {
97 pub async fn stop(self) {
98 drop(self.sender);
100 self.join_handle.await.ok();
101 }
102}
103
104struct CoreThread {
105 core: Core,
106 receiver: Receiver<CoreThreadCommand>,
107 rx_propagation_delay: watch::Receiver<Round>,
108 rx_last_known_proposed_round: watch::Receiver<Round>,
109 context: Arc<Context>,
110}
111
112impl CoreThread {
113 pub async fn run(mut self) -> ConsensusResult<()> {
114 tracing::debug!("Started core thread");
115
116 loop {
117 tokio::select! {
118 command = self.receiver.recv() => {
119 let Some(command) = command else {
120 break;
121 };
122 self.context.metrics.node_metrics.core_lock_dequeued.inc();
123 match command {
124 CoreThreadCommand::AddBlocks(blocks, sender) => {
125 let _scope = monitored_scope("CoreThread::loop::add_blocks");
126 let missing_block_refs = self.core.add_blocks(blocks)?;
127 sender.send(missing_block_refs).ok();
128 }
129 CoreThreadCommand::CheckBlockRefs(block_refs, sender) => {
130 let _scope = monitored_scope("CoreThread::loop::check_block_refs");
131 let missing_block_refs = self.core.check_block_refs(block_refs)?;
132 sender.send(missing_block_refs).ok();
133 }
134 CoreThreadCommand::AddCertifiedCommits(commits, sender) => {
135 let _scope = monitored_scope("CoreThread::loop::add_certified_commits");
136 let missing_block_refs = self.core.add_certified_commits(commits)?;
137 sender.send(missing_block_refs).ok();
138 }
139 CoreThreadCommand::NewBlock(round, sender, force) => {
140 let _scope = monitored_scope("CoreThread::loop::new_block");
141 self.core.new_block(round, force)?;
142 sender.send(()).ok();
143 }
144 CoreThreadCommand::GetMissing(sender) => {
145 let _scope = monitored_scope("CoreThread::loop::get_missing");
146 sender.send(self.core.get_missing_blocks()).ok();
147 }
148 }
149 }
150 _ = self.rx_last_known_proposed_round.changed() => {
151 let _scope = monitored_scope("CoreThread::loop::set_last_known_proposed_round");
152 let round = *self.rx_last_known_proposed_round.borrow();
153 self.core.set_last_known_proposed_round(round);
154 self.core.new_block(round + 1, true)?;
155 }
156 _ = self.rx_propagation_delay.changed() => {
157 let _scope = monitored_scope("CoreThread::loop::set_propagation_delay");
158 let should_propose_before = self.core.should_propose();
159 let propagation_delay = *self.rx_propagation_delay.borrow();
160 self.core.set_propagation_delay(
161 propagation_delay
162 );
163 if !should_propose_before && self.core.should_propose() {
164 self.core.new_block(Round::MAX, true)?;
167 }
168 }
169 }
170 }
171
172 Ok(())
173 }
174}
175
176#[derive(Clone)]
177pub(crate) struct ChannelCoreThreadDispatcher {
178 context: Arc<Context>,
179 sender: WeakSender<CoreThreadCommand>,
180 tx_propagation_delay: Arc<watch::Sender<Round>>,
181 tx_last_known_proposed_round: Arc<watch::Sender<Round>>,
182 highest_received_rounds: Arc<Vec<AtomicU32>>,
183}
184
185impl ChannelCoreThreadDispatcher {
186 pub(crate) fn start(
187 context: Arc<Context>,
188 dag_state: &RwLock<DagState>,
189 core: Core,
190 ) -> (Self, CoreThreadHandle) {
191 let highest_received_rounds = {
193 let dag_state = dag_state.read();
194
195 context
196 .committee
197 .authorities()
198 .map(|(index, _)| {
199 AtomicU32::new(dag_state.get_last_block_for_authority(index).round())
200 })
201 .collect()
202 };
203
204 let (sender, receiver) =
205 channel("consensus_core_commands", CORE_THREAD_COMMANDS_CHANNEL_SIZE);
206 let (tx_propagation_delay, mut rx_propagation_delay) = watch::channel(0);
207 let (tx_last_known_proposed_round, mut rx_last_known_proposed_round) = watch::channel(0);
208 rx_propagation_delay.mark_unchanged();
209 rx_last_known_proposed_round.mark_unchanged();
210 let core_thread = CoreThread {
211 core,
212 receiver,
213 rx_propagation_delay,
214 rx_last_known_proposed_round,
215 context: context.clone(),
216 };
217
218 let join_handle = spawn_logged_monitored_task!(
219 async move {
220 if let Err(err) = core_thread.run().await
221 && !matches!(err, ConsensusError::Shutdown)
222 {
223 panic!("Fatal error occurred: {err}");
224 }
225 },
226 "ConsensusCoreThread"
227 );
228
229 let dispatcher = ChannelCoreThreadDispatcher {
232 context,
233 sender: sender.downgrade(),
234 tx_propagation_delay: Arc::new(tx_propagation_delay),
235 tx_last_known_proposed_round: Arc::new(tx_last_known_proposed_round),
236 highest_received_rounds: Arc::new(highest_received_rounds),
237 };
238 let handle = CoreThreadHandle {
239 join_handle,
240 sender,
241 };
242 (dispatcher, handle)
243 }
244
245 async fn send(&self, command: CoreThreadCommand) {
246 self.context.metrics.node_metrics.core_lock_enqueued.inc();
247 if let Some(sender) = self.sender.upgrade()
248 && let Err(err) = sender.send(command).await
249 {
250 warn!(
251 "Couldn't send command to core thread, probably is shutting down: {}",
252 err
253 );
254 }
255 }
256}
257
258#[async_trait]
259impl CoreThreadDispatcher for ChannelCoreThreadDispatcher {
260 async fn add_blocks(
261 &self,
262 blocks: Vec<VerifiedBlock>,
263 ) -> Result<BTreeSet<BlockRef>, CoreError> {
264 for block in &blocks {
265 self.highest_received_rounds[block.author()].fetch_max(block.round(), Ordering::AcqRel);
266 }
267 let (sender, receiver) = oneshot::channel();
268 self.send(CoreThreadCommand::AddBlocks(blocks.clone(), sender))
269 .await;
270 let missing_block_refs = receiver.await.map_err(|e| Shutdown(e.to_string()))?;
271
272 Ok(missing_block_refs)
273 }
274
275 async fn check_block_refs(
276 &self,
277 block_refs: Vec<BlockRef>,
278 ) -> Result<BTreeSet<BlockRef>, CoreError> {
279 let (sender, receiver) = oneshot::channel();
280 self.send(CoreThreadCommand::CheckBlockRefs(
281 block_refs.clone(),
282 sender,
283 ))
284 .await;
285 let missing_block_refs = receiver.await.map_err(|e| Shutdown(e.to_string()))?;
286
287 Ok(missing_block_refs)
288 }
289
290 async fn add_certified_commits(
291 &self,
292 commits: CertifiedCommits,
293 ) -> Result<BTreeSet<BlockRef>, CoreError> {
294 for commit in commits.commits() {
295 for block in commit.blocks() {
296 self.highest_received_rounds[block.author()]
297 .fetch_max(block.round(), Ordering::AcqRel);
298 }
299 }
300 let (sender, receiver) = oneshot::channel();
301 self.send(CoreThreadCommand::AddCertifiedCommits(commits, sender))
302 .await;
303 let missing_block_refs = receiver.await.map_err(|e| Shutdown(e.to_string()))?;
304 Ok(missing_block_refs)
305 }
306
307 async fn new_block(&self, round: Round, force: bool) -> Result<(), CoreError> {
308 let (sender, receiver) = oneshot::channel();
309 self.send(CoreThreadCommand::NewBlock(round, sender, force))
310 .await;
311 receiver.await.map_err(|e| Shutdown(e.to_string()))
312 }
313
314 async fn get_missing_blocks(&self) -> Result<BTreeSet<BlockRef>, CoreError> {
315 let (sender, receiver) = oneshot::channel();
316 self.send(CoreThreadCommand::GetMissing(sender)).await;
317 receiver.await.map_err(|e| Shutdown(e.to_string()))
318 }
319
320 fn set_propagation_delay(&self, propagation_delay: Round) -> Result<(), CoreError> {
321 self.tx_propagation_delay
322 .send(propagation_delay)
323 .map_err(|e| Shutdown(e.to_string()))
324 }
325
326 fn set_last_known_proposed_round(&self, round: Round) -> Result<(), CoreError> {
327 self.tx_last_known_proposed_round
328 .send(round)
329 .map_err(|e| Shutdown(e.to_string()))
330 }
331
332 fn highest_received_rounds(&self) -> Vec<Round> {
333 self.highest_received_rounds
334 .iter()
335 .map(|round| round.load(Ordering::Relaxed))
336 .collect()
337 }
338}
339
340#[cfg(test)]
342#[derive(Default)]
343pub(crate) struct MockCoreThreadDispatcher {
344 add_blocks: parking_lot::Mutex<Vec<VerifiedBlock>>,
345 missing_blocks: parking_lot::Mutex<BTreeSet<BlockRef>>,
346 last_known_proposed_round: parking_lot::Mutex<Vec<Round>>,
347}
348
349#[cfg(test)]
350impl MockCoreThreadDispatcher {
351 #[cfg(test)]
352 pub(crate) async fn get_add_blocks(&self) -> Vec<VerifiedBlock> {
353 let mut add_blocks = self.add_blocks.lock();
354 add_blocks.drain(0..).collect()
355 }
356
357 #[cfg(test)]
358 pub(crate) async fn stub_missing_blocks(&self, block_refs: BTreeSet<BlockRef>) {
359 let mut missing_blocks = self.missing_blocks.lock();
360 missing_blocks.extend(block_refs);
361 }
362
363 #[cfg(test)]
364 pub(crate) async fn get_last_own_proposed_round(&self) -> Vec<Round> {
365 let last_known_proposed_round = self.last_known_proposed_round.lock();
366 last_known_proposed_round.clone()
367 }
368}
369
370#[cfg(test)]
371#[async_trait]
372impl CoreThreadDispatcher for MockCoreThreadDispatcher {
373 async fn add_blocks(
374 &self,
375 blocks: Vec<VerifiedBlock>,
376 ) -> Result<BTreeSet<BlockRef>, CoreError> {
377 let mut add_blocks = self.add_blocks.lock();
378 add_blocks.extend(blocks);
379 Ok(BTreeSet::new())
380 }
381
382 async fn check_block_refs(
383 &self,
384 _block_refs: Vec<BlockRef>,
385 ) -> Result<BTreeSet<BlockRef>, CoreError> {
386 Ok(BTreeSet::new())
387 }
388
389 async fn add_certified_commits(
390 &self,
391 _commits: CertifiedCommits,
392 ) -> Result<BTreeSet<BlockRef>, CoreError> {
393 todo!()
394 }
395
396 async fn new_block(&self, _round: Round, _force: bool) -> Result<(), CoreError> {
397 Ok(())
398 }
399
400 async fn get_missing_blocks(&self) -> Result<BTreeSet<BlockRef>, CoreError> {
401 let mut missing_blocks = self.missing_blocks.lock();
402 let result = missing_blocks.clone();
403 missing_blocks.clear();
404 Ok(result)
405 }
406
407 fn set_propagation_delay(&self, _propagation_delay: Round) -> Result<(), CoreError> {
408 todo!();
409 }
410
411 fn set_last_known_proposed_round(&self, round: Round) -> Result<(), CoreError> {
412 let mut last_known_proposed_round = self.last_known_proposed_round.lock();
413 last_known_proposed_round.push(round);
414 Ok(())
415 }
416
417 fn highest_received_rounds(&self) -> Vec<Round> {
418 todo!()
419 }
420}
421
422#[cfg(test)]
423mod test {
424 use mysten_metrics::monitored_mpsc;
425 use parking_lot::RwLock;
426
427 use super::*;
428 use crate::{
429 CommitConsumerArgs,
430 block_manager::BlockManager,
431 block_verifier::NoopBlockVerifier,
432 commit_observer::CommitObserver,
433 context::Context,
434 core::CoreSignals,
435 dag_state::DagState,
436 leader_schedule::LeaderSchedule,
437 round_tracker::PeerRoundTracker,
438 storage::mem_store::MemStore,
439 transaction::{TransactionClient, TransactionConsumer},
440 transaction_certifier::TransactionCertifier,
441 };
442
443 #[tokio::test]
444 async fn test_core_thread() {
445 telemetry_subscribers::init_for_testing();
446 let (context, mut key_pairs) = Context::new_for_test(4);
447 let context = Arc::new(context);
448 let store = Arc::new(MemStore::new());
449 let dag_state = Arc::new(RwLock::new(DagState::new(context.clone(), store.clone())));
450 let block_manager = BlockManager::new(context.clone(), dag_state.clone());
451 let (_transaction_client, tx_receiver) = TransactionClient::new(context.clone());
452 let transaction_consumer = TransactionConsumer::new(tx_receiver, context.clone());
453 let (blocks_sender, _blocks_receiver) =
454 monitored_mpsc::unbounded_channel("consensus_block_output");
455 let transaction_certifier = TransactionCertifier::new(
456 context.clone(),
457 Arc::new(NoopBlockVerifier {}),
458 dag_state.clone(),
459 blocks_sender,
460 );
461 let (signals, signal_receivers) = CoreSignals::new(context.clone());
462 let _block_receiver = signal_receivers.block_broadcast_receiver();
463 let (commit_consumer, _commit_receiver, _transaction_receiver) =
464 CommitConsumerArgs::new(0, 0);
465 let leader_schedule = Arc::new(LeaderSchedule::from_store(
466 context.clone(),
467 dag_state.clone(),
468 ));
469 let commit_observer = CommitObserver::new(
470 context.clone(),
471 commit_consumer,
472 dag_state.clone(),
473 transaction_certifier.clone(),
474 leader_schedule.clone(),
475 )
476 .await;
477 let leader_schedule = Arc::new(LeaderSchedule::from_store(
478 context.clone(),
479 dag_state.clone(),
480 ));
481 let round_tracker = Arc::new(RwLock::new(PeerRoundTracker::new(context.clone())));
482 let core = Core::new(
483 context.clone(),
484 leader_schedule,
485 transaction_consumer,
486 transaction_certifier,
487 block_manager,
488 commit_observer,
489 signals,
490 key_pairs.remove(context.own_index.value()).1,
491 dag_state.clone(),
492 false,
493 round_tracker,
494 );
495
496 let (core_dispatcher, handle) =
497 ChannelCoreThreadDispatcher::start(context, &dag_state, core);
498
499 let dispatcher_1 = core_dispatcher.clone();
501 let dispatcher_2 = core_dispatcher.clone();
502
503 assert!(dispatcher_1.add_blocks(vec![]).await.is_ok());
505 assert!(dispatcher_2.add_blocks(vec![]).await.is_ok());
506
507 handle.stop().await;
509
510 assert!(dispatcher_1.add_blocks(vec![]).await.is_err());
512 assert!(dispatcher_2.add_blocks(vec![]).await.is_err());
513 }
514}