sui_network/randomness/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use self::{auth::AllowedPeersUpdatable, metrics::Metrics};
5use anemo::PeerId;
6use anyhow::Result;
7use fastcrypto::groups::bls12381;
8use fastcrypto_tbls::{
9    dkg_v1,
10    nodes::PartyId,
11    tbls::ThresholdBls,
12    types::{ShareIndex, ThresholdBls12381MinSig},
13};
14use mysten_metrics::spawn_monitored_task;
15use mysten_network::anemo_ext::NetworkExt;
16use serde::{Deserialize, Serialize};
17use std::{
18    collections::{HashMap, HashSet, btree_map::BTreeMap},
19    ops::Bound,
20    sync::Arc,
21    time::{self, Duration},
22};
23use sui_config::p2p::RandomnessConfig;
24use sui_macros::fail_point_if;
25use sui_types::{
26    base_types::AuthorityName,
27    committee::EpochId,
28    crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
29};
30use tokio::sync::{
31    OnceCell, {mpsc, oneshot},
32};
33use tracing::{debug, error, info, instrument, warn};
34
35mod auth;
36mod builder;
37mod generated {
38    include!(concat!(env!("OUT_DIR"), "/sui.Randomness.rs"));
39}
40mod metrics;
41mod server;
42#[cfg(test)]
43mod tests;
44
45pub use builder::{Builder, UnstartedRandomness};
46pub use generated::{
47    randomness_client::RandomnessClient,
48    randomness_server::{Randomness, RandomnessServer},
49};
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub struct SendSignaturesRequest {
53    epoch: EpochId,
54    round: RandomnessRound,
55    // BCS-serialized `RandomnessPartialSignature` values. We store raw bytes here to enable
56    // defenses against too-large messages.
57    // The protocol requires the signatures to be ordered by share index (as provided by fastcrypto).
58    partial_sigs: Vec<Vec<u8>>,
59    // If peer already has a full signature available for the round, it's provided here in lieu
60    // of partial sigs.
61    sig: Option<RandomnessSignature>,
62}
63
64/// A handle to the Randomness network subsystem.
65///
66/// This handle can be cloned and shared. Once all copies of a Randomness system's Handle have been
67/// dropped, the Randomness system will be gracefully shutdown.
68#[derive(Clone, Debug)]
69pub struct Handle {
70    sender: mpsc::Sender<RandomnessMessage>,
71}
72
73impl Handle {
74    /// Transitions the Randomness system to a new epoch. Cancels all partial signature sends for
75    /// prior epochs.
76    pub fn update_epoch(
77        &self,
78        new_epoch: EpochId,
79        authority_info: HashMap<AuthorityName, (PeerId, PartyId)>,
80        dkg_output: dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
81        aggregation_threshold: u16,
82        recovered_last_completed_round: Option<RandomnessRound>, // set to None if not starting up mid-epoch
83    ) {
84        self.sender
85            .try_send(RandomnessMessage::UpdateEpoch(
86                new_epoch,
87                authority_info,
88                dkg_output,
89                aggregation_threshold,
90                recovered_last_completed_round,
91            ))
92            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
93    }
94
95    /// Begins transmitting partial signatures for the given epoch and round until completed.
96    pub fn send_partial_signatures(&self, epoch: EpochId, round: RandomnessRound) {
97        self.sender
98            .try_send(RandomnessMessage::SendPartialSignatures(epoch, round))
99            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
100    }
101
102    /// Records the given round as complete, stopping any partial signature sends.
103    pub fn complete_round(&self, epoch: EpochId, round: RandomnessRound) {
104        self.sender
105            .try_send(RandomnessMessage::CompleteRound(epoch, round))
106            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
107    }
108
109    /// Admin interface handler: generates partial signatures for the given round at the
110    /// current epoch.
111    pub fn admin_get_partial_signatures(
112        &self,
113        round: RandomnessRound,
114        tx: oneshot::Sender<Vec<u8>>,
115    ) {
116        self.sender
117            .try_send(RandomnessMessage::AdminGetPartialSignatures(round, tx))
118            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
119    }
120
121    /// Admin interface handler: injects partial signatures for the given round at the
122    /// current epoch, skipping validity checks.
123    pub fn admin_inject_partial_signatures(
124        &self,
125        authority_name: AuthorityName,
126        round: RandomnessRound,
127        sigs: Vec<RandomnessPartialSignature>,
128        result_channel: oneshot::Sender<Result<()>>,
129    ) {
130        self.sender
131            .try_send(RandomnessMessage::AdminInjectPartialSignatures(
132                authority_name,
133                round,
134                sigs,
135                result_channel,
136            ))
137            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
138    }
139
140    /// Admin interface handler: injects full signature for the given round at the
141    /// current epoch, skipping validity checks.
142    pub fn admin_inject_full_signature(
143        &self,
144        round: RandomnessRound,
145        sig: RandomnessSignature,
146        result_channel: oneshot::Sender<Result<()>>,
147    ) {
148        self.sender
149            .try_send(RandomnessMessage::AdminInjectFullSignature(
150                round,
151                sig,
152                result_channel,
153            ))
154            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
155    }
156
157    // For testing.
158    pub fn new_stub() -> Self {
159        let (sender, mut receiver) = mpsc::channel(1);
160        // Keep receiver open until all senders are closed.
161        tokio::spawn(async move {
162            loop {
163                tokio::select! {
164                    m = receiver.recv() => {
165                        if m.is_none() {
166                            break;
167                        }
168                    },
169                }
170            }
171        });
172        Self { sender }
173    }
174}
175
176#[derive(Debug)]
177enum RandomnessMessage {
178    UpdateEpoch(
179        EpochId,
180        HashMap<AuthorityName, (PeerId, PartyId)>,
181        dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
182        u16,                     // aggregation_threshold
183        Option<RandomnessRound>, // recovered_highest_completed_round
184    ),
185    SendPartialSignatures(EpochId, RandomnessRound),
186    CompleteRound(EpochId, RandomnessRound),
187    ReceiveSignatures(
188        PeerId,
189        EpochId,
190        RandomnessRound,
191        Vec<Vec<u8>>,
192        Option<RandomnessSignature>,
193    ),
194    MaybeIgnoreByzantinePeer(EpochId, PeerId),
195    AdminGetPartialSignatures(RandomnessRound, oneshot::Sender<Vec<u8>>),
196    AdminInjectPartialSignatures(
197        AuthorityName,
198        RandomnessRound,
199        Vec<RandomnessPartialSignature>,
200        oneshot::Sender<Result<()>>,
201    ),
202    AdminInjectFullSignature(
203        RandomnessRound,
204        RandomnessSignature,
205        oneshot::Sender<Result<()>>,
206    ),
207}
208
209struct RandomnessEventLoop {
210    name: AuthorityName,
211    config: RandomnessConfig,
212    mailbox: mpsc::Receiver<RandomnessMessage>,
213    mailbox_sender: mpsc::WeakSender<RandomnessMessage>,
214    network: anemo::Network,
215    allowed_peers: AllowedPeersUpdatable,
216    allowed_peers_set: HashSet<PeerId>,
217    metrics: Metrics,
218    randomness_tx: mpsc::Sender<(EpochId, RandomnessRound, Vec<u8>)>,
219
220    epoch: EpochId,
221    authority_info: Arc<HashMap<AuthorityName, (PeerId, PartyId)>>,
222    peer_share_ids: Option<HashMap<PeerId, Vec<ShareIndex>>>,
223    blocked_share_id_count: usize,
224    dkg_output: Option<dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>>,
225    aggregation_threshold: u16,
226    highest_requested_round: BTreeMap<EpochId, RandomnessRound>,
227    send_tasks: BTreeMap<
228        RandomnessRound,
229        (
230            tokio::task::JoinHandle<()>,
231            Arc<OnceCell<RandomnessSignature>>,
232        ),
233    >,
234    round_request_time: BTreeMap<(EpochId, RandomnessRound), time::Instant>,
235    future_epoch_partial_sigs: BTreeMap<(EpochId, RandomnessRound, PeerId), Vec<Vec<u8>>>,
236    received_partial_sigs: BTreeMap<(RandomnessRound, PeerId), Vec<RandomnessPartialSignature>>,
237    completed_sigs: BTreeMap<RandomnessRound, RandomnessSignature>,
238    highest_completed_round: BTreeMap<EpochId, RandomnessRound>,
239}
240
241impl RandomnessEventLoop {
242    pub async fn start(mut self) {
243        info!("Randomness network event loop started");
244
245        loop {
246            tokio::select! {
247                maybe_message = self.mailbox.recv() => {
248                    // Once all handles to our mailbox have been dropped this
249                    // will yield `None` and we can terminate the event loop.
250                    if let Some(message) = maybe_message {
251                        self.handle_message(message);
252                    } else {
253                        break;
254                    }
255                },
256            }
257        }
258
259        info!("Randomness network event loop ended");
260    }
261
262    fn handle_message(&mut self, message: RandomnessMessage) {
263        match message {
264            RandomnessMessage::UpdateEpoch(
265                epoch,
266                authority_info,
267                dkg_output,
268                aggregation_threshold,
269                recovered_highest_completed_round,
270            ) => {
271                if let Err(e) = self.update_epoch(
272                    epoch,
273                    authority_info,
274                    dkg_output,
275                    aggregation_threshold,
276                    recovered_highest_completed_round,
277                ) {
278                    error!("BUG: failed to update epoch in RandomnessEventLoop: {e:?}");
279                }
280            }
281            RandomnessMessage::SendPartialSignatures(epoch, round) => {
282                self.send_partial_signatures(epoch, round)
283            }
284            RandomnessMessage::CompleteRound(epoch, round) => self.complete_round(epoch, round),
285            RandomnessMessage::ReceiveSignatures(peer_id, epoch, round, partial_sigs, sig) => {
286                if let Some(sig) = sig {
287                    self.receive_full_signature(peer_id, epoch, round, sig)
288                } else {
289                    self.receive_partial_signatures(peer_id, epoch, round, partial_sigs)
290                }
291            }
292            RandomnessMessage::MaybeIgnoreByzantinePeer(epoch, peer_id) => {
293                self.maybe_ignore_byzantine_peer(epoch, peer_id)
294            }
295            RandomnessMessage::AdminGetPartialSignatures(round, tx) => {
296                self.admin_get_partial_signatures(round, tx)
297            }
298            RandomnessMessage::AdminInjectPartialSignatures(
299                authority_name,
300                round,
301                sigs,
302                result_channel,
303            ) => {
304                let _ = result_channel.send(self.admin_inject_partial_signatures(
305                    authority_name,
306                    round,
307                    sigs,
308                ));
309            }
310            RandomnessMessage::AdminInjectFullSignature(round, sig, result_channel) => {
311                let _ = result_channel.send(self.admin_inject_full_signature(round, sig));
312            }
313        }
314    }
315
316    #[instrument(level = "debug", skip_all, fields(?new_epoch))]
317    fn update_epoch(
318        &mut self,
319        new_epoch: EpochId,
320        authority_info: HashMap<AuthorityName, (PeerId, PartyId)>,
321        dkg_output: dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
322        aggregation_threshold: u16,
323        recovered_highest_completed_round: Option<RandomnessRound>,
324    ) -> Result<()> {
325        assert!(self.dkg_output.is_none() || new_epoch > self.epoch);
326
327        debug!("updating randomness network loop to new epoch");
328
329        self.peer_share_ids = Some(authority_info.iter().try_fold(
330            HashMap::new(),
331            |mut acc, (_name, (peer_id, party_id))| -> Result<_> {
332                let ids = dkg_output
333                    .nodes
334                    .share_ids_of(*party_id)
335                    .expect("party_id should be valid");
336                acc.insert(*peer_id, ids);
337                Ok(acc)
338            },
339        )?);
340        self.allowed_peers_set = authority_info
341            .values()
342            .map(|(peer_id, _)| *peer_id)
343            .collect();
344        self.allowed_peers
345            .update(Arc::new(self.allowed_peers_set.clone()));
346        self.epoch = new_epoch;
347        self.authority_info = Arc::new(authority_info);
348        self.dkg_output = Some(dkg_output);
349        self.aggregation_threshold = aggregation_threshold;
350        if let Some(round) = recovered_highest_completed_round {
351            self.highest_completed_round
352                .entry(new_epoch)
353                .and_modify(|r| *r = std::cmp::max(*r, round))
354                .or_insert(round);
355        }
356        for (_, (task, _)) in std::mem::take(&mut self.send_tasks) {
357            task.abort();
358        }
359        self.metrics.set_epoch(new_epoch);
360
361        // Throw away info from old epochs.
362        self.highest_requested_round = self.highest_requested_round.split_off(&new_epoch);
363        self.round_request_time = self
364            .round_request_time
365            .split_off(&(new_epoch, RandomnessRound(0)));
366        self.received_partial_sigs.clear();
367        self.completed_sigs.clear();
368        self.highest_completed_round = self.highest_completed_round.split_off(&new_epoch);
369
370        // Start any pending tasks for the new epoch.
371        self.maybe_start_pending_tasks();
372
373        // Aggregate any sigs received early from the new epoch.
374        // (We can't call `maybe_aggregate_partial_signatures` directly while iterating,
375        // because it takes `&mut self`, so we store in a Vec first.)
376        for ((epoch, round, peer_id), sig_bytes) in
377            std::mem::take(&mut self.future_epoch_partial_sigs)
378        {
379            // We can fully validate these now that we have current epoch DKG output.
380            self.receive_partial_signatures(peer_id, epoch, round, sig_bytes);
381        }
382        let rounds_to_aggregate: Vec<_> =
383            self.received_partial_sigs.keys().map(|(r, _)| *r).collect();
384        for round in rounds_to_aggregate {
385            self.maybe_aggregate_partial_signatures(new_epoch, round);
386        }
387
388        Ok(())
389    }
390
391    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
392    fn send_partial_signatures(&mut self, epoch: EpochId, round: RandomnessRound) {
393        if epoch < self.epoch {
394            error!(
395                "BUG: skipping sending partial sigs, we are already up to epoch {}",
396                self.epoch
397            );
398            debug_assert!(
399                false,
400                "skipping sending partial sigs, we are already up to higher epoch"
401            );
402            return;
403        }
404        if epoch == self.epoch
405            && let Some(highest_completed_round) = self.highest_completed_round.get(&epoch)
406            && round <= *highest_completed_round
407        {
408            info!("skipping sending partial sigs, we already have completed this round");
409            return;
410        }
411
412        self.highest_requested_round
413            .entry(epoch)
414            .and_modify(|r| *r = std::cmp::max(*r, round))
415            .or_insert(round);
416        self.round_request_time
417            .insert((epoch, round), time::Instant::now());
418        self.maybe_start_pending_tasks();
419    }
420
421    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
422    fn complete_round(&mut self, epoch: EpochId, round: RandomnessRound) {
423        debug!("completing randomness round");
424        let new_highest_round = *self
425            .highest_completed_round
426            .entry(epoch)
427            .and_modify(|r| *r = std::cmp::max(*r, round))
428            .or_insert(round);
429        if round != new_highest_round {
430            // This round completion came out of order, and we're already ahead. Nothing more
431            // to do in that case.
432            return;
433        }
434
435        self.round_request_time = self.round_request_time.split_off(&(epoch, round + 1));
436
437        if epoch == self.epoch {
438            self.remove_partial_sigs_in_range((
439                Bound::Included((RandomnessRound(0), PeerId([0; 32]))),
440                Bound::Excluded((round + 1, PeerId([0; 32]))),
441            ));
442            self.completed_sigs = self.completed_sigs.split_off(&(round + 1));
443            for (_, (task, _)) in self.send_tasks.iter().take_while(|(r, _)| **r <= round) {
444                task.abort();
445            }
446            self.send_tasks = self.send_tasks.split_off(&(round + 1));
447            self.maybe_start_pending_tasks();
448        }
449
450        self.update_rounds_pending_metric();
451    }
452
453    #[instrument(level = "debug", skip_all, fields(?peer_id, ?epoch, ?round))]
454    fn receive_partial_signatures(
455        &mut self,
456        peer_id: PeerId,
457        epoch: EpochId,
458        round: RandomnessRound,
459        sig_bytes: Vec<Vec<u8>>,
460    ) {
461        // Basic validity checks.
462        if epoch < self.epoch {
463            debug!(
464                "skipping received partial sigs, we are already up to epoch {}",
465                self.epoch
466            );
467            return;
468        }
469        if epoch > self.epoch + 1 {
470            debug!(
471                "skipping received partial sigs, we are still on epoch {}",
472                self.epoch
473            );
474            return;
475        }
476        if epoch == self.epoch && self.completed_sigs.contains_key(&round) {
477            debug!("skipping received partial sigs, we already have completed this sig");
478            return;
479        }
480        let highest_completed_round = self.highest_completed_round.get(&epoch).copied();
481        if let Some(highest_completed_round) = &highest_completed_round
482            && *highest_completed_round >= round
483        {
484            debug!("skipping received partial sigs, we already have completed this round");
485            return;
486        }
487
488        // If sigs are for a future epoch, we can't fully verify them without DKG output.
489        // Save them for later use.
490        if epoch != self.epoch || self.peer_share_ids.is_none() {
491            if round.0 >= self.config.max_partial_sigs_rounds_ahead() {
492                debug!("skipping received partial sigs for future epoch, round too far ahead",);
493                return;
494            }
495
496            debug!("saving partial sigs from future epoch for later use");
497            self.future_epoch_partial_sigs
498                .insert((epoch, round, peer_id), sig_bytes);
499            return;
500        }
501
502        // Verify shape of sigs matches what we expect for the peer.
503        let peer_share_ids = self.peer_share_ids.as_ref().expect("checked above");
504        let expected_share_ids = if let Some(expected_share_ids) = peer_share_ids.get(&peer_id) {
505            expected_share_ids
506        } else {
507            debug!("received partial sigs from unknown peer");
508            return;
509        };
510        if sig_bytes.len() != expected_share_ids.len() as usize {
511            warn!(
512                "received partial sigs with wrong share ids count: expected {}, got {}",
513                expected_share_ids.len(),
514                sig_bytes.len(),
515            );
516            return;
517        }
518
519        // Accept partial signatures up to `max_partial_sigs_rounds_ahead` past the round of the
520        // last completed signature, or the highest completed round, whichever is greater.
521        let last_completed_signature = self.completed_sigs.last_key_value().map(|(r, _)| *r);
522        let last_completed_round = std::cmp::max(last_completed_signature, highest_completed_round)
523            .unwrap_or(RandomnessRound(0));
524        if round.0
525            >= last_completed_round
526                .0
527                .saturating_add(self.config.max_partial_sigs_rounds_ahead())
528        {
529            debug!(
530                "skipping received partial sigs, most recent round we completed was only {last_completed_round}",
531            );
532            return;
533        }
534
535        // Deserialize the partial sigs.
536        let partial_sigs =
537            match sig_bytes
538                .iter()
539                .try_fold(Vec::new(), |mut acc, bytes| -> Result<_> {
540                    let sig: RandomnessPartialSignature = bcs::from_bytes(bytes)?;
541                    acc.push(sig);
542                    Ok(acc)
543                }) {
544                Ok(partial_sigs) => partial_sigs,
545                Err(e) => {
546                    warn!("failed to deserialize partial sigs: {e:?}");
547                    return;
548                }
549            };
550        // Verify we received the expected share IDs (to protect against a validator that sends
551        // valid signatures of other peers which will be successfully verified below).
552        let received_share_ids = partial_sigs.iter().map(|s| s.index);
553        if received_share_ids
554            .zip(expected_share_ids.iter())
555            .any(|(a, b)| a != *b)
556        {
557            let received_share_ids = partial_sigs.iter().map(|s| s.index).collect::<Vec<_>>();
558            warn!(
559                "received partial sigs with wrong share ids: expected {expected_share_ids:?}, received {received_share_ids:?}"
560            );
561            return;
562        }
563
564        // We passed all the checks, save the partial sigs.
565        debug!("recording received partial signatures");
566        self.received_partial_sigs
567            .insert((round, peer_id), partial_sigs);
568
569        self.maybe_aggregate_partial_signatures(epoch, round);
570    }
571
572    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
573    fn maybe_aggregate_partial_signatures(&mut self, epoch: EpochId, round: RandomnessRound) {
574        if let Some(highest_completed_round) = self.highest_completed_round.get(&epoch)
575            && round <= *highest_completed_round
576        {
577            info!("skipping aggregation for already-completed round");
578            return;
579        }
580
581        let highest_requested_round = self.highest_requested_round.get(&epoch);
582        if highest_requested_round.is_none() || round > *highest_requested_round.unwrap() {
583            // We have to wait here, because even if we have enough information from other nodes
584            // to complete the signature, local shared object versions are not set until consensus
585            // finishes processing the corresponding commit. This function will be called again
586            // after maybe_start_pending_tasks begins this round locally.
587            debug!(
588                "waiting to aggregate randomness partial signatures until local consensus catches up"
589            );
590            return;
591        }
592
593        if epoch != self.epoch {
594            debug!(
595                "waiting to aggregate randomness partial signatures until DKG completes for epoch"
596            );
597            return;
598        }
599
600        if self.completed_sigs.contains_key(&round) {
601            info!("skipping aggregation for already-completed signature");
602            return;
603        }
604
605        let vss_pk = {
606            let Some(dkg_output) = &self.dkg_output else {
607                debug!("called maybe_aggregate_partial_signatures before DKG completed");
608                return;
609            };
610            &dkg_output.vss_pk
611        };
612
613        let sig_bounds = (
614            Bound::Included((round, PeerId([0; 32]))),
615            Bound::Excluded((round + 1, PeerId([0; 32]))),
616        );
617
618        // If we have enough partial signatures, aggregate them.
619        let sig_range = self
620            .received_partial_sigs
621            .range(sig_bounds)
622            .flat_map(|(_, sigs)| sigs);
623        let mut sig =
624            match ThresholdBls12381MinSig::aggregate(self.aggregation_threshold, sig_range) {
625                Ok(sig) => sig,
626                Err(fastcrypto::error::FastCryptoError::NotEnoughInputs) => return, // wait for more input
627                Err(e) => {
628                    error!("error while aggregating randomness partial signatures: {e:?}");
629                    return;
630                }
631            };
632
633        // Try to verify the aggregated signature all at once. (Should work in the happy path.)
634        if ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig).is_err() {
635            // If verifiation fails, some of the inputs must be invalid. We have to go through
636            // one-by-one to find which.
637            // TODO: add test for individual sig verification.
638            self.received_partial_sigs
639                .retain(|&(r, peer_id), partial_sigs| {
640                    if round != r {
641                        return true;
642                    }
643                    if ThresholdBls12381MinSig::partial_verify_batch(
644                        vss_pk,
645                        &round.signature_message(),
646                        partial_sigs.iter(),
647                        &mut rand::thread_rng(),
648                    )
649                    .is_err()
650                    {
651                        warn!(
652                            "received invalid partial signatures from possibly-Byzantine peer {peer_id}"
653                        );
654                        if let Some(sender) = self.mailbox_sender.upgrade() {
655                            sender.try_send(RandomnessMessage::MaybeIgnoreByzantinePeer(
656                                epoch,
657                                peer_id,
658                            ))
659                            .expect("RandomnessEventLoop mailbox should not overflow or be closed");
660                        }
661                        return false;
662                    }
663                    true
664                });
665            let sig_range = self
666                .received_partial_sigs
667                .range(sig_bounds)
668                .flat_map(|(_, sigs)| sigs);
669            sig = match ThresholdBls12381MinSig::aggregate(self.aggregation_threshold, sig_range) {
670                Ok(sig) => sig,
671                Err(fastcrypto::error::FastCryptoError::NotEnoughInputs) => return, // wait for more input
672                Err(e) => {
673                    error!("error while aggregating randomness partial signatures: {e:?}");
674                    return;
675                }
676            };
677            if let Err(e) =
678                ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
679            {
680                error!(
681                    "error while verifying randomness partial signatures after removing invalid partials: {e:?}"
682                );
683                debug_assert!(
684                    false,
685                    "error while verifying randomness partial signatures after removing invalid partials"
686                );
687                return;
688            }
689        }
690
691        debug!("successfully generated randomness full signature");
692        self.process_valid_full_signature(epoch, round, sig);
693    }
694
695    #[instrument(level = "debug", skip_all, fields(?peer_id, ?epoch, ?round))]
696    fn receive_full_signature(
697        &mut self,
698        peer_id: PeerId,
699        epoch: EpochId,
700        round: RandomnessRound,
701        sig: RandomnessSignature,
702    ) {
703        let vss_pk = {
704            let Some(dkg_output) = &self.dkg_output else {
705                debug!("called receive_full_signature before DKG completed");
706                return;
707            };
708            &dkg_output.vss_pk
709        };
710
711        // Basic validity checks.
712        if epoch != self.epoch {
713            debug!("skipping received full sig, we are on epoch {}", self.epoch);
714            return;
715        }
716        if self.completed_sigs.contains_key(&round) {
717            debug!("skipping received full sigs, we already have completed this sig");
718            return;
719        }
720        let highest_completed_round = self.highest_completed_round.get(&epoch).copied();
721        if let Some(highest_completed_round) = &highest_completed_round
722            && *highest_completed_round >= round
723        {
724            debug!("skipping received full sig, we already have completed this round");
725            return;
726        }
727
728        let highest_requested_round = self.highest_requested_round.get(&epoch);
729        if highest_requested_round.is_none() || round > *highest_requested_round.unwrap() {
730            // Wait for local consensus to catch up if necessary.
731            debug!(
732                "skipping received full signature, local consensus is not caught up to its round"
733            );
734            return;
735        }
736
737        if let Err(e) =
738            ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
739        {
740            info!("received invalid full signature from peer {peer_id}: {e:?}");
741            if let Some(sender) = self.mailbox_sender.upgrade() {
742                sender
743                    .try_send(RandomnessMessage::MaybeIgnoreByzantinePeer(epoch, peer_id))
744                    .expect("RandomnessEventLoop mailbox should not overflow or be closed");
745            }
746            return;
747        }
748
749        debug!("received valid randomness full signature");
750        self.process_valid_full_signature(epoch, round, sig);
751    }
752
753    fn process_valid_full_signature(
754        &mut self,
755        epoch: EpochId,
756        round: RandomnessRound,
757        sig: RandomnessSignature,
758    ) {
759        assert_eq!(epoch, self.epoch);
760
761        if let Some((_, full_sig_cell)) = self.send_tasks.get(&round) {
762            full_sig_cell
763                .set(sig)
764                .expect("full signature should never be processed twice");
765        }
766        self.completed_sigs.insert(round, sig);
767        self.remove_partial_sigs_in_range((
768            Bound::Included((round, PeerId([0; 32]))),
769            Bound::Excluded((round + 1, PeerId([0; 32]))),
770        ));
771        self.metrics.record_completed_round(round);
772        if let Some(start_time) = self.round_request_time.get(&(epoch, round))
773            && let Some(metric) = self.metrics.round_generation_latency_metric()
774        {
775            metric.observe(start_time.elapsed().as_secs_f64());
776        }
777
778        let sig_bytes = bcs::to_bytes(&sig).expect("signature serialization should not fail");
779        self.randomness_tx
780            .try_send((epoch, round, sig_bytes))
781            .expect("RandomnessRoundReceiver mailbox should not overflow or be closed");
782    }
783
784    fn maybe_ignore_byzantine_peer(&mut self, epoch: EpochId, peer_id: PeerId) {
785        if epoch != self.epoch {
786            return; // make sure we're still on the same epoch
787        }
788        let Some(dkg_output) = &self.dkg_output else {
789            return; // can't ignore a peer if we haven't finished DKG
790        };
791        if !self.allowed_peers_set.contains(&peer_id) {
792            return; // peer is already disallowed
793        }
794        let Some(peer_share_ids) = &self.peer_share_ids else {
795            return; // can't ignore a peer if we haven't finished DKG
796        };
797        let Some(peer_shares) = peer_share_ids.get(&peer_id) else {
798            warn!("can't ignore unknown byzantine peer {peer_id:?}");
799            return;
800        };
801        let max_ignored_shares = (self.config.max_ignored_peer_weight_factor()
802            * (dkg_output.nodes.total_weight() as f64)) as usize;
803        if self.blocked_share_id_count + peer_shares.len() > max_ignored_shares {
804            warn!(
805                "ignoring byzantine peer {peer_id:?} with {} shares would exceed max ignored peer weight {max_ignored_shares}",
806                peer_shares.len()
807            );
808            return;
809        }
810
811        warn!(
812            "ignoring byzantine peer {peer_id:?} with {} shares",
813            peer_shares.len()
814        );
815        self.blocked_share_id_count += peer_shares.len();
816        self.allowed_peers_set.remove(&peer_id);
817        self.allowed_peers
818            .update(Arc::new(self.allowed_peers_set.clone()));
819        self.metrics.inc_num_ignored_byzantine_peers();
820    }
821
822    fn maybe_start_pending_tasks(&mut self) {
823        let dkg_output = if let Some(dkg_output) = &self.dkg_output {
824            dkg_output
825        } else {
826            return; // wait for DKG
827        };
828        let shares = if let Some(shares) = &dkg_output.shares {
829            shares
830        } else {
831            return; // can't participate in randomness generation without shares
832        };
833        let highest_requested_round =
834            if let Some(highest_requested_round) = self.highest_requested_round.get(&self.epoch) {
835                highest_requested_round
836            } else {
837                return; // no rounds to start
838            };
839        // Begin from the next round after the most recent one we've started (or, if none are running,
840        // after the highest completed round in the epoch).
841        let start_round = std::cmp::max(
842            if let Some(highest_completed_round) = self.highest_completed_round.get(&self.epoch) {
843                highest_completed_round.checked_add(1).unwrap()
844            } else {
845                RandomnessRound(0)
846            },
847            self.send_tasks
848                .last_key_value()
849                .map(|(r, _)| r.checked_add(1).unwrap())
850                .unwrap_or(RandomnessRound(0)),
851        );
852
853        let mut rounds_to_aggregate = Vec::new();
854        for round in start_round.0..=highest_requested_round.0 {
855            let round = RandomnessRound(round);
856
857            if self.send_tasks.len() >= self.config.max_partial_sigs_concurrent_sends() {
858                break; // limit concurrent tasks
859            }
860
861            let full_sig_cell = Arc::new(OnceCell::new());
862            self.send_tasks.entry(round).or_insert_with(|| {
863                let name = self.name;
864                let network = self.network.clone();
865                let retry_interval = self.config.partial_signature_retry_interval();
866                let metrics = self.metrics.clone();
867                let authority_info = self.authority_info.clone();
868                let epoch = self.epoch;
869                let partial_sigs = ThresholdBls12381MinSig::partial_sign_batch(
870                    shares.iter(),
871                    &round.signature_message(),
872                );
873                let full_sig_cell_clone = full_sig_cell.clone();
874
875                // Record own partial sigs.
876                if !self.completed_sigs.contains_key(&round) {
877                    self.received_partial_sigs
878                        .insert((round, self.network.peer_id()), partial_sigs.clone());
879                    rounds_to_aggregate.push((epoch, round));
880                }
881
882                debug!("sending partial sigs for epoch {epoch}, round {round}");
883                (
884                    spawn_monitored_task!(RandomnessEventLoop::send_signatures_task(
885                        name,
886                        network,
887                        retry_interval,
888                        metrics,
889                        authority_info,
890                        epoch,
891                        round,
892                        partial_sigs,
893                        full_sig_cell_clone,
894                    )),
895                    full_sig_cell,
896                )
897            });
898        }
899
900        self.update_rounds_pending_metric();
901
902        // After starting a round, we have generated our own partial sigs. Check if that's
903        // enough for us to aggregate already.
904        for (epoch, round) in rounds_to_aggregate {
905            self.maybe_aggregate_partial_signatures(epoch, round);
906        }
907    }
908
909    #[allow(clippy::type_complexity)]
910    fn remove_partial_sigs_in_range(
911        &mut self,
912        range: (
913            Bound<(RandomnessRound, PeerId)>,
914            Bound<(RandomnessRound, PeerId)>,
915        ),
916    ) {
917        let keys_to_remove: Vec<_> = self
918            .received_partial_sigs
919            .range(range)
920            .map(|(key, _)| *key)
921            .collect();
922        for key in keys_to_remove {
923            // Have to remove keys one-by-one because BTreeMap does not support range-removal.
924            self.received_partial_sigs.remove(&key);
925        }
926    }
927
928    async fn send_signatures_task(
929        name: AuthorityName,
930        network: anemo::Network,
931        retry_interval: Duration,
932        metrics: Metrics,
933        authority_info: Arc<HashMap<AuthorityName, (PeerId, PartyId)>>,
934        epoch: EpochId,
935        round: RandomnessRound,
936        partial_sigs: Vec<RandomnessPartialSignature>,
937        full_sig: Arc<OnceCell<RandomnessSignature>>,
938    ) {
939        // For simtests, we may test not sending partial signatures.
940        #[allow(unused_mut)]
941        let mut fail_point_skip_sending = false;
942        fail_point_if!("rb-send-partial-signatures", || {
943            fail_point_skip_sending = true;
944        });
945        if fail_point_skip_sending {
946            warn!("skipping sending partial sigs due to simtest fail point");
947            return;
948        }
949
950        let _metrics_guard = metrics
951            .round_observation_latency_metric()
952            .map(|metric| metric.start_timer());
953
954        let peers: HashMap<_, _> = authority_info
955            .iter()
956            .map(|(name, (peer_id, _party_id))| (name, network.waiting_peer(*peer_id)))
957            .collect();
958        let partial_sigs: Vec<_> = partial_sigs
959            .iter()
960            .map(|sig| bcs::to_bytes(sig).expect("message serialization should not fail"))
961            .collect();
962
963        loop {
964            let mut requests = Vec::new();
965            for (peer_name, peer) in &peers {
966                if name == **peer_name {
967                    continue; // don't send partial sigs to self
968                }
969                let mut client = RandomnessClient::new(peer.clone());
970                const SEND_PARTIAL_SIGNATURES_TIMEOUT: Duration = Duration::from_secs(10);
971                let full_sig = full_sig.get().cloned();
972                let request = anemo::Request::new(SendSignaturesRequest {
973                    epoch,
974                    round,
975                    partial_sigs: if full_sig.is_none() {
976                        partial_sigs.clone()
977                    } else {
978                        Vec::new()
979                    },
980                    sig: full_sig,
981                })
982                .with_timeout(SEND_PARTIAL_SIGNATURES_TIMEOUT);
983                requests.push(async move {
984                    let result = client.send_signatures(request).await;
985                    if let Err(_error) = result {
986                        // TODO: add Display impl to anemo::rpc::Status, log it here
987                        debug!("failed to send partial signatures to {peer_name}");
988                    }
989                });
990            }
991
992            // Process all requests.
993            futures::future::join_all(requests).await;
994
995            // Keep retrying send to all peers until task is aborted via external message.
996            tokio::time::sleep(retry_interval).await;
997        }
998    }
999
1000    fn update_rounds_pending_metric(&self) {
1001        let highest_requested_round = self
1002            .highest_requested_round
1003            .get(&self.epoch)
1004            .map(|r| r.0)
1005            .unwrap_or(0);
1006        let highest_completed_round = self
1007            .highest_completed_round
1008            .get(&self.epoch)
1009            .map(|r| r.0)
1010            .unwrap_or(0);
1011        let num_rounds_pending =
1012            highest_requested_round.saturating_sub(highest_completed_round) as i64;
1013        let prev_value = self.metrics.num_rounds_pending().unwrap_or_default();
1014        if num_rounds_pending / 100 > prev_value / 100 {
1015            warn!(
1016                // Recording multiples of 100 so tests can match on the log message.
1017                "RandomnessEventLoop randomness generation backlog: over {} rounds are pending (oldest is {:?})",
1018                (num_rounds_pending / 100) * 100,
1019                highest_completed_round + 1,
1020            );
1021        }
1022        self.metrics.set_num_rounds_pending(num_rounds_pending);
1023    }
1024
1025    fn admin_get_partial_signatures(&self, round: RandomnessRound, tx: oneshot::Sender<Vec<u8>>) {
1026        let shares = if let Some(shares) = self.dkg_output.as_ref().and_then(|d| d.shares.as_ref())
1027        {
1028            shares
1029        } else {
1030            let _ = tx.send(Vec::new()); // no error handling needed if receiver is already dropped
1031            return;
1032        };
1033
1034        let partial_sigs =
1035            ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), &round.signature_message());
1036        // no error handling needed if receiver is already dropped
1037        let _ = tx.send(bcs::to_bytes(&partial_sigs).expect("serialization should not fail"));
1038    }
1039
1040    fn admin_inject_partial_signatures(
1041        &mut self,
1042        authority_name: AuthorityName,
1043        round: RandomnessRound,
1044        sigs: Vec<RandomnessPartialSignature>,
1045    ) -> Result<()> {
1046        let peer_id = self
1047            .authority_info
1048            .get(&authority_name)
1049            .map(|(peer_id, _)| *peer_id)
1050            .ok_or(anyhow::anyhow!("unknown AuthorityName {authority_name:?}"))?;
1051        self.received_partial_sigs.insert((round, peer_id), sigs);
1052        self.maybe_aggregate_partial_signatures(self.epoch, round);
1053        Ok(())
1054    }
1055
1056    fn admin_inject_full_signature(
1057        &mut self,
1058        round: RandomnessRound,
1059        sig: RandomnessSignature,
1060    ) -> Result<()> {
1061        let vss_pk = {
1062            let Some(dkg_output) = &self.dkg_output else {
1063                return Err(anyhow::anyhow!(
1064                    "called admin_inject_full_signature before DKG completed"
1065                ));
1066            };
1067            &dkg_output.vss_pk
1068        };
1069
1070        ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
1071            .map_err(|e| anyhow::anyhow!("invalid full signature: {e:?}"))?;
1072
1073        self.process_valid_full_signature(self.epoch, round, sig);
1074        Ok(())
1075    }
1076}