sui_network/randomness/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use self::{auth::AllowedPeersUpdatable, metrics::Metrics};
5use anemo::PeerId;
6use anyhow::Result;
7use fastcrypto::groups::bls12381;
8use fastcrypto_tbls::{
9    dkg_v1,
10    nodes::PartyId,
11    tbls::ThresholdBls,
12    types::{ShareIndex, ThresholdBls12381MinSig},
13};
14use mysten_common::ZipDebugEqIteratorExt;
15use mysten_metrics::spawn_monitored_task;
16use mysten_network::anemo_ext::NetworkExt;
17use serde::{Deserialize, Serialize};
18use std::{
19    collections::{HashMap, HashSet, btree_map::BTreeMap},
20    ops::Bound,
21    sync::Arc,
22    time::{self, Duration},
23};
24use sui_config::p2p::RandomnessConfig;
25use sui_macros::fail_point_if;
26use sui_types::{
27    base_types::AuthorityName,
28    committee::EpochId,
29    crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
30};
31use tokio::sync::{
32    OnceCell, {mpsc, oneshot},
33};
34use tracing::{debug, error, info, instrument, warn};
35
36mod auth;
37mod builder;
38mod generated {
39    include!(concat!(env!("OUT_DIR"), "/sui.Randomness.rs"));
40}
41mod metrics;
42mod server;
43#[cfg(test)]
44mod tests;
45
46pub use builder::{Builder, UnstartedRandomness};
47pub use generated::{
48    randomness_client::RandomnessClient,
49    randomness_server::{Randomness, RandomnessServer},
50};
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct SendSignaturesRequest {
54    epoch: EpochId,
55    round: RandomnessRound,
56    // BCS-serialized `RandomnessPartialSignature` values. We store raw bytes here to enable
57    // defenses against too-large messages.
58    // The protocol requires the signatures to be ordered by share index (as provided by fastcrypto).
59    partial_sigs: Vec<Vec<u8>>,
60    // If peer already has a full signature available for the round, it's provided here in lieu
61    // of partial sigs.
62    sig: Option<RandomnessSignature>,
63}
64
65/// A handle to the Randomness network subsystem.
66///
67/// This handle can be cloned and shared. Once all copies of a Randomness system's Handle have been
68/// dropped, the Randomness system will be gracefully shutdown.
69#[derive(Clone, Debug)]
70pub struct Handle {
71    sender: mpsc::Sender<RandomnessMessage>,
72}
73
74impl Handle {
75    /// Transitions the Randomness system to a new epoch. Cancels all partial signature sends for
76    /// prior epochs.
77    pub fn update_epoch(
78        &self,
79        new_epoch: EpochId,
80        authority_info: HashMap<AuthorityName, (PeerId, PartyId)>,
81        dkg_output: dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
82        aggregation_threshold: u16,
83        recovered_last_completed_round: Option<RandomnessRound>, // set to None if not starting up mid-epoch
84    ) {
85        self.sender
86            .try_send(RandomnessMessage::UpdateEpoch(
87                new_epoch,
88                authority_info,
89                dkg_output,
90                aggregation_threshold,
91                recovered_last_completed_round,
92            ))
93            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
94    }
95
96    /// Begins transmitting partial signatures for the given epoch and round until completed.
97    pub fn send_partial_signatures(&self, epoch: EpochId, round: RandomnessRound) {
98        self.sender
99            .try_send(RandomnessMessage::SendPartialSignatures(epoch, round))
100            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
101    }
102
103    /// Records the given round as complete, stopping any partial signature sends.
104    pub fn complete_round(&self, epoch: EpochId, round: RandomnessRound) {
105        self.sender
106            .try_send(RandomnessMessage::CompleteRound(epoch, round))
107            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
108    }
109
110    /// Admin interface handler: generates partial signatures for the given round at the
111    /// current epoch.
112    pub fn admin_get_partial_signatures(
113        &self,
114        round: RandomnessRound,
115        tx: oneshot::Sender<Vec<u8>>,
116    ) {
117        self.sender
118            .try_send(RandomnessMessage::AdminGetPartialSignatures(round, tx))
119            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
120    }
121
122    /// Admin interface handler: injects partial signatures for the given round at the
123    /// current epoch, skipping validity checks.
124    pub fn admin_inject_partial_signatures(
125        &self,
126        authority_name: AuthorityName,
127        round: RandomnessRound,
128        sigs: Vec<RandomnessPartialSignature>,
129        result_channel: oneshot::Sender<Result<()>>,
130    ) {
131        self.sender
132            .try_send(RandomnessMessage::AdminInjectPartialSignatures(
133                authority_name,
134                round,
135                sigs,
136                result_channel,
137            ))
138            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
139    }
140
141    /// Admin interface handler: injects full signature for the given round at the
142    /// current epoch, skipping validity checks.
143    pub fn admin_inject_full_signature(
144        &self,
145        round: RandomnessRound,
146        sig: RandomnessSignature,
147        result_channel: oneshot::Sender<Result<()>>,
148    ) {
149        self.sender
150            .try_send(RandomnessMessage::AdminInjectFullSignature(
151                round,
152                sig,
153                result_channel,
154            ))
155            .expect("RandomnessEventLoop mailbox should not overflow or be closed")
156    }
157
158    // For testing.
159    pub fn new_stub() -> Self {
160        let (sender, mut receiver) = mpsc::channel(1);
161        // Keep receiver open until all senders are closed.
162        tokio::spawn(async move {
163            loop {
164                tokio::select! {
165                    m = receiver.recv() => {
166                        if m.is_none() {
167                            break;
168                        }
169                    },
170                }
171            }
172        });
173        Self { sender }
174    }
175}
176
177#[derive(Debug)]
178enum RandomnessMessage {
179    UpdateEpoch(
180        EpochId,
181        HashMap<AuthorityName, (PeerId, PartyId)>,
182        dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
183        u16,                     // aggregation_threshold
184        Option<RandomnessRound>, // recovered_highest_completed_round
185    ),
186    SendPartialSignatures(EpochId, RandomnessRound),
187    CompleteRound(EpochId, RandomnessRound),
188    ReceiveSignatures(
189        PeerId,
190        EpochId,
191        RandomnessRound,
192        Vec<Vec<u8>>,
193        Option<RandomnessSignature>,
194    ),
195    MaybeIgnoreByzantinePeer(EpochId, PeerId),
196    AdminGetPartialSignatures(RandomnessRound, oneshot::Sender<Vec<u8>>),
197    AdminInjectPartialSignatures(
198        AuthorityName,
199        RandomnessRound,
200        Vec<RandomnessPartialSignature>,
201        oneshot::Sender<Result<()>>,
202    ),
203    AdminInjectFullSignature(
204        RandomnessRound,
205        RandomnessSignature,
206        oneshot::Sender<Result<()>>,
207    ),
208}
209
210struct RandomnessEventLoop {
211    name: AuthorityName,
212    config: RandomnessConfig,
213    mailbox: mpsc::Receiver<RandomnessMessage>,
214    mailbox_sender: mpsc::WeakSender<RandomnessMessage>,
215    network: anemo::Network,
216    allowed_peers: AllowedPeersUpdatable,
217    allowed_peers_set: HashSet<PeerId>,
218    metrics: Metrics,
219    randomness_tx: mpsc::Sender<(EpochId, RandomnessRound, Vec<u8>)>,
220
221    epoch: EpochId,
222    authority_info: Arc<HashMap<AuthorityName, (PeerId, PartyId)>>,
223    peer_share_ids: Option<HashMap<PeerId, Vec<ShareIndex>>>,
224    blocked_share_id_count: usize,
225    dkg_output: Option<dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>>,
226    aggregation_threshold: u16,
227    highest_requested_round: BTreeMap<EpochId, RandomnessRound>,
228    send_tasks: BTreeMap<
229        RandomnessRound,
230        (
231            tokio::task::JoinHandle<()>,
232            Arc<OnceCell<RandomnessSignature>>,
233        ),
234    >,
235    round_request_time: BTreeMap<(EpochId, RandomnessRound), time::Instant>,
236    future_epoch_partial_sigs: BTreeMap<(EpochId, RandomnessRound, PeerId), Vec<Vec<u8>>>,
237    received_partial_sigs: BTreeMap<(RandomnessRound, PeerId), Vec<RandomnessPartialSignature>>,
238    completed_sigs: BTreeMap<RandomnessRound, RandomnessSignature>,
239    highest_completed_round: BTreeMap<EpochId, RandomnessRound>,
240}
241
242impl RandomnessEventLoop {
243    pub async fn start(mut self) {
244        info!("Randomness network event loop started");
245
246        loop {
247            tokio::select! {
248                maybe_message = self.mailbox.recv() => {
249                    // Once all handles to our mailbox have been dropped this
250                    // will yield `None` and we can terminate the event loop.
251                    if let Some(message) = maybe_message {
252                        self.handle_message(message);
253                    } else {
254                        break;
255                    }
256                },
257            }
258        }
259
260        info!("Randomness network event loop ended");
261    }
262
263    fn handle_message(&mut self, message: RandomnessMessage) {
264        match message {
265            RandomnessMessage::UpdateEpoch(
266                epoch,
267                authority_info,
268                dkg_output,
269                aggregation_threshold,
270                recovered_highest_completed_round,
271            ) => {
272                if let Err(e) = self.update_epoch(
273                    epoch,
274                    authority_info,
275                    dkg_output,
276                    aggregation_threshold,
277                    recovered_highest_completed_round,
278                ) {
279                    error!("BUG: failed to update epoch in RandomnessEventLoop: {e:?}");
280                }
281            }
282            RandomnessMessage::SendPartialSignatures(epoch, round) => {
283                self.send_partial_signatures(epoch, round)
284            }
285            RandomnessMessage::CompleteRound(epoch, round) => self.complete_round(epoch, round),
286            RandomnessMessage::ReceiveSignatures(peer_id, epoch, round, partial_sigs, sig) => {
287                if let Some(sig) = sig {
288                    self.receive_full_signature(peer_id, epoch, round, sig)
289                } else {
290                    self.receive_partial_signatures(peer_id, epoch, round, partial_sigs)
291                }
292            }
293            RandomnessMessage::MaybeIgnoreByzantinePeer(epoch, peer_id) => {
294                self.maybe_ignore_byzantine_peer(epoch, peer_id)
295            }
296            RandomnessMessage::AdminGetPartialSignatures(round, tx) => {
297                self.admin_get_partial_signatures(round, tx)
298            }
299            RandomnessMessage::AdminInjectPartialSignatures(
300                authority_name,
301                round,
302                sigs,
303                result_channel,
304            ) => {
305                let _ = result_channel.send(self.admin_inject_partial_signatures(
306                    authority_name,
307                    round,
308                    sigs,
309                ));
310            }
311            RandomnessMessage::AdminInjectFullSignature(round, sig, result_channel) => {
312                let _ = result_channel.send(self.admin_inject_full_signature(round, sig));
313            }
314        }
315    }
316
317    #[instrument(level = "debug", skip_all, fields(?new_epoch))]
318    fn update_epoch(
319        &mut self,
320        new_epoch: EpochId,
321        authority_info: HashMap<AuthorityName, (PeerId, PartyId)>,
322        dkg_output: dkg_v1::Output<bls12381::G2Element, bls12381::G2Element>,
323        aggregation_threshold: u16,
324        recovered_highest_completed_round: Option<RandomnessRound>,
325    ) -> Result<()> {
326        assert!(self.dkg_output.is_none() || new_epoch > self.epoch);
327
328        debug!("updating randomness network loop to new epoch");
329
330        self.peer_share_ids = Some(authority_info.iter().try_fold(
331            HashMap::new(),
332            |mut acc, (_name, (peer_id, party_id))| -> Result<_> {
333                let ids = dkg_output
334                    .nodes
335                    .share_ids_of(*party_id)
336                    .expect("party_id should be valid");
337                acc.insert(*peer_id, ids);
338                Ok(acc)
339            },
340        )?);
341        self.allowed_peers_set = authority_info
342            .values()
343            .map(|(peer_id, _)| *peer_id)
344            .collect();
345        self.allowed_peers
346            .update(Arc::new(self.allowed_peers_set.clone()));
347        self.epoch = new_epoch;
348        self.authority_info = Arc::new(authority_info);
349        self.dkg_output = Some(dkg_output);
350        self.aggregation_threshold = aggregation_threshold;
351        if let Some(round) = recovered_highest_completed_round {
352            self.highest_completed_round
353                .entry(new_epoch)
354                .and_modify(|r| *r = std::cmp::max(*r, round))
355                .or_insert(round);
356        }
357        for (_, (task, _)) in std::mem::take(&mut self.send_tasks) {
358            task.abort();
359        }
360        self.metrics.set_epoch(new_epoch);
361
362        // Throw away info from old epochs.
363        self.highest_requested_round = self.highest_requested_round.split_off(&new_epoch);
364        self.round_request_time = self
365            .round_request_time
366            .split_off(&(new_epoch, RandomnessRound(0)));
367        self.received_partial_sigs.clear();
368        self.completed_sigs.clear();
369        self.highest_completed_round = self.highest_completed_round.split_off(&new_epoch);
370
371        // Start any pending tasks for the new epoch.
372        self.maybe_start_pending_tasks();
373
374        // Aggregate any sigs received early from the new epoch.
375        // (We can't call `maybe_aggregate_partial_signatures` directly while iterating,
376        // because it takes `&mut self`, so we store in a Vec first.)
377        for ((epoch, round, peer_id), sig_bytes) in
378            std::mem::take(&mut self.future_epoch_partial_sigs)
379        {
380            // We can fully validate these now that we have current epoch DKG output.
381            self.receive_partial_signatures(peer_id, epoch, round, sig_bytes);
382        }
383        let rounds_to_aggregate: Vec<_> =
384            self.received_partial_sigs.keys().map(|(r, _)| *r).collect();
385        for round in rounds_to_aggregate {
386            self.maybe_aggregate_partial_signatures(new_epoch, round);
387        }
388
389        Ok(())
390    }
391
392    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
393    fn send_partial_signatures(&mut self, epoch: EpochId, round: RandomnessRound) {
394        if epoch < self.epoch {
395            error!(
396                "BUG: skipping sending partial sigs, we are already up to epoch {}",
397                self.epoch
398            );
399            debug_assert!(
400                false,
401                "skipping sending partial sigs, we are already up to higher epoch"
402            );
403            return;
404        }
405        if epoch == self.epoch
406            && let Some(highest_completed_round) = self.highest_completed_round.get(&epoch)
407            && round <= *highest_completed_round
408        {
409            info!("skipping sending partial sigs, we already have completed this round");
410            return;
411        }
412
413        self.highest_requested_round
414            .entry(epoch)
415            .and_modify(|r| *r = std::cmp::max(*r, round))
416            .or_insert(round);
417        self.round_request_time
418            .insert((epoch, round), time::Instant::now());
419        self.maybe_start_pending_tasks();
420    }
421
422    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
423    fn complete_round(&mut self, epoch: EpochId, round: RandomnessRound) {
424        debug!("completing randomness round");
425        let new_highest_round = *self
426            .highest_completed_round
427            .entry(epoch)
428            .and_modify(|r| *r = std::cmp::max(*r, round))
429            .or_insert(round);
430        if round != new_highest_round {
431            // This round completion came out of order, and we're already ahead. Nothing more
432            // to do in that case.
433            return;
434        }
435
436        self.round_request_time = self.round_request_time.split_off(&(epoch, round + 1));
437
438        if epoch == self.epoch {
439            self.remove_partial_sigs_in_range((
440                Bound::Included((RandomnessRound(0), PeerId([0; 32]))),
441                Bound::Excluded((round + 1, PeerId([0; 32]))),
442            ));
443            self.completed_sigs = self.completed_sigs.split_off(&(round + 1));
444            for (_, (task, _)) in self.send_tasks.iter().take_while(|(r, _)| **r <= round) {
445                task.abort();
446            }
447            self.send_tasks = self.send_tasks.split_off(&(round + 1));
448            self.maybe_start_pending_tasks();
449        }
450
451        self.update_rounds_pending_metric();
452    }
453
454    #[instrument(level = "debug", skip_all, fields(?peer_id, ?epoch, ?round))]
455    fn receive_partial_signatures(
456        &mut self,
457        peer_id: PeerId,
458        epoch: EpochId,
459        round: RandomnessRound,
460        sig_bytes: Vec<Vec<u8>>,
461    ) {
462        // Basic validity checks.
463        if epoch < self.epoch {
464            debug!(
465                "skipping received partial sigs, we are already up to epoch {}",
466                self.epoch
467            );
468            return;
469        }
470        if epoch > self.epoch + 1 {
471            debug!(
472                "skipping received partial sigs, we are still on epoch {}",
473                self.epoch
474            );
475            return;
476        }
477        if epoch == self.epoch && self.completed_sigs.contains_key(&round) {
478            debug!("skipping received partial sigs, we already have completed this sig");
479            return;
480        }
481        let highest_completed_round = self.highest_completed_round.get(&epoch).copied();
482        if let Some(highest_completed_round) = &highest_completed_round
483            && *highest_completed_round >= round
484        {
485            debug!("skipping received partial sigs, we already have completed this round");
486            return;
487        }
488
489        // If sigs are for a future epoch, we can't fully verify them without DKG output.
490        // Save them for later use.
491        if epoch != self.epoch || self.peer_share_ids.is_none() {
492            if round.0 >= self.config.max_partial_sigs_rounds_ahead() {
493                debug!("skipping received partial sigs for future epoch, round too far ahead",);
494                return;
495            }
496
497            debug!("saving partial sigs from future epoch for later use");
498            self.future_epoch_partial_sigs
499                .insert((epoch, round, peer_id), sig_bytes);
500            return;
501        }
502
503        // Verify shape of sigs matches what we expect for the peer.
504        let peer_share_ids = self.peer_share_ids.as_ref().expect("checked above");
505        let expected_share_ids = if let Some(expected_share_ids) = peer_share_ids.get(&peer_id) {
506            expected_share_ids
507        } else {
508            debug!("received partial sigs from unknown peer");
509            return;
510        };
511        if sig_bytes.len() != expected_share_ids.len() as usize {
512            warn!(
513                "received partial sigs with wrong share ids count: expected {}, got {}",
514                expected_share_ids.len(),
515                sig_bytes.len(),
516            );
517            return;
518        }
519
520        // Accept partial signatures up to `max_partial_sigs_rounds_ahead` past the round of the
521        // last completed signature, or the highest completed round, whichever is greater.
522        let last_completed_signature = self.completed_sigs.last_key_value().map(|(r, _)| *r);
523        let last_completed_round = std::cmp::max(last_completed_signature, highest_completed_round)
524            .unwrap_or(RandomnessRound(0));
525        if round.0
526            >= last_completed_round
527                .0
528                .saturating_add(self.config.max_partial_sigs_rounds_ahead())
529        {
530            debug!(
531                "skipping received partial sigs, most recent round we completed was only {last_completed_round}",
532            );
533            return;
534        }
535
536        // Deserialize the partial sigs.
537        let partial_sigs =
538            match sig_bytes
539                .iter()
540                .try_fold(Vec::new(), |mut acc, bytes| -> Result<_> {
541                    let sig: RandomnessPartialSignature = bcs::from_bytes(bytes)?;
542                    acc.push(sig);
543                    Ok(acc)
544                }) {
545                Ok(partial_sigs) => partial_sigs,
546                Err(e) => {
547                    warn!("failed to deserialize partial sigs: {e:?}");
548                    return;
549                }
550            };
551        // Verify we received the expected share IDs (to protect against a validator that sends
552        // valid signatures of other peers which will be successfully verified below).
553        let received_share_ids = partial_sigs.iter().map(|s| s.index);
554        if received_share_ids
555            .zip_debug_eq(expected_share_ids.iter())
556            .any(|(a, b)| a != *b)
557        {
558            let received_share_ids = partial_sigs.iter().map(|s| s.index).collect::<Vec<_>>();
559            warn!(
560                "received partial sigs with wrong share ids: expected {expected_share_ids:?}, received {received_share_ids:?}"
561            );
562            return;
563        }
564
565        // We passed all the checks, save the partial sigs.
566        debug!("recording received partial signatures");
567        self.received_partial_sigs
568            .insert((round, peer_id), partial_sigs);
569
570        self.maybe_aggregate_partial_signatures(epoch, round);
571    }
572
573    #[instrument(level = "debug", skip_all, fields(?epoch, ?round))]
574    fn maybe_aggregate_partial_signatures(&mut self, epoch: EpochId, round: RandomnessRound) {
575        if let Some(highest_completed_round) = self.highest_completed_round.get(&epoch)
576            && round <= *highest_completed_round
577        {
578            info!("skipping aggregation for already-completed round");
579            return;
580        }
581
582        let highest_requested_round = self.highest_requested_round.get(&epoch);
583        if highest_requested_round.is_none() || round > *highest_requested_round.unwrap() {
584            // We have to wait here, because even if we have enough information from other nodes
585            // to complete the signature, local shared object versions are not set until consensus
586            // finishes processing the corresponding commit. This function will be called again
587            // after maybe_start_pending_tasks begins this round locally.
588            debug!(
589                "waiting to aggregate randomness partial signatures until local consensus catches up"
590            );
591            return;
592        }
593
594        if epoch != self.epoch {
595            debug!(
596                "waiting to aggregate randomness partial signatures until DKG completes for epoch"
597            );
598            return;
599        }
600
601        if self.completed_sigs.contains_key(&round) {
602            info!("skipping aggregation for already-completed signature");
603            return;
604        }
605
606        let vss_pk = {
607            let Some(dkg_output) = &self.dkg_output else {
608                debug!("called maybe_aggregate_partial_signatures before DKG completed");
609                return;
610            };
611            &dkg_output.vss_pk
612        };
613
614        let sig_bounds = (
615            Bound::Included((round, PeerId([0; 32]))),
616            Bound::Excluded((round + 1, PeerId([0; 32]))),
617        );
618
619        // If we have enough partial signatures, aggregate them.
620        let sig_range = self
621            .received_partial_sigs
622            .range(sig_bounds)
623            .flat_map(|(_, sigs)| sigs);
624        let mut sig =
625            match ThresholdBls12381MinSig::aggregate(self.aggregation_threshold, sig_range) {
626                Ok(sig) => sig,
627                Err(fastcrypto::error::FastCryptoError::NotEnoughInputs) => return, // wait for more input
628                Err(e) => {
629                    error!("error while aggregating randomness partial signatures: {e:?}");
630                    return;
631                }
632            };
633
634        // Try to verify the aggregated signature all at once. (Should work in the happy path.)
635        if ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig).is_err() {
636            // If verifiation fails, some of the inputs must be invalid. We have to go through
637            // one-by-one to find which.
638            // TODO: add test for individual sig verification.
639            self.received_partial_sigs
640                .retain(|&(r, peer_id), partial_sigs| {
641                    if round != r {
642                        return true;
643                    }
644                    if ThresholdBls12381MinSig::partial_verify_batch(
645                        vss_pk,
646                        &round.signature_message(),
647                        partial_sigs.iter(),
648                        &mut rand::thread_rng(),
649                    )
650                    .is_err()
651                    {
652                        warn!(
653                            "received invalid partial signatures from possibly-Byzantine peer {peer_id}"
654                        );
655                        if let Some(sender) = self.mailbox_sender.upgrade() {
656                            sender.try_send(RandomnessMessage::MaybeIgnoreByzantinePeer(
657                                epoch,
658                                peer_id,
659                            ))
660                            .expect("RandomnessEventLoop mailbox should not overflow or be closed");
661                        }
662                        return false;
663                    }
664                    true
665                });
666            let sig_range = self
667                .received_partial_sigs
668                .range(sig_bounds)
669                .flat_map(|(_, sigs)| sigs);
670            sig = match ThresholdBls12381MinSig::aggregate(self.aggregation_threshold, sig_range) {
671                Ok(sig) => sig,
672                Err(fastcrypto::error::FastCryptoError::NotEnoughInputs) => return, // wait for more input
673                Err(e) => {
674                    error!("error while aggregating randomness partial signatures: {e:?}");
675                    return;
676                }
677            };
678            if let Err(e) =
679                ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
680            {
681                error!(
682                    "error while verifying randomness partial signatures after removing invalid partials: {e:?}"
683                );
684                debug_assert!(
685                    false,
686                    "error while verifying randomness partial signatures after removing invalid partials"
687                );
688                return;
689            }
690        }
691
692        debug!("successfully generated randomness full signature");
693        self.process_valid_full_signature(epoch, round, sig);
694    }
695
696    #[instrument(level = "debug", skip_all, fields(?peer_id, ?epoch, ?round))]
697    fn receive_full_signature(
698        &mut self,
699        peer_id: PeerId,
700        epoch: EpochId,
701        round: RandomnessRound,
702        sig: RandomnessSignature,
703    ) {
704        let vss_pk = {
705            let Some(dkg_output) = &self.dkg_output else {
706                debug!("called receive_full_signature before DKG completed");
707                return;
708            };
709            &dkg_output.vss_pk
710        };
711
712        // Basic validity checks.
713        if epoch != self.epoch {
714            debug!("skipping received full sig, we are on epoch {}", self.epoch);
715            return;
716        }
717        if self.completed_sigs.contains_key(&round) {
718            debug!("skipping received full sigs, we already have completed this sig");
719            return;
720        }
721        let highest_completed_round = self.highest_completed_round.get(&epoch).copied();
722        if let Some(highest_completed_round) = &highest_completed_round
723            && *highest_completed_round >= round
724        {
725            debug!("skipping received full sig, we already have completed this round");
726            return;
727        }
728
729        let highest_requested_round = self.highest_requested_round.get(&epoch);
730        if highest_requested_round.is_none() || round > *highest_requested_round.unwrap() {
731            // Wait for local consensus to catch up if necessary.
732            debug!(
733                "skipping received full signature, local consensus is not caught up to its round"
734            );
735            return;
736        }
737
738        if let Err(e) =
739            ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
740        {
741            info!("received invalid full signature from peer {peer_id}: {e:?}");
742            if let Some(sender) = self.mailbox_sender.upgrade() {
743                sender
744                    .try_send(RandomnessMessage::MaybeIgnoreByzantinePeer(epoch, peer_id))
745                    .expect("RandomnessEventLoop mailbox should not overflow or be closed");
746            }
747            return;
748        }
749
750        debug!("received valid randomness full signature");
751        self.process_valid_full_signature(epoch, round, sig);
752    }
753
754    fn process_valid_full_signature(
755        &mut self,
756        epoch: EpochId,
757        round: RandomnessRound,
758        sig: RandomnessSignature,
759    ) {
760        assert_eq!(epoch, self.epoch);
761
762        if let Some((_, full_sig_cell)) = self.send_tasks.get(&round) {
763            full_sig_cell
764                .set(sig)
765                .expect("full signature should never be processed twice");
766        }
767        self.completed_sigs.insert(round, sig);
768        self.remove_partial_sigs_in_range((
769            Bound::Included((round, PeerId([0; 32]))),
770            Bound::Excluded((round + 1, PeerId([0; 32]))),
771        ));
772        self.metrics.record_completed_round(round);
773        if let Some(start_time) = self.round_request_time.get(&(epoch, round))
774            && let Some(metric) = self.metrics.round_generation_latency_metric()
775        {
776            metric.observe(start_time.elapsed().as_secs_f64());
777        }
778
779        let sig_bytes = bcs::to_bytes(&sig).expect("signature serialization should not fail");
780        self.randomness_tx
781            .try_send((epoch, round, sig_bytes))
782            .expect("RandomnessRoundReceiver mailbox should not overflow or be closed");
783    }
784
785    fn maybe_ignore_byzantine_peer(&mut self, epoch: EpochId, peer_id: PeerId) {
786        if epoch != self.epoch {
787            return; // make sure we're still on the same epoch
788        }
789        let Some(dkg_output) = &self.dkg_output else {
790            return; // can't ignore a peer if we haven't finished DKG
791        };
792        if !self.allowed_peers_set.contains(&peer_id) {
793            return; // peer is already disallowed
794        }
795        let Some(peer_share_ids) = &self.peer_share_ids else {
796            return; // can't ignore a peer if we haven't finished DKG
797        };
798        let Some(peer_shares) = peer_share_ids.get(&peer_id) else {
799            warn!("can't ignore unknown byzantine peer {peer_id:?}");
800            return;
801        };
802        let max_ignored_shares = (self.config.max_ignored_peer_weight_factor()
803            * (dkg_output.nodes.total_weight() as f64)) as usize;
804        if self.blocked_share_id_count + peer_shares.len() > max_ignored_shares {
805            warn!(
806                "ignoring byzantine peer {peer_id:?} with {} shares would exceed max ignored peer weight {max_ignored_shares}",
807                peer_shares.len()
808            );
809            return;
810        }
811
812        warn!(
813            "ignoring byzantine peer {peer_id:?} with {} shares",
814            peer_shares.len()
815        );
816        self.blocked_share_id_count += peer_shares.len();
817        self.allowed_peers_set.remove(&peer_id);
818        self.allowed_peers
819            .update(Arc::new(self.allowed_peers_set.clone()));
820        self.metrics.inc_num_ignored_byzantine_peers();
821    }
822
823    fn maybe_start_pending_tasks(&mut self) {
824        let dkg_output = if let Some(dkg_output) = &self.dkg_output {
825            dkg_output
826        } else {
827            return; // wait for DKG
828        };
829        let shares = if let Some(shares) = &dkg_output.shares {
830            shares
831        } else {
832            return; // can't participate in randomness generation without shares
833        };
834        let highest_requested_round =
835            if let Some(highest_requested_round) = self.highest_requested_round.get(&self.epoch) {
836                highest_requested_round
837            } else {
838                return; // no rounds to start
839            };
840        // Begin from the next round after the most recent one we've started (or, if none are running,
841        // after the highest completed round in the epoch).
842        let start_round = std::cmp::max(
843            if let Some(highest_completed_round) = self.highest_completed_round.get(&self.epoch) {
844                highest_completed_round.checked_add(1).unwrap()
845            } else {
846                RandomnessRound(0)
847            },
848            self.send_tasks
849                .last_key_value()
850                .map(|(r, _)| r.checked_add(1).unwrap())
851                .unwrap_or(RandomnessRound(0)),
852        );
853
854        let mut rounds_to_aggregate = Vec::new();
855        for round in start_round.0..=highest_requested_round.0 {
856            let round = RandomnessRound(round);
857
858            if self.send_tasks.len() >= self.config.max_partial_sigs_concurrent_sends() {
859                break; // limit concurrent tasks
860            }
861
862            let full_sig_cell = Arc::new(OnceCell::new());
863            self.send_tasks.entry(round).or_insert_with(|| {
864                let name = self.name;
865                let network = self.network.clone();
866                let retry_interval = self.config.partial_signature_retry_interval();
867                let metrics = self.metrics.clone();
868                let authority_info = self.authority_info.clone();
869                let epoch = self.epoch;
870                let partial_sigs = ThresholdBls12381MinSig::partial_sign_batch(
871                    shares.iter(),
872                    &round.signature_message(),
873                );
874                let full_sig_cell_clone = full_sig_cell.clone();
875
876                // Record own partial sigs.
877                if !self.completed_sigs.contains_key(&round) {
878                    self.received_partial_sigs
879                        .insert((round, self.network.peer_id()), partial_sigs.clone());
880                    rounds_to_aggregate.push((epoch, round));
881                }
882
883                debug!("sending partial sigs for epoch {epoch}, round {round}");
884                (
885                    spawn_monitored_task!(RandomnessEventLoop::send_signatures_task(
886                        name,
887                        network,
888                        retry_interval,
889                        metrics,
890                        authority_info,
891                        epoch,
892                        round,
893                        partial_sigs,
894                        full_sig_cell_clone,
895                    )),
896                    full_sig_cell,
897                )
898            });
899        }
900
901        self.update_rounds_pending_metric();
902
903        // After starting a round, we have generated our own partial sigs. Check if that's
904        // enough for us to aggregate already.
905        for (epoch, round) in rounds_to_aggregate {
906            self.maybe_aggregate_partial_signatures(epoch, round);
907        }
908    }
909
910    #[allow(clippy::type_complexity)]
911    fn remove_partial_sigs_in_range(
912        &mut self,
913        range: (
914            Bound<(RandomnessRound, PeerId)>,
915            Bound<(RandomnessRound, PeerId)>,
916        ),
917    ) {
918        let keys_to_remove: Vec<_> = self
919            .received_partial_sigs
920            .range(range)
921            .map(|(key, _)| *key)
922            .collect();
923        for key in keys_to_remove {
924            // Have to remove keys one-by-one because BTreeMap does not support range-removal.
925            self.received_partial_sigs.remove(&key);
926        }
927    }
928
929    async fn send_signatures_task(
930        name: AuthorityName,
931        network: anemo::Network,
932        retry_interval: Duration,
933        metrics: Metrics,
934        authority_info: Arc<HashMap<AuthorityName, (PeerId, PartyId)>>,
935        epoch: EpochId,
936        round: RandomnessRound,
937        partial_sigs: Vec<RandomnessPartialSignature>,
938        full_sig: Arc<OnceCell<RandomnessSignature>>,
939    ) {
940        // For simtests, we may test not sending partial signatures.
941        #[allow(unused_mut)]
942        let mut fail_point_skip_sending = false;
943        fail_point_if!("rb-send-partial-signatures", || {
944            fail_point_skip_sending = true;
945        });
946        if fail_point_skip_sending {
947            warn!("skipping sending partial sigs due to simtest fail point");
948            return;
949        }
950
951        let _metrics_guard = metrics
952            .round_observation_latency_metric()
953            .map(|metric| metric.start_timer());
954
955        let peers: HashMap<_, _> = authority_info
956            .iter()
957            .map(|(name, (peer_id, _party_id))| (name, network.waiting_peer(*peer_id)))
958            .collect();
959        let partial_sigs: Vec<_> = partial_sigs
960            .iter()
961            .map(|sig| bcs::to_bytes(sig).expect("message serialization should not fail"))
962            .collect();
963
964        loop {
965            let mut requests = Vec::new();
966            for (peer_name, peer) in &peers {
967                if name == **peer_name {
968                    continue; // don't send partial sigs to self
969                }
970                let mut client = RandomnessClient::new(peer.clone());
971                const SEND_PARTIAL_SIGNATURES_TIMEOUT: Duration = Duration::from_secs(10);
972                let full_sig = full_sig.get().cloned();
973                let request = anemo::Request::new(SendSignaturesRequest {
974                    epoch,
975                    round,
976                    partial_sigs: if full_sig.is_none() {
977                        partial_sigs.clone()
978                    } else {
979                        Vec::new()
980                    },
981                    sig: full_sig,
982                })
983                .with_timeout(SEND_PARTIAL_SIGNATURES_TIMEOUT);
984                requests.push(async move {
985                    let result = client.send_signatures(request).await;
986                    if let Err(_error) = result {
987                        // TODO: add Display impl to anemo::rpc::Status, log it here
988                        debug!("failed to send partial signatures to {peer_name}");
989                    }
990                });
991            }
992
993            // Process all requests.
994            futures::future::join_all(requests).await;
995
996            // Keep retrying send to all peers until task is aborted via external message.
997            tokio::time::sleep(retry_interval).await;
998        }
999    }
1000
1001    fn update_rounds_pending_metric(&self) {
1002        let highest_requested_round = self
1003            .highest_requested_round
1004            .get(&self.epoch)
1005            .map(|r| r.0)
1006            .unwrap_or(0);
1007        let highest_completed_round = self
1008            .highest_completed_round
1009            .get(&self.epoch)
1010            .map(|r| r.0)
1011            .unwrap_or(0);
1012        let num_rounds_pending =
1013            highest_requested_round.saturating_sub(highest_completed_round) as i64;
1014        let prev_value = self.metrics.num_rounds_pending().unwrap_or_default();
1015        if num_rounds_pending / 100 > prev_value / 100 {
1016            warn!(
1017                // Recording multiples of 100 so tests can match on the log message.
1018                "RandomnessEventLoop randomness generation backlog: over {} rounds are pending (oldest is {:?})",
1019                (num_rounds_pending / 100) * 100,
1020                highest_completed_round + 1,
1021            );
1022        }
1023        self.metrics.set_num_rounds_pending(num_rounds_pending);
1024    }
1025
1026    fn admin_get_partial_signatures(&self, round: RandomnessRound, tx: oneshot::Sender<Vec<u8>>) {
1027        let shares = if let Some(shares) = self.dkg_output.as_ref().and_then(|d| d.shares.as_ref())
1028        {
1029            shares
1030        } else {
1031            let _ = tx.send(Vec::new()); // no error handling needed if receiver is already dropped
1032            return;
1033        };
1034
1035        let partial_sigs =
1036            ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), &round.signature_message());
1037        // no error handling needed if receiver is already dropped
1038        let _ = tx.send(bcs::to_bytes(&partial_sigs).expect("serialization should not fail"));
1039    }
1040
1041    fn admin_inject_partial_signatures(
1042        &mut self,
1043        authority_name: AuthorityName,
1044        round: RandomnessRound,
1045        sigs: Vec<RandomnessPartialSignature>,
1046    ) -> Result<()> {
1047        let peer_id = self
1048            .authority_info
1049            .get(&authority_name)
1050            .map(|(peer_id, _)| *peer_id)
1051            .ok_or(anyhow::anyhow!("unknown AuthorityName {authority_name:?}"))?;
1052        self.received_partial_sigs.insert((round, peer_id), sigs);
1053        self.maybe_aggregate_partial_signatures(self.epoch, round);
1054        Ok(())
1055    }
1056
1057    fn admin_inject_full_signature(
1058        &mut self,
1059        round: RandomnessRound,
1060        sig: RandomnessSignature,
1061    ) -> Result<()> {
1062        let vss_pk = {
1063            let Some(dkg_output) = &self.dkg_output else {
1064                return Err(anyhow::anyhow!(
1065                    "called admin_inject_full_signature before DKG completed"
1066                ));
1067            };
1068            &dkg_output.vss_pk
1069        };
1070
1071        ThresholdBls12381MinSig::verify(vss_pk.c0(), &round.signature_message(), &sig)
1072            .map_err(|e| anyhow::anyhow!("invalid full signature: {e:?}"))?;
1073
1074        self.process_valid_full_signature(self.epoch, round, sig);
1075        Ok(())
1076    }
1077}