1use std::{sync::Arc, time::Duration};
20
21use consensus_types::block::Round;
22use futures::stream::{FuturesUnordered, StreamExt as _};
23use mysten_common::sync::notify_once::NotifyOnce;
24use mysten_metrics::monitored_scope;
25use parking_lot::RwLock;
26use tokio::{task::JoinHandle, time::MissedTickBehavior};
27
28use crate::{
29 BlockAPI as _, context::Context, core_thread::CoreThreadDispatcher, dag_state::DagState,
30 network::NetworkClient, round_tracker::PeerRoundTracker,
31};
32
33pub(crate) struct RoundProberHandle {
35 prober_task: JoinHandle<()>,
36 shutdown_notify: Arc<NotifyOnce>,
37}
38
39impl RoundProberHandle {
40 pub(crate) async fn stop(self) {
41 let _ = self.shutdown_notify.notify();
42 if let Err(e) = self.prober_task.await
44 && e.is_panic()
45 {
46 std::panic::resume_unwind(e.into_panic());
47 }
48 }
49}
50
51pub(crate) struct RoundProber<C: NetworkClient> {
52 context: Arc<Context>,
53 core_thread_dispatcher: Arc<dyn CoreThreadDispatcher>,
54 round_tracker: Arc<RwLock<PeerRoundTracker>>,
55 dag_state: Arc<RwLock<DagState>>,
56 network_client: Arc<C>,
57 shutdown_notify: Arc<NotifyOnce>,
58}
59
60impl<C: NetworkClient> RoundProber<C> {
61 pub(crate) fn new(
62 context: Arc<Context>,
63 core_thread_dispatcher: Arc<dyn CoreThreadDispatcher>,
64 round_tracker: Arc<RwLock<PeerRoundTracker>>,
65 dag_state: Arc<RwLock<DagState>>,
66 network_client: Arc<C>,
67 ) -> Self {
68 Self {
69 context,
70 core_thread_dispatcher,
71 round_tracker,
72 dag_state,
73 network_client,
74 shutdown_notify: Arc::new(NotifyOnce::new()),
75 }
76 }
77
78 pub(crate) fn start(self) -> RoundProberHandle {
79 let shutdown_notify = self.shutdown_notify.clone();
80 let loop_shutdown_notify = shutdown_notify.clone();
81 let prober_task = tokio::spawn(async move {
82 let mut interval = tokio::time::interval(Duration::from_millis(
86 self.context.parameters.round_prober_interval_ms,
87 ));
88 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
89 loop {
90 tokio::select! {
91 _ = interval.tick() => {
92 self.probe().await;
93 }
94 _ = loop_shutdown_notify.wait() => {
95 break;
96 }
97 }
98 }
99 });
100 RoundProberHandle {
101 prober_task,
102 shutdown_notify,
103 }
104 }
105
106 pub(crate) async fn probe(&self) -> Round {
109 let _scope = monitored_scope("RoundProber");
110
111 let node_metrics = &self.context.metrics.node_metrics;
112 let request_timeout =
113 Duration::from_millis(self.context.parameters.round_prober_request_timeout_ms);
114 let own_index = self.context.own_index;
115 let mut requests = FuturesUnordered::new();
116
117 for (peer, _) in self.context.committee.authorities() {
118 if peer == own_index {
119 continue;
120 }
121 let network_client = self.network_client.clone();
122 requests.push(async move {
123 let result = tokio::time::timeout(
124 request_timeout,
125 network_client.get_latest_rounds(peer, request_timeout),
126 )
127 .await;
128 (peer, result)
129 });
130 }
131
132 let mut highest_received_rounds =
133 vec![vec![0; self.context.committee.size()]; self.context.committee.size()];
134 let mut highest_accepted_rounds =
135 vec![vec![0; self.context.committee.size()]; self.context.committee.size()];
136
137 let blocks = self
138 .dag_state
139 .read()
140 .get_last_cached_block_per_authority(Round::MAX);
141 let local_highest_accepted_rounds = blocks
142 .into_iter()
143 .map(|(block, _)| block.round())
144 .collect::<Vec<_>>();
145 let last_proposed_round = local_highest_accepted_rounds[own_index];
146
147 highest_received_rounds[own_index] = self.core_thread_dispatcher.highest_received_rounds();
150 highest_accepted_rounds[own_index] = local_highest_accepted_rounds;
151 highest_received_rounds[own_index][own_index] = last_proposed_round;
152 highest_accepted_rounds[own_index][own_index] = last_proposed_round;
153
154 loop {
155 tokio::select! {
156 result = requests.next() => {
157 let Some((peer, result)) = result else { break };
158 let peer_name = &self.context.committee.authority(peer).hostname;
159 match result {
160 Ok(Ok((received, accepted))) => {
161 if received.len() == self.context.committee.size()
162 {
163 highest_received_rounds[peer] = received;
164 } else {
165 node_metrics.round_prober_request_errors.with_label_values(&["invalid_received_rounds"]).inc();
166 tracing::warn!("Received invalid number of received rounds from peer {}", peer_name);
167 }
168
169 if accepted.len() == self.context.committee.size() {
170 highest_accepted_rounds[peer] = accepted;
171 } else {
172 node_metrics.round_prober_request_errors.with_label_values(&["invalid_accepted_rounds"]).inc();
173 tracing::warn!("Received invalid number of accepted rounds from peer {}", peer_name);
174 }
175 },
176 Ok(Err(err)) => {
187 node_metrics.round_prober_request_errors.with_label_values(&["failed_fetch"]).inc();
188 tracing::debug!("Failed to get latest rounds from peer {}: {:?}", peer_name, err);
189 },
190 Err(_) => {
191 node_metrics.round_prober_request_errors.with_label_values(&["timeout"]).inc();
192 tracing::debug!("Timeout while getting latest rounds from peer {}", peer_name);
193 },
194 }
195 }
196 _ = self.shutdown_notify.wait() => break,
197 }
198 }
199
200 self.round_tracker
201 .write()
202 .update_from_probe(highest_accepted_rounds, highest_received_rounds);
203 let propagation_delay = self
204 .round_tracker
205 .read()
206 .calculate_propagation_delay(last_proposed_round);
207
208 let _ = self
209 .core_thread_dispatcher
210 .set_propagation_delay(propagation_delay);
211
212 propagation_delay
213 }
214}
215
216#[cfg(test)]
217mod test {
218 use std::{collections::BTreeSet, sync::Arc, time::Duration};
219
220 use async_trait::async_trait;
221 use bytes::Bytes;
222 use consensus_config::AuthorityIndex;
223 use consensus_types::block::{BlockRef, Round};
224 use parking_lot::RwLock;
225
226 use crate::{
227 TestBlock, VerifiedBlock,
228 commit::{CertifiedCommits, CommitRange},
229 context::Context,
230 core_thread::{CoreError, CoreThreadDispatcher},
231 dag_state::DagState,
232 error::{ConsensusError, ConsensusResult},
233 network::{BlockStream, NetworkClient},
234 round_prober::RoundProber,
235 round_tracker::PeerRoundTracker,
236 storage::mem_store::MemStore,
237 };
238
239 struct FakeThreadDispatcher {
240 highest_received_rounds: Vec<Round>,
241 }
242
243 impl FakeThreadDispatcher {
244 fn new(highest_received_rounds: Vec<Round>) -> Self {
245 Self {
246 highest_received_rounds,
247 }
248 }
249 }
250
251 #[async_trait]
252 impl CoreThreadDispatcher for FakeThreadDispatcher {
253 async fn add_blocks(
254 &self,
255 _blocks: Vec<VerifiedBlock>,
256 ) -> Result<BTreeSet<BlockRef>, CoreError> {
257 unimplemented!()
258 }
259
260 async fn check_block_refs(
261 &self,
262 _block_refs: Vec<BlockRef>,
263 ) -> Result<BTreeSet<BlockRef>, CoreError> {
264 unimplemented!()
265 }
266
267 async fn add_certified_commits(
268 &self,
269 _commits: CertifiedCommits,
270 ) -> Result<BTreeSet<BlockRef>, CoreError> {
271 unimplemented!()
272 }
273
274 async fn new_block(&self, _round: Round, _force: bool) -> Result<(), CoreError> {
275 unimplemented!()
276 }
277
278 async fn get_missing_blocks(&self) -> Result<BTreeSet<BlockRef>, CoreError> {
279 unimplemented!()
280 }
281
282 fn set_propagation_delay(&self, _propagation_delay: Round) -> Result<(), CoreError> {
283 Ok(())
284 }
285
286 fn set_last_known_proposed_round(&self, _round: Round) -> Result<(), CoreError> {
287 unimplemented!()
288 }
289
290 fn highest_received_rounds(&self) -> Vec<Round> {
291 self.highest_received_rounds.clone()
292 }
293 }
294
295 struct FakeNetworkClient {
296 highest_received_rounds: Vec<Vec<Round>>,
297 highest_accepted_rounds: Vec<Vec<Round>>,
298 }
299
300 impl FakeNetworkClient {
301 fn new(
302 highest_received_rounds: Vec<Vec<Round>>,
303 highest_accepted_rounds: Vec<Vec<Round>>,
304 ) -> Self {
305 Self {
306 highest_received_rounds,
307 highest_accepted_rounds,
308 }
309 }
310 }
311
312 #[async_trait]
313 impl NetworkClient for FakeNetworkClient {
314 async fn send_block(
315 &self,
316 _peer: AuthorityIndex,
317 _serialized_block: &VerifiedBlock,
318 _timeout: Duration,
319 ) -> ConsensusResult<()> {
320 unimplemented!("Unimplemented")
321 }
322
323 async fn subscribe_blocks(
324 &self,
325 _peer: AuthorityIndex,
326 _last_received: Round,
327 _timeout: Duration,
328 ) -> ConsensusResult<BlockStream> {
329 unimplemented!("Unimplemented")
330 }
331
332 async fn fetch_blocks(
333 &self,
334 _peer: AuthorityIndex,
335 _block_refs: Vec<BlockRef>,
336 _highest_accepted_rounds: Vec<Round>,
337 _breadth_first: bool,
338 _timeout: Duration,
339 ) -> ConsensusResult<Vec<Bytes>> {
340 unimplemented!("Unimplemented")
341 }
342
343 async fn fetch_commits(
344 &self,
345 _peer: AuthorityIndex,
346 _commit_range: CommitRange,
347 _timeout: Duration,
348 ) -> ConsensusResult<(Vec<Bytes>, Vec<Bytes>)> {
349 unimplemented!("Unimplemented")
350 }
351
352 async fn fetch_latest_blocks(
353 &self,
354 _peer: AuthorityIndex,
355 _authorities: Vec<AuthorityIndex>,
356 _timeout: Duration,
357 ) -> ConsensusResult<Vec<Bytes>> {
358 unimplemented!("Unimplemented")
359 }
360
361 async fn get_latest_rounds(
362 &self,
363 peer: AuthorityIndex,
364 _timeout: Duration,
365 ) -> ConsensusResult<(Vec<Round>, Vec<Round>)> {
366 let received_rounds = self.highest_received_rounds[peer].clone();
367 let accepted_rounds = self.highest_accepted_rounds[peer].clone();
368 if received_rounds.is_empty() && accepted_rounds.is_empty() {
369 Err(ConsensusError::NetworkRequestTimeout("test".to_string()))
370 } else {
371 Ok((received_rounds, accepted_rounds))
372 }
373 }
374 }
375
376 #[tokio::test]
377 async fn test_round_prober() {
378 telemetry_subscribers::init_for_testing();
379 const NUM_AUTHORITIES: usize = 7;
380 let context = Arc::new(Context::new_for_test(NUM_AUTHORITIES).0);
381 let core_thread_dispatcher = Arc::new(FakeThreadDispatcher::new(vec![
382 110, 120, 130, 140, 150, 160, 170,
383 ]));
384 let store = Arc::new(MemStore::new());
385 let dag_state = Arc::new(RwLock::new(DagState::new(context.clone(), store)));
386 let network_client = Arc::new(FakeNetworkClient::new(
388 vec![
389 vec![],
390 vec![109, 121, 131, 0, 151, 161, 171],
391 vec![101, 0, 103, 104, 105, 166, 107],
392 vec![],
393 vec![100, 102, 133, 0, 155, 106, 177],
394 vec![105, 115, 103, 0, 125, 126, 127],
395 vec![10, 20, 30, 40, 50, 60],
396 ], vec![
398 vec![],
399 vec![0, 121, 131, 0, 151, 161, 171],
400 vec![1, 0, 103, 104, 105, 166, 107],
401 vec![],
402 vec![0, 102, 133, 0, 155, 106, 177],
403 vec![1, 115, 103, 0, 125, 126, 127],
404 vec![1, 20, 30, 40, 50, 60],
405 ], ));
407
408 let round_tracker = Arc::new(RwLock::new(PeerRoundTracker::new(context.clone())));
409 let prober = RoundProber::new(
410 context.clone(),
411 core_thread_dispatcher.clone(),
412 round_tracker.clone(),
413 dag_state.clone(),
414 network_client.clone(),
415 );
416
417 let blocks = (0..NUM_AUTHORITIES)
419 .map(|authority| {
420 let round = 110 + (authority as u32 * 10);
421 VerifiedBlock::new_for_test(TestBlock::new(round, authority as u32).build())
422 })
423 .collect::<Vec<_>>();
424
425 dag_state.write().accept_blocks(blocks);
426
427 let propagation_delay = prober.probe().await;
438
439 assert_eq!(propagation_delay, 10);
440 }
441}