sui_network/randomness/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use self::{auth::AllowedPeersUpdatable, metrics::Metrics};
5use anemo::PeerId;
6use anyhow::Result;
7use fastcrypto::groups::bls12381;
8use fastcrypto_tbls::{
9    dkg_v1,
10    nodes::PartyId,
11    tbls::ThresholdBls,
12    types::{ShareIndex, ThresholdBls12381MinSig},
13};
14use mysten_common::ZipDebugEqIteratorExt;
15use mysten_metrics::spawn_monitored_task;
16use mysten_network::anemo_ext::NetworkExt;
17use serde::{Deserialize, Serialize};
18use std::{
19    collections::{HashMap, HashSet, btree_map::BTreeMap},
20    ops::Bound,
21    sync::Arc,
22    time::{self, Duration},
23};
24use sui_config::p2p::RandomnessConfig;
25use sui_macros::fail_point_if;
26use sui_types::{
27    base_types::AuthorityName,
28    committee::EpochId,
29    crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
30};
31use tokio::sync::{
32    OnceCell, {mpsc, oneshot},
33};
34use tracing::{debug, error, info, instrument, warn};
35
36mod auth;
37mod builder;
38mod generated {
39    include!(concat!(env!("OUT_DIR"), "/sui.Randomness.rs"));
40}
41mod metrics;
42mod server;
43#[cfg(test)]
44mod tests;
45
46pub use builder::{Builder, UnstartedRandomness};
47pub use generated::{
48    randomness_client::RandomnessClient,
49    randomness_server::{Randomness, RandomnessServer},
50};
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct SendSignaturesRequest {
54    epoch: EpochId,
55    round: RandomnessRound,
56    // BCS-serialized `RandomnessPartialSignature` values. We store raw bytes here to enable
57    // defenses against too-large messages.
58    // The protocol requires the signatures to be ordered by share index (as provided by fastcrypto).
59    partial_sigs: Vec<Vec<u8>>,
60    // If peer already has a full signature available for the round, it's provided here in lieu
61    // of partial sigs.
62    sig: Option<RandomnessSignature>,
63}
64
65/// A handle to the Randomness network subsystem.
66///
67/// This handle can be cloned and shared. Once all copies of a Randomness system's Handle have been
68/// dropped, the Randomness system will be gracefully shutdown.
69#[derive(Clone, Debug)]
70pub struct Handle {
71    sender: mpsc::Sender<RandomnessMessage>,
72}
73
74impl Handle {
75    /// Transitions the Randomness system to a new epoch. Cancels all partial signature sends for
76    /// prior epochs.
77    pub fn update_epoch(
78        &self,
79        new_epoch: EpochId,
80        authority_info: HashMap<AuthorityName, (PeerId, PartyId)>,
81        dkg_output: dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
82        aggregation_threshold: u16,
83        recovered_last_completed_round: Option<RandomnessRound>, // set to None if not starting up mid-epoch
84    ) {
85        self.sender
86            .try_send(RandomnessMessage::UpdateEpoch(
87                new_epoch,
88                authority_info,
89                dkg_output,
90                aggregation_threshold,
91                recovered_last_completed_round,
92            ))
93            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
94    }
95
96    /// Begins transmitting partial signatures for the given epoch and round until completed.
97    pub fn send_partial_signatures(&self, epoch: EpochId, round: RandomnessRound) {
98        self.sender
99            .try_send(RandomnessMessage::SendPartialSignatures(epoch, round))
100            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
101    }
102
103    /// Records the given round as complete, stopping any partial signature sends.
104    pub fn complete_round(&self, epoch: EpochId, round: RandomnessRound) {
105        self.sender
106            .try_send(RandomnessMessage::CompleteRound(epoch, round))
107            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
108    }
109
110    /// Admin interface handler: generates partial signatures for the given round at the
111    /// current epoch.
112    pub fn admin_get_partial_signatures(
113        &self,
114        round: RandomnessRound,
115        tx: oneshot::Sender<Vec<u8>>,
116    ) {
117        self.sender
118            .try_send(RandomnessMessage::AdminGetPartialSignatures(round, tx))
119            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
120    }
121
122    /// Admin interface handler: injects partial signatures for the given round at the
123    /// current epoch, skipping validity checks.
124    pub fn admin_inject_partial_signatures(
125        &self,
126        authority_name: AuthorityName,
127        round: RandomnessRound,
128        sigs: Vec<RandomnessPartialSignature>,
129        result_channel: oneshot::Sender<Result<()>>,
130    ) {
131        self.sender
132            .try_send(RandomnessMessage::AdminInjectPartialSignatures(
133                authority_name,
134                round,
135                sigs,
136                result_channel,
137            ))
138            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
139    }
140
141    /// Admin interface handler: injects full signature for the given round at the
142    /// current epoch, skipping validity checks.
143    pub fn admin_inject_full_signature(
144        &self,
145        round: RandomnessRound,
146        sig: RandomnessSignature,
147        result_channel: oneshot::Sender<Result<()>>,
148    ) {
149        self.sender
150            .try_send(RandomnessMessage::AdminInjectFullSignature(
151                round,
152                sig,
153                result_channel,
154            ))
155            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
156    }
157
158    // For testing.
159    pub fn new_stub() -> Self {
160        let (sender, mut receiver) = mpsc::channel(100);
161        // Keep receiver open until all senders are closed.
162        tokio::spawn(async move {
163            loop {
164                tokio::select! {
165                    m = receiver.recv() => {
166                        if m.is_none() {
167                            break;
168                        }
169                    },
170                }
171            }
172        });
173        Self { sender }
174    }
175}
176
177#[derive(Debug)]
178enum RandomnessMessage {
179    UpdateEpoch(
180        EpochId,
181        HashMap<AuthorityName, (PeerId, PartyId)>,
182        dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
183        u16,                     // aggregation_threshold
184        Option<RandomnessRound>, // recovered_highest_completed_round
185    ),
186    SendPartialSignatures(EpochId, RandomnessRound),
187    CompleteRound(EpochId, RandomnessRound),
188    ReceiveSignatures(
189        PeerId,
190        EpochId,
191        RandomnessRound,
192        Vec<Vec<u8>>,
193        Option<RandomnessSignature>,
194    ),
195    MaybeIgnoreByzantinePeer(EpochId, PeerId),
196    AdminGetPartialSignatures(RandomnessRound, oneshot::Sender<Vec<u8>>),
197    AdminInjectPartialSignatures(
198        AuthorityName,
199        RandomnessRound,
200        Vec<RandomnessPartialSignature>,
201        oneshot::Sender<Result<()>>,
202    ),
203    AdminInjectFullSignature(
204        RandomnessRound,
205        RandomnessSignature,
206        oneshot::Sender<Result<()>>,
207    ),
208}
209
210struct RandomnessEventLoop {
211    name: AuthorityName,
212    config: RandomnessConfig,
213    mailbox: mpsc::Receiver<RandomnessMessage>,
214    mailbox_sender: mpsc::WeakSender<RandomnessMessage>,
215    network: anemo::Network,
216    allowed_peers: AllowedPeersUpdatable,
217    allowed_peers_set: HashSet<PeerId>,
218    metrics: Metrics,
219    randomness_tx: mpsc::Sender<(EpochId, RandomnessRound, Vec<u8>)>,
220
221    epoch: EpochId,
222    authority_info: Arc<HashMap<AuthorityName, (PeerId, PartyId)>>,
223    peer_share_ids: Option<HashMap<PeerId, Vec<ShareIndex>>>,
224    blocked_share_id_count: usize,
225    dkg_output: Option<dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>>,
226    aggregation_threshold: u16,
227    highest_requested_round: BTreeMap<EpochId, RandomnessRound>,
228    send_tasks: BTreeMap<
229        RandomnessRound,
230        (
231            tokio::task::JoinHandle<()>,
232            Arc<OnceCell<RandomnessSignature>>,
233        ),
234    >,
235    round_request_time: BTreeMap<(EpochId, RandomnessRound), time::Instant>,
236    future_epoch_partial_sigs: BTreeMap<(EpochId, RandomnessRound, PeerId), Vec<Vec<u8>>>,
237    received_partial_sigs: BTreeMap<(RandomnessRound, PeerId), Vec<RandomnessPartialSignature>>,
238    completed_sigs: BTreeMap<RandomnessRound, RandomnessSignature>,
239    highest_completed_round: BTreeMap<EpochId, RandomnessRound>,
240}
241
242impl RandomnessEventLoop {
243    pub async fn start(mut self) {
244        info!("Randomness network event loop started");
245
246        loop {
247            tokio::select! {
248                maybe_message = self.mailbox.recv() => {
249                    // Once all handles to our mailbox have been dropped this
250                    // will yield `None` and we can terminate the event loop.
251                    if let Some(message) = maybe_message {
252                        self.handle_message(message);
253                    } else {
254                        break;
255                    }
256                },
257            }
258        }
259
260        info!("Randomness network event loop ended");
261    }
262
263    fn handle_message(&mut self, message: RandomnessMessage) {
264        match message {
265            RandomnessMessage::UpdateEpoch(
266                epoch,
267                authority_info,
268                dkg_output,
269                aggregation_threshold,
270                recovered_highest_completed_round,
271            ) => {
272                if let Err(e) = self.update_epoch(
273                    epoch,
274                    authority_info,
275                    dkg_output,
276                    aggregation_threshold,
277                    recovered_highest_completed_round,
278                ) {
279                    error!("BUG: failed to update epoch in RandomnessEventLoop: {e:?}");
280                }
281            }
282            RandomnessMessage::SendPartialSignatures(epoch, round) => {
283                self.send_partial_signatures(epoch, round)
284            }
285            RandomnessMessage::CompleteRound(epoch, round) => self.complete_round(epoch, round),
286            RandomnessMessage::ReceiveSignatures(peer_id, epoch, round, partial_sigs, sig) => {
287                if let Some(sig) = sig {
288                    self.receive_full_signature(peer_id, epoch, round, sig)
289                } else {
290                    self.receive_partial_signatures(peer_id, epoch, round, partial_sigs)
291                }
292            }
293            RandomnessMessage::MaybeIgnoreByzantinePeer(epoch, peer_id) => {
294                self.maybe_ignore_byzantine_peer(epoch, peer_id)
295            }
296            RandomnessMessage::AdminGetPartialSignatures(round, tx) => {
297                self.admin_get_partial_signatures(round, tx)
298            }
299            RandomnessMessage::AdminInjectPartialSignatures(
300                authority_name,
301                round,
302                sigs,
303                result_channel,
304            ) => {
305                let _ = result_channel.send(self.admin_inject_partial_signatures(
306                    authority_name,
307                    round,
308                    sigs,
309                ));
310            }
311            RandomnessMessage::AdminInjectFullSignature(round, sig, result_channel) => {
312                let _ = result_channel.send(self.admin_inject_full_signature(round, sig));
313            }
314        }
315    }
316
317    #[instrument(level = "debug", skip_all, fields(?new_epoch))]
318    fn update_epoch(
319        &mut self,
320        new_epoch: EpochId,
321        authority_info: HashMap<AuthorityName, (PeerId, PartyId)>,
322        dkg_output: dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
323        aggregation_threshold: u16,
324        recovered_highest_completed_round: Option<RandomnessRound>,
325    ) -> Result<()> {
326        assert!(self.dkg_output.is_none() || new_epoch > self.epoch);
327
328        debug!("updating randomness network loop to new epoch");
329
330        self.peer_share_ids = Some(authority_info.iter().try_fold(
331            HashMap::new(),
332            |mut acc, (_name, (peer_id, party_id))| -> Result<_> {
333                let ids = dkg_output
334                    .nodes
335                    .share_ids_of(*party_id)
336                    .expect("party_id should be valid");
337                acc.insert(*peer_id, ids);
338                Ok(acc)
339            },
340        )?);
341        self.allowed_peers_set = authority_info
342            .values()
343            .map(|(peer_id, _)| *peer_id)
344            .collect();
345        self.allowed_peers
346            .update(Arc::new(self.allowed_peers_set.clone()));
347        self.epoch = new_epoch;
348        self.authority_info = Arc::new(authority_info);
349        self.dkg_output = Some(dkg_output);
350        self.aggregation_threshold = aggregation_threshold;
351        if let Some(round) = recovered_highest_completed_round {
352            self.highest_completed_round
353                .entry(new_epoch)
354                .and_modify(|r| *r = std::cmp::max(*r, round))
355                .or_insert(round);
356        }
357        for (_, (task, _)) in std::mem::take(&mut self.send_tasks) {
358            task.abort();
359        }
360        self.metrics.set_epoch(new_epoch);
361
362        // Throw away info from old epochs.
363        self.highest_requested_round = self.highest_requested_round.split_off(&new_epoch);
364        self.round_request_time = self
365            .round_request_time
366            .split_off(&(new_epoch, RandomnessRound(0)));
367        self.received_partial_sigs.clear();
368        self.completed_sigs.clear();
369        self.highest_completed_round = self.highest_completed_round.split_off(&new_epoch);
370
371        // Start any pending tasks for the new epoch.
372        self.maybe_start_pending_tasks();
373
374        // Aggregate any sigs received early from the new epoch.
375        // (We can't call `maybe_aggregate_partial_signatures` directly while iterating,
376        // because it takes `&mut self`, so we store in a Vec first.)
377        for ((epoch, round, peer_id), sig_bytes) in
378            std::mem::take(&mut self.future_epoch_partial_sigs)
379        {
380            // We can fully validate these now that we have current epoch DKG output.
381            self.receive_partial_signatures(peer_id, epoch, round, sig_bytes);
382        }
383        let rounds_to_aggregate: Vec<_> =
384            self.received_partial_sigs.keys().map(|(r, _)| *r).collect();
385        for round in rounds_to_aggregate {
386            self.maybe_aggregate_partial_signatures(new_epoch, round);
387        }
388
389        Ok(())
390    }
391
392    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
393    fn send_partial_signatures(&mut self, epoch: EpochId, round: RandomnessRound) {
394        if epoch < self.epoch {
395            error!(
396                "BUG: skipping sending partial sigs, we are already up to epoch {}",
397                self.epoch
398            );
399            debug_assert!(
400                false,
401                "skipping sending partial sigs, we are already up to higher epoch"
402            );
403            return;
404        }
405        if epoch == self.epoch
406            && let Some(highest_completed_round) = self.highest_completed_round.get(&epoch)
407            && round <= *highest_completed_round
408        {
409            info!("skipping sending partial sigs, we already have completed this round");
410            return;
411        }
412
413        self.highest_requested_round
414            .entry(epoch)
415            .and_modify(|r| *r = std::cmp::max(*r, round))
416            .or_insert(round);
417        self.round_request_time
418            .insert((epoch, round), time::Instant::now());
419        self.maybe_start_pending_tasks();
420    }
421
422    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
423    fn complete_round(&mut self, epoch: EpochId, round: RandomnessRound) {
424        debug!("completing randomness round");
425        let new_highest_round = *self
426            .highest_completed_round
427            .entry(epoch)
428            .and_modify(|r| *r = std::cmp::max(*r, round))
429            .or_insert(round);
430        if round != new_highest_round {
431            // This round completion came out of order, and we're already ahead. Nothing more
432            // to do in that case.
433            return;
434        }
435
436        self.round_request_time = self.round_request_time.split_off(&(epoch, round + 1));
437
438        if epoch == self.epoch {
439            self.remove_partial_sigs_in_range((
440                Bound::Included((RandomnessRound(0), PeerId([0; 32]))),
441                Bound::Excluded((round + 1, PeerId([0; 32]))),
442            ));
443            self.completed_sigs = self.completed_sigs.split_off(&(round + 1));
444            for (_, (task, _)) in self.send_tasks.iter().take_while(|(r, _)| **r <= round) {
445                task.abort();
446            }
447            self.send_tasks = self.send_tasks.split_off(&(round + 1));
448            self.maybe_start_pending_tasks();
449        }
450
451        self.update_rounds_pending_metric();
452    }
453
454    #[instrument(level = "debug", skip_all, fields(?peer_id, ?epoch, ?round))]
455    fn receive_partial_signatures(
456        &mut self,
457        peer_id: PeerId,
458        epoch: EpochId,
459        round: RandomnessRound,
460        sig_bytes: Vec<Vec<u8>>,
461    ) {
462        // Basic validity checks.
463        if epoch < self.epoch {
464            debug!(
465                "skipping received partial sigs, we are already up to epoch {}",
466                self.epoch
467            );
468            return;
469        }
470        if epoch > self.epoch + 1 {
471            debug!(
472                "skipping received partial sigs, we are still on epoch {}",
473                self.epoch
474            );
475            return;
476        }
477        if epoch == self.epoch && self.completed_sigs.contains_key(&round) {
478            debug!("skipping received partial sigs, we already have completed this sig");
479            return;
480        }
481        let highest_completed_round = self.highest_completed_round.get(&epoch).copied();
482        if let Some(highest_completed_round) = &highest_completed_round
483            && *highest_completed_round >= round
484        {
485            debug!("skipping received partial sigs, we already have completed this round");
486            return;
487        }
488
489        // If sigs are for a future epoch, we can't fully verify them without DKG output.
490        // Save them for later use.
491        if epoch != self.epoch || self.peer_share_ids.is_none() {
492            if round.0 >= self.config.max_partial_sigs_rounds_ahead() {
493                debug!("skipping received partial sigs for future epoch, round too far ahead",);
494                return;
495            }
496
497            debug!("saving partial sigs from future epoch for later use");
498            self.future_epoch_partial_sigs
499                .insert((epoch, round, peer_id), sig_bytes);
500            return;
501        }
502
503        // Verify shape of sigs matches what we expect for the peer.
504        let peer_share_ids = self.peer_share_ids.as_ref().expect("checked above");
505        let expected_share_ids = if let Some(expected_share_ids) = peer_share_ids.get(&peer_id) {
506            expected_share_ids
507        } else {
508            debug!("received partial sigs from unknown peer");
509            return;
510        };
511        if sig_bytes.len() != expected_share_ids.len() as usize {
512            warn!(
513                "received partial sigs with wrong share ids count: expected {}, got {}",
514                expected_share_ids.len(),
515                sig_bytes.len(),
516            );
517            return;
518        }
519
520        // Accept partial signatures up to `max_partial_sigs_rounds_ahead` past the round of the
521        // last completed signature, or the highest completed round, whichever is greater.
522        let last_completed_signature = self.completed_sigs.last_key_value().map(|(r, _)| *r);
523        let last_completed_round = std::cmp::max(last_completed_signature, highest_completed_round)
524            .unwrap_or(RandomnessRound(0));
525        if round.0
526            >= last_completed_round
527                .0
528                .saturating_add(self.config.max_partial_sigs_rounds_ahead())
529        {
530            debug!(
531                "skipping received partial sigs, most recent round we completed was only {last_completed_round}",
532            );
533            return;
534        }
535
536        // Deserialize the partial sigs.
537        let partial_sigs =
538            match sig_bytes
539                .iter()
540                .try_fold(Vec::new(), |mut acc, bytes| -> Result<_> {
541                    let sig: RandomnessPartialSignature = bcs::from_bytes(bytes)?;
542                    acc.push(sig);
543                    Ok(acc)
544                }) {
545                Ok(partial_sigs) => partial_sigs,
546                Err(e) => {
547                    warn!("failed to deserialize partial sigs: {e:?}");
548                    return;
549                }
550            };
551        // Verify we received the expected share IDs (to protect against a validator that sends
552        // valid signatures of other peers which will be successfully verified below).
553        let received_share_ids = partial_sigs.iter().map(|s| s.index);
554        if received_share_ids
555            .zip_debug_eq(expected_share_ids.iter())
556            .any(|(a, b)| a != *b)
557        {
558            let received_share_ids = partial_sigs.iter().map(|s| s.index).collect::<Vec<_>>();
559            warn!(
560                "received partial sigs with wrong share ids: expected {expected_share_ids:?}, received {received_share_ids:?}"
561            );
562            return;
563        }
564
565        // We passed all the checks, save the partial sigs.
566        debug!("recording received partial signatures");
567        self.received_partial_sigs
568            .insert((round, peer_id), partial_sigs);
569
570        self.maybe_aggregate_partial_signatures(epoch, round);
571    }
572
573    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
574    fn maybe_aggregate_partial_signatures(&mut self, epoch: EpochId, round: RandomnessRound) {
575        if let Some(highest_completed_round) = self.highest_completed_round.get(&epoch)
576            && round <= *highest_completed_round
577        {
578            info!("skipping aggregation for already-completed round");
579            return;
580        }
581
582        let highest_requested_round = self.highest_requested_round.get(&epoch);
583        if highest_requested_round.is_none() || round > *highest_requested_round.unwrap() {
584            // We have to wait here, because even if we have enough information from other nodes
585            // to complete the signature, local shared object versions are not set until consensus
586            // finishes processing the corresponding commit. This function will be called again
587            // after maybe_start_pending_tasks begins this round locally.
588            debug!(
589                "waiting to aggregate randomness partial signatures until local consensus catches up"
590            );
591            return;
592        }
593
594        if epoch != self.epoch {
595            debug!(
596                "waiting to aggregate randomness partial signatures until DKG completes for epoch"
597            );
598            return;
599        }
600
601        if self.completed_sigs.contains_key(&round) {
602            info!("skipping aggregation for already-completed signature");
603            return;
604        }
605
606        let vss_pk = {
607            let Some(dkg_output) = &self.dkg_output else {
608                debug!("called maybe_aggregate_partial_signatures before DKG completed");
609                return;
610            };
611            &dkg_output.vss_pk
612        };
613
614        let sig_bounds = (
615            Bound::Included((round, PeerId([0; 32]))),
616            Bound::Excluded((round + 1, PeerId([0; 32]))),
617        );
618
619        // If we have enough partial signatures, aggregate them.
620        let sig_range = self
621            .received_partial_sigs
622            .range(sig_bounds)
623            .flat_map(|(_, sigs)| sigs);
624        let mut sig =
625            match ThresholdBls12381MinSig::aggregate(self.aggregation_threshold, sig_range) {
626                Ok(sig) => sig,
627                Err(fastcrypto::error::FastCryptoError::NotEnoughInputs) => return, // wait for more input
628                Err(e) => {
629                    error!("error while aggregating randomness partial signatures: {e:?}");
630                    return;
631                }
632            };
633
634        // Try to verify the aggregated signature all at once. (Should work in the happy path.)
635        if ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig).is_err() {
636            // If verifiation fails, some of the inputs must be invalid. We have to go through
637            // one-by-one to find which.
638            // TODO: add test for individual sig verification.
639            self.received_partial_sigs
640                .retain(|&(r, peer_id), partial_sigs| {
641                    if round != r {
642                        return true;
643                    }
644                    if ThresholdBls12381MinSig::partial_verify_batch(
645                        vss_pk,
646                        &round.signature_message(),
647                        partial_sigs.iter(),
648                        &mut rand::thread_rng(),
649                    )
650                    .is_err()
651                    {
652                        warn!(
653                            "received invalid partial signatures from possibly-Byzantine peer {peer_id}"
654                        );
655                        if let Some(sender) = self.mailbox_sender.upgrade() {
656                            sender.try_send(RandomnessMessage::MaybeIgnoreByzantinePeer(
657                                epoch,
658                                peer_id,
659                            ))
660                            .expect("RandomnessEventLoop mailbox should not overflow or be closed");
661                        }
662                        return false;
663                    }
664                    true
665                });
666            let sig_range = self
667                .received_partial_sigs
668                .range(sig_bounds)
669                .flat_map(|(_, sigs)| sigs);
670            sig = match ThresholdBls12381MinSig::aggregate(self.aggregation_threshold, sig_range) {
671                Ok(sig) => sig,
672                Err(fastcrypto::error::FastCryptoError::NotEnoughInputs) => return, // wait for more input
673                Err(e) => {
674                    error!("error while aggregating randomness partial signatures: {e:?}");
675                    return;
676                }
677            };
678            if let Err(e) =
679                ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
680            {
681                error!(
682                    "error while verifying randomness partial signatures after removing invalid partials: {e:?}"
683                );
684                debug_assert!(
685                    false,
686                    "error while verifying randomness partial signatures after removing invalid partials"
687                );
688                return;
689            }
690        }
691
692        debug!("successfully generated randomness full signature");
693        self.process_valid_full_signature(epoch, round, sig);
694    }
695
696    #[instrument(level = "debug", skip_all, fields(?peer_id, ?epoch, ?round))]
697    fn receive_full_signature(
698        &mut self,
699        peer_id: PeerId,
700        epoch: EpochId,
701        round: RandomnessRound,
702        sig: RandomnessSignature,
703    ) {
704        let vss_pk = {
705            let Some(dkg_output) = &self.dkg_output else {
706                debug!("called receive_full_signature before DKG completed");
707                return;
708            };
709            &dkg_output.vss_pk
710        };
711
712        // Basic validity checks.
713        if epoch != self.epoch {
714            debug!("skipping received full sig, we are on epoch {}", self.epoch);
715            return;
716        }
717        if self.completed_sigs.contains_key(&round) {
718            debug!("skipping received full sigs, we already have completed this sig");
719            return;
720        }
721        let highest_completed_round = self.highest_completed_round.get(&epoch).copied();
722        if let Some(highest_completed_round) = &highest_completed_round
723            && *highest_completed_round >= round
724        {
725            debug!("skipping received full sig, we already have completed this round");
726            return;
727        }
728
729        let highest_requested_round = self.highest_requested_round.get(&epoch);
730        if highest_requested_round.is_none() || round > *highest_requested_round.unwrap() {
731            // Wait for local consensus to catch up if necessary.
732            debug!(
733                "skipping received full signature, local consensus is not caught up to its round"
734            );
735            return;
736        }
737
738        if let Err(e) =
739            ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
740        {
741            info!("received invalid full signature from peer {peer_id}: {e:?}");
742            if let Some(sender) = self.mailbox_sender.upgrade() {
743                sender
744                    .try_send(RandomnessMessage::MaybeIgnoreByzantinePeer(epoch, peer_id))
745                    .expect("RandomnessEventLoop mailbox should not overflow or be closed");
746            }
747            return;
748        }
749
750        debug!("received valid randomness full signature");
751        self.process_valid_full_signature(epoch, round, sig);
752    }
753
754    fn process_valid_full_signature(
755        &mut self,
756        epoch: EpochId,
757        round: RandomnessRound,
758        sig: RandomnessSignature,
759    ) {
760        assert_eq!(epoch, self.epoch);
761
762        if let Some((_, full_sig_cell)) = self.send_tasks.get(&round) {
763            full_sig_cell
764                .set(sig)
765                .expect("full signature should never be processed twice");
766        }
767        self.completed_sigs.insert(round, sig);
768        self.remove_partial_sigs_in_range((
769            Bound::Included((round, PeerId([0; 32]))),
770            Bound::Excluded((round + 1, PeerId([0; 32]))),
771        ));
772        self.metrics.record_completed_round(round);
773        if let Some(start_time) = self.round_request_time.get(&(epoch, round))
774            && let Some(metric) = self.metrics.round_generation_latency_metric()
775        {
776            metric.observe(start_time.elapsed().as_secs_f64());
777        }
778
779        let sig_bytes = bcs::to_bytes(&sig).expect("signature serialization should not fail");
780        if let Err(e) = self.randomness_tx.try_send((epoch, round, sig_bytes)) {
781            match e {
782                // Receiver is torn down during node shutdown; dropping the round is harmless.
783                mpsc::error::TrySendError::Closed(_) => {
784                    info!("dropping completed randomness round {round}: receiver channel closed");
785                }
786                // Mailbox capacity is huge (default 1M); a full mailbox means a real bug.
787                mpsc::error::TrySendError::Full(_) => {
788                    panic!("RandomnessRoundReceiver mailbox should not overflow");
789                }
790            }
791        }
792    }
793
794    fn maybe_ignore_byzantine_peer(&mut self, epoch: EpochId, peer_id: PeerId) {
795        if epoch != self.epoch {
796            return; // make sure we're still on the same epoch
797        }
798        let Some(dkg_output) = &self.dkg_output else {
799            return; // can't ignore a peer if we haven't finished DKG
800        };
801        if !self.allowed_peers_set.contains(&peer_id) {
802            return; // peer is already disallowed
803        }
804        let Some(peer_share_ids) = &self.peer_share_ids else {
805            return; // can't ignore a peer if we haven't finished DKG
806        };
807        let Some(peer_shares) = peer_share_ids.get(&peer_id) else {
808            warn!("can't ignore unknown byzantine peer {peer_id:?}");
809            return;
810        };
811        let max_ignored_shares = (self.config.max_ignored_peer_weight_factor()
812            * (dkg_output.nodes.total_weight() as f64)) as usize;
813        if self.blocked_share_id_count + peer_shares.len() > max_ignored_shares {
814            warn!(
815                "ignoring byzantine peer {peer_id:?} with {} shares would exceed max ignored peer weight {max_ignored_shares}",
816                peer_shares.len()
817            );
818            return;
819        }
820
821        warn!(
822            "ignoring byzantine peer {peer_id:?} with {} shares",
823            peer_shares.len()
824        );
825        self.blocked_share_id_count += peer_shares.len();
826        self.allowed_peers_set.remove(&peer_id);
827        self.allowed_peers
828            .update(Arc::new(self.allowed_peers_set.clone()));
829        self.metrics.inc_num_ignored_byzantine_peers();
830    }
831
832    fn maybe_start_pending_tasks(&mut self) {
833        let dkg_output = if let Some(dkg_output) = &self.dkg_output {
834            dkg_output
835        } else {
836            return; // wait for DKG
837        };
838        let shares = if let Some(shares) = &dkg_output.shares {
839            shares
840        } else {
841            return; // can't participate in randomness generation without shares
842        };
843        let highest_requested_round =
844            if let Some(highest_requested_round) = self.highest_requested_round.get(&self.epoch) {
845                highest_requested_round
846            } else {
847                return; // no rounds to start
848            };
849        // Begin from the next round after the most recent one we've started (or, if none are running,
850        // after the highest completed round in the epoch).
851        let start_round = std::cmp::max(
852            if let Some(highest_completed_round) = self.highest_completed_round.get(&self.epoch) {
853                highest_completed_round.checked_add(1).unwrap()
854            } else {
855                RandomnessRound(0)
856            },
857            self.send_tasks
858                .last_key_value()
859                .map(|(r, _)| r.checked_add(1).unwrap())
860                .unwrap_or(RandomnessRound(0)),
861        );
862
863        let mut rounds_to_aggregate = Vec::new();
864        for round in start_round.0..=highest_requested_round.0 {
865            let round = RandomnessRound(round);
866
867            if self.send_tasks.len() >= self.config.max_partial_sigs_concurrent_sends() {
868                break; // limit concurrent tasks
869            }
870
871            let full_sig_cell = Arc::new(OnceCell::new());
872            self.send_tasks.entry(round).or_insert_with(|| {
873                let name = self.name;
874                let network = self.network.clone();
875                let retry_interval = self.config.partial_signature_retry_interval();
876                let metrics = self.metrics.clone();
877                let authority_info = self.authority_info.clone();
878                let epoch = self.epoch;
879                let partial_sigs = ThresholdBls12381MinSig::partial_sign_batch(
880                    shares.iter(),
881                    &round.signature_message(),
882                );
883                let full_sig_cell_clone = full_sig_cell.clone();
884
885                // Record own partial sigs.
886                if !self.completed_sigs.contains_key(&round) {
887                    self.received_partial_sigs
888                        .insert((round, self.network.peer_id()), partial_sigs.clone());
889                    rounds_to_aggregate.push((epoch, round));
890                }
891
892                debug!("sending partial sigs for epoch {epoch}, round {round}");
893                (
894                    spawn_monitored_task!(RandomnessEventLoop::send_signatures_task(
895                        name,
896                        network,
897                        retry_interval,
898                        metrics,
899                        authority_info,
900                        epoch,
901                        round,
902                        partial_sigs,
903                        full_sig_cell_clone,
904                    )),
905                    full_sig_cell,
906                )
907            });
908        }
909
910        self.update_rounds_pending_metric();
911
912        // After starting a round, we have generated our own partial sigs. Check if that's
913        // enough for us to aggregate already.
914        for (epoch, round) in rounds_to_aggregate {
915            self.maybe_aggregate_partial_signatures(epoch, round);
916        }
917    }
918
919    #[allow(clippy::type_complexity)]
920    fn remove_partial_sigs_in_range(
921        &mut self,
922        range: (
923            Bound<(RandomnessRound, PeerId)>,
924            Bound<(RandomnessRound, PeerId)>,
925        ),
926    ) {
927        let keys_to_remove: Vec<_> = self
928            .received_partial_sigs
929            .range(range)
930            .map(|(key, _)| *key)
931            .collect();
932        for key in keys_to_remove {
933            // Have to remove keys one-by-one because BTreeMap does not support range-removal.
934            self.received_partial_sigs.remove(&key);
935        }
936    }
937
938    async fn send_signatures_task(
939        name: AuthorityName,
940        network: anemo::Network,
941        retry_interval: Duration,
942        metrics: Metrics,
943        authority_info: Arc<HashMap<AuthorityName, (PeerId, PartyId)>>,
944        epoch: EpochId,
945        round: RandomnessRound,
946        partial_sigs: Vec<RandomnessPartialSignature>,
947        full_sig: Arc<OnceCell<RandomnessSignature>>,
948    ) {
949        // For simtests, we may test not sending partial signatures.
950        #[allow(unused_mut)]
951        let mut fail_point_skip_sending = false;
952        fail_point_if!("rb-send-partial-signatures", || {
953            fail_point_skip_sending = true;
954        });
955        if fail_point_skip_sending {
956            warn!("skipping sending partial sigs due to simtest fail point");
957            return;
958        }
959
960        let _metrics_guard = metrics
961            .round_observation_latency_metric()
962            .map(|metric| metric.start_timer());
963
964        let peers: HashMap<_, _> = authority_info
965            .iter()
966            .map(|(name, (peer_id, _party_id))| (name, network.waiting_peer(*peer_id)))
967            .collect();
968        let partial_sigs: Vec<_> = partial_sigs
969            .iter()
970            .map(|sig| bcs::to_bytes(sig).expect("message serialization should not fail"))
971            .collect();
972
973        loop {
974            let mut requests = Vec::new();
975            for (peer_name, peer) in &peers {
976                if name == **peer_name {
977                    continue; // don't send partial sigs to self
978                }
979                let mut client = RandomnessClient::new(peer.clone());
980                const SEND_PARTIAL_SIGNATURES_TIMEOUT: Duration = Duration::from_secs(10);
981                let full_sig = full_sig.get().cloned();
982                let request = anemo::Request::new(SendSignaturesRequest {
983                    epoch,
984                    round,
985                    partial_sigs: if full_sig.is_none() {
986                        partial_sigs.clone()
987                    } else {
988                        Vec::new()
989                    },
990                    sig: full_sig,
991                })
992                .with_timeout(SEND_PARTIAL_SIGNATURES_TIMEOUT);
993                requests.push(async move {
994                    let result = client.send_signatures(request).await;
995                    if let Err(_error) = result {
996                        // TODO: add Display impl to anemo::rpc::Status, log it here
997                        debug!("failed to send partial signatures to {peer_name}");
998                    }
999                });
1000            }
1001
1002            // Process all requests.
1003            futures::future::join_all(requests).await;
1004
1005            // Keep retrying send to all peers until task is aborted via external message.
1006            tokio::time::sleep(retry_interval).await;
1007        }
1008    }
1009
1010    fn update_rounds_pending_metric(&self) {
1011        let highest_requested_round = self
1012            .highest_requested_round
1013            .get(&self.epoch)
1014            .map(|r| r.0)
1015            .unwrap_or(0);
1016        let highest_completed_round = self
1017            .highest_completed_round
1018            .get(&self.epoch)
1019            .map(|r| r.0)
1020            .unwrap_or(0);
1021        let num_rounds_pending =
1022            highest_requested_round.saturating_sub(highest_completed_round) as i64;
1023        let prev_value = self.metrics.num_rounds_pending().unwrap_or_default();
1024        if num_rounds_pending / 100 > prev_value / 100 {
1025            warn!(
1026                // Recording multiples of 100 so tests can match on the log message.
1027                "RandomnessEventLoop randomness generation backlog: over {} rounds are pending (oldest is {:?})",
1028                (num_rounds_pending / 100) * 100,
1029                highest_completed_round + 1,
1030            );
1031        }
1032        self.metrics.set_num_rounds_pending(num_rounds_pending);
1033    }
1034
1035    fn admin_get_partial_signatures(&self, round: RandomnessRound, tx: oneshot::Sender<Vec<u8>>) {
1036        let shares = if let Some(shares) = self.dkg_output.as_ref().and_then(|d| d.shares.as_ref())
1037        {
1038            shares
1039        } else {
1040            let _ = tx.send(Vec::new()); // no error handling needed if receiver is already dropped
1041            return;
1042        };
1043
1044        let partial_sigs =
1045            ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), &round.signature_message());
1046        // no error handling needed if receiver is already dropped
1047        let _ = tx.send(bcs::to_bytes(&partial_sigs).expect("serialization should not fail"));
1048    }
1049
1050    fn admin_inject_partial_signatures(
1051        &mut self,
1052        authority_name: AuthorityName,
1053        round: RandomnessRound,
1054        sigs: Vec<RandomnessPartialSignature>,
1055    ) -> Result<()> {
1056        let peer_id = self
1057            .authority_info
1058            .get(&authority_name)
1059            .map(|(peer_id, _)| *peer_id)
1060            .ok_or(anyhow::anyhow!("unknown AuthorityName {authority_name:?}"))?;
1061        self.received_partial_sigs.insert((round, peer_id), sigs);
1062        self.maybe_aggregate_partial_signatures(self.epoch, round);
1063        Ok(())
1064    }
1065
1066    fn admin_inject_full_signature(
1067        &mut self,
1068        round: RandomnessRound,
1069        sig: RandomnessSignature,
1070    ) -> Result<()> {
1071        let vss_pk = {
1072            let Some(dkg_output) = &self.dkg_output else {
1073                return Err(anyhow::anyhow!(
1074                    "called admin_inject_full_signature before DKG completed"
1075                ));
1076            };
1077            &dkg_output.vss_pk
1078        };
1079
1080        ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
1081            .map_err(|e| anyhow::anyhow!("invalid full signature: {e:?}"))?;
1082
1083        self.process_valid_full_signature(self.epoch, round, sig);
1084        Ok(())
1085    }
1086}