1use std::{sync::Arc, time::Duration};
20
21use consensus_types::block::Round;
22use futures::stream::{FuturesUnordered, StreamExt as _};
23use mysten_common::sync::notify_once::NotifyOnce;
24use mysten_metrics::monitored_scope;
25use parking_lot::RwLock;
26use tokio::{task::JoinHandle, time::MissedTickBehavior};
27
28use crate::{
29 BlockAPI as _, context::Context, core_thread::CoreThreadDispatcher, dag_state::DagState,
30 network::ValidatorNetworkClient, round_tracker::RoundTracker,
31};
32
33pub(crate) struct RoundProberHandle {
35 prober_task: JoinHandle<()>,
36 shutdown_notify: Arc<NotifyOnce>,
37}
38
39impl RoundProberHandle {
40 pub(crate) async fn stop(self) {
41 let _ = self.shutdown_notify.notify();
42 if let Err(e) = self.prober_task.await
44 && e.is_panic()
45 {
46 std::panic::resume_unwind(e.into_panic());
47 }
48 }
49}
50
51pub(crate) struct RoundProber<C: ValidatorNetworkClient> {
52 context: Arc<Context>,
53 core_thread_dispatcher: Arc<dyn CoreThreadDispatcher>,
54 round_tracker: Arc<RwLock<RoundTracker>>,
55 dag_state: Arc<RwLock<DagState>>,
56 network_client: Arc<C>,
57 shutdown_notify: Arc<NotifyOnce>,
58}
59
60impl<C: ValidatorNetworkClient> RoundProber<C> {
61 pub(crate) fn new(
62 context: Arc<Context>,
63 core_thread_dispatcher: Arc<dyn CoreThreadDispatcher>,
64 round_tracker: Arc<RwLock<RoundTracker>>,
65 dag_state: Arc<RwLock<DagState>>,
66 network_client: Arc<C>,
67 ) -> Self {
68 Self {
69 context,
70 core_thread_dispatcher,
71 round_tracker,
72 dag_state,
73 network_client,
74 shutdown_notify: Arc::new(NotifyOnce::new()),
75 }
76 }
77
78 pub(crate) fn start(self) -> RoundProberHandle {
79 let shutdown_notify = self.shutdown_notify.clone();
80 let loop_shutdown_notify = shutdown_notify.clone();
81 let prober_task = tokio::spawn(async move {
82 let mut interval = tokio::time::interval(Duration::from_millis(
86 self.context.parameters.round_prober_interval_ms,
87 ));
88 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
89 loop {
90 tokio::select! {
91 _ = interval.tick() => {
92 self.probe().await;
93 }
94 _ = loop_shutdown_notify.wait() => {
95 break;
96 }
97 }
98 }
99 });
100 RoundProberHandle {
101 prober_task,
102 shutdown_notify,
103 }
104 }
105
106 pub(crate) async fn probe(&self) -> Round {
109 let _scope = monitored_scope("RoundProber");
110
111 let node_metrics = &self.context.metrics.node_metrics;
112 let request_timeout =
113 Duration::from_millis(self.context.parameters.round_prober_request_timeout_ms);
114 let own_index = self.context.own_index;
115 let mut requests = FuturesUnordered::new();
116
117 for (peer, _) in self.context.committee.authorities() {
118 if peer == own_index {
119 continue;
120 }
121 let network_client = self.network_client.clone();
122 requests.push(async move {
123 let result = tokio::time::timeout(
124 request_timeout,
125 network_client.get_latest_rounds(peer, request_timeout),
126 )
127 .await;
128 (peer, result)
129 });
130 }
131
132 let mut highest_received_rounds =
133 vec![vec![0; self.context.committee.size()]; self.context.committee.size()];
134 let mut highest_accepted_rounds =
135 vec![vec![0; self.context.committee.size()]; self.context.committee.size()];
136
137 let blocks = self
138 .dag_state
139 .read()
140 .get_last_cached_block_per_authority(Round::MAX);
141 let local_highest_accepted_rounds = blocks
142 .into_iter()
143 .map(|(block, _)| block.round())
144 .collect::<Vec<_>>();
145 let last_proposed_round = local_highest_accepted_rounds[own_index];
146
147 highest_received_rounds[own_index] =
150 self.round_tracker.read().local_highest_received_rounds();
151 highest_accepted_rounds[own_index] = local_highest_accepted_rounds;
152 highest_received_rounds[own_index][own_index] = last_proposed_round;
153 highest_accepted_rounds[own_index][own_index] = last_proposed_round;
154
155 loop {
156 tokio::select! {
157 result = requests.next() => {
158 let Some((peer, result)) = result else { break };
159 let peer_name = &self.context.committee.authority(peer).hostname;
160 match result {
161 Ok(Ok((received, accepted))) => {
162 if received.len() == self.context.committee.size()
163 {
164 highest_received_rounds[peer] = received;
165 } else {
166 node_metrics.round_prober_request_errors.with_label_values(&["invalid_received_rounds"]).inc();
167 tracing::warn!("Received invalid number of received rounds from peer {}", peer_name);
168 }
169
170 if accepted.len() == self.context.committee.size() {
171 highest_accepted_rounds[peer] = accepted;
172 } else {
173 node_metrics.round_prober_request_errors.with_label_values(&["invalid_accepted_rounds"]).inc();
174 tracing::warn!("Received invalid number of accepted rounds from peer {}", peer_name);
175 }
176 },
177 Ok(Err(err)) => {
188 node_metrics.round_prober_request_errors.with_label_values(&["failed_fetch"]).inc();
189 tracing::debug!("Failed to get latest rounds from peer {}: {:?}", peer_name, err);
190 },
191 Err(_) => {
192 node_metrics.round_prober_request_errors.with_label_values(&["timeout"]).inc();
193 tracing::debug!("Timeout while getting latest rounds from peer {}", peer_name);
194 },
195 }
196 }
197 _ = self.shutdown_notify.wait() => break,
198 }
199 }
200
201 self.round_tracker
202 .write()
203 .update_from_probe(highest_accepted_rounds, highest_received_rounds);
204 let propagation_delay = self
205 .round_tracker
206 .read()
207 .calculate_propagation_delay(last_proposed_round);
208
209 let _ = self
210 .core_thread_dispatcher
211 .set_propagation_delay(propagation_delay);
212
213 propagation_delay
214 }
215}
216
217#[cfg(test)]
218mod test {
219 use std::{collections::BTreeSet, sync::Arc, time::Duration};
220
221 use async_trait::async_trait;
222 use bytes::Bytes;
223 use consensus_config::AuthorityIndex;
224 use consensus_types::block::{BlockRef, Round};
225 use parking_lot::RwLock;
226
227 use crate::{
228 TestBlock, VerifiedBlock,
229 commit::{CertifiedCommits, CommitRange},
230 context::Context,
231 core_thread::{CoreError, CoreThreadDispatcher},
232 dag_state::DagState,
233 error::{ConsensusError, ConsensusResult},
234 network::{BlockStream, ValidatorNetworkClient},
235 round_prober::RoundProber,
236 round_tracker::RoundTracker,
237 storage::mem_store::MemStore,
238 };
239
240 struct FakeThreadDispatcher {}
241
242 impl FakeThreadDispatcher {
243 fn new() -> Self {
244 Self {}
245 }
246 }
247
248 #[async_trait]
249 impl CoreThreadDispatcher for FakeThreadDispatcher {
250 async fn add_blocks(
251 &self,
252 _blocks: Vec<VerifiedBlock>,
253 ) -> Result<BTreeSet<BlockRef>, CoreError> {
254 unimplemented!()
255 }
256
257 async fn check_block_refs(
258 &self,
259 _block_refs: Vec<BlockRef>,
260 ) -> Result<BTreeSet<BlockRef>, CoreError> {
261 unimplemented!()
262 }
263
264 async fn add_certified_commits(
265 &self,
266 _commits: CertifiedCommits,
267 ) -> Result<BTreeSet<BlockRef>, CoreError> {
268 unimplemented!()
269 }
270
271 async fn new_block(&self, _round: Round, _force: bool) -> Result<(), CoreError> {
272 unimplemented!()
273 }
274
275 async fn get_missing_blocks(&self) -> Result<BTreeSet<BlockRef>, CoreError> {
276 unimplemented!()
277 }
278
279 fn set_propagation_delay(&self, _propagation_delay: Round) -> Result<(), CoreError> {
280 Ok(())
281 }
282
283 fn set_last_known_proposed_round(&self, _round: Round) -> Result<(), CoreError> {
284 unimplemented!()
285 }
286 }
287
288 struct FakeNetworkClient {
289 highest_received_rounds: Vec<Vec<Round>>,
290 highest_accepted_rounds: Vec<Vec<Round>>,
291 }
292
293 impl FakeNetworkClient {
294 fn new(
295 highest_received_rounds: Vec<Vec<Round>>,
296 highest_accepted_rounds: Vec<Vec<Round>>,
297 ) -> Self {
298 Self {
299 highest_received_rounds,
300 highest_accepted_rounds,
301 }
302 }
303 }
304
305 #[async_trait]
306 impl ValidatorNetworkClient for FakeNetworkClient {
307 async fn send_block(
308 &self,
309 _peer: AuthorityIndex,
310 _serialized_block: &VerifiedBlock,
311 _timeout: Duration,
312 ) -> ConsensusResult<()> {
313 unimplemented!("Unimplemented")
314 }
315
316 async fn subscribe_blocks(
317 &self,
318 _peer: AuthorityIndex,
319 _last_received: Round,
320 _timeout: Duration,
321 ) -> ConsensusResult<BlockStream> {
322 unimplemented!("Unimplemented")
323 }
324
325 async fn fetch_blocks(
326 &self,
327 _peer: AuthorityIndex,
328 _block_refs: Vec<BlockRef>,
329 _highest_accepted_rounds: Vec<Round>,
330 _breadth_first: bool,
331 _timeout: Duration,
332 ) -> ConsensusResult<Vec<Bytes>> {
333 unimplemented!("Unimplemented")
334 }
335
336 async fn fetch_commits(
337 &self,
338 _peer: AuthorityIndex,
339 _commit_range: CommitRange,
340 _timeout: Duration,
341 ) -> ConsensusResult<(Vec<Bytes>, Vec<Bytes>)> {
342 unimplemented!("Unimplemented")
343 }
344
345 async fn fetch_latest_blocks(
346 &self,
347 _peer: AuthorityIndex,
348 _authorities: Vec<AuthorityIndex>,
349 _timeout: Duration,
350 ) -> ConsensusResult<Vec<Bytes>> {
351 unimplemented!("Unimplemented")
352 }
353
354 async fn get_latest_rounds(
355 &self,
356 peer: AuthorityIndex,
357 _timeout: Duration,
358 ) -> ConsensusResult<(Vec<Round>, Vec<Round>)> {
359 let received_rounds = self.highest_received_rounds[peer].clone();
360 let accepted_rounds = self.highest_accepted_rounds[peer].clone();
361 if received_rounds.is_empty() && accepted_rounds.is_empty() {
362 Err(ConsensusError::NetworkRequestTimeout("test".to_string()))
363 } else {
364 Ok((received_rounds, accepted_rounds))
365 }
366 }
367 }
368
369 #[tokio::test]
370 async fn test_round_prober() {
371 telemetry_subscribers::init_for_testing();
372 const NUM_AUTHORITIES: usize = 7;
373 let context = Arc::new(Context::new_for_test(NUM_AUTHORITIES).0);
374 let core_thread_dispatcher = Arc::new(FakeThreadDispatcher::new());
375 let store = Arc::new(MemStore::new());
376 let dag_state = Arc::new(RwLock::new(DagState::new(context.clone(), store)));
377 let network_client = Arc::new(FakeNetworkClient::new(
379 vec![
380 vec![],
381 vec![109, 121, 131, 0, 151, 161, 171],
382 vec![101, 0, 103, 104, 105, 166, 107],
383 vec![],
384 vec![100, 102, 133, 0, 155, 106, 177],
385 vec![105, 115, 103, 0, 125, 126, 127],
386 vec![10, 20, 30, 40, 50, 60],
387 ], vec![
389 vec![],
390 vec![0, 121, 131, 0, 151, 161, 171],
391 vec![1, 0, 103, 104, 105, 166, 107],
392 vec![],
393 vec![0, 102, 133, 0, 155, 106, 177],
394 vec![1, 115, 103, 0, 125, 126, 127],
395 vec![1, 20, 30, 40, 50, 60],
396 ], ));
398
399 let round_tracker = Arc::new(RwLock::new(RoundTracker::new(
401 context.clone(),
402 vec![110, 120, 130, 140, 150, 160, 170],
403 )));
404 let prober = RoundProber::new(
405 context.clone(),
406 core_thread_dispatcher.clone(),
407 round_tracker.clone(),
408 dag_state.clone(),
409 network_client.clone(),
410 );
411
412 let blocks = (0..NUM_AUTHORITIES)
414 .map(|authority| {
415 let round = 110 + (authority as u32 * 10);
416 VerifiedBlock::new_for_test(TestBlock::new(round, authority as u32).build())
417 })
418 .collect::<Vec<_>>();
419
420 dag_state.write().accept_blocks(blocks);
421
422 let propagation_delay = prober.probe().await;
433
434 assert_eq!(propagation_delay, 10);
435 }
436}