1use std::sync::{Arc, Mutex};
5
6use arc_swap::ArcSwapOption;
7use mysten_network::Multiaddr;
8use serde::{Deserialize, Serialize};
9use sui_types::crypto::{NetworkPublicKey, ToFromBytes};
10use sui_types::error::SuiResult;
11use tap::TapFallible;
12use tracing::{info, warn};
13
14use crate::discovery;
15
16#[derive(Clone)]
19pub struct EndpointManager {
20 inner: Arc<Inner>,
21}
22
23struct Inner {
24 discovery_sender: discovery::Sender,
25 consensus_address_updater: ArcSwapOption<Arc<dyn ConsensusAddressUpdater>>,
26 pending_consensus_updates: Mutex<Vec<(NetworkPublicKey, AddressSource, Vec<Multiaddr>)>>,
27}
28
29pub trait ConsensusAddressUpdater: Send + Sync + 'static {
30 fn update_address(
31 &self,
32 network_pubkey: NetworkPublicKey,
33 source: AddressSource,
34 addresses: Vec<Multiaddr>,
35 ) -> SuiResult<()>;
36}
37
38impl EndpointManager {
39 pub fn new(discovery_sender: discovery::Sender) -> Self {
40 Self {
41 inner: Arc::new(Inner {
42 discovery_sender,
43 consensus_address_updater: ArcSwapOption::empty(),
44 pending_consensus_updates: Mutex::new(Vec::new()),
45 }),
46 }
47 }
48
49 pub fn set_consensus_address_updater(
50 &self,
51 consensus_address_updater: Arc<dyn ConsensusAddressUpdater>,
52 ) {
53 let mut pending = self.inner.pending_consensus_updates.lock().unwrap();
54
55 for (pubkey, source, addrs) in pending.drain(..) {
56 if let Err(e) = consensus_address_updater.update_address(pubkey.clone(), source, addrs)
57 {
58 warn!(
59 ?pubkey,
60 "Error replaying buffered consensus address update: {e:?}"
61 );
62 }
63 }
64
65 self.inner
66 .consensus_address_updater
67 .store(Some(Arc::new(consensus_address_updater)));
68 }
69
70 pub fn update_endpoint(
75 &self,
76 endpoint: EndpointId,
77 source: AddressSource,
78 addresses: Vec<Multiaddr>,
79 ) -> SuiResult<()> {
80 match endpoint {
81 EndpointId::P2p(peer_id) => {
82 let anemo_addresses: Vec<_> = addresses
83 .into_iter()
84 .filter_map(|addr| {
85 addr.to_anemo_address()
86 .tap_err(|_| {
87 warn!(
88 ?addr,
89 "Skipping peer address: can't convert to anemo address"
90 )
91 })
92 .ok()
93 })
94 .collect();
95
96 self.inner
97 .discovery_sender
98 .peer_address_change(peer_id, source, anemo_addresses);
99 }
100 EndpointId::Consensus(network_pubkey) => {
101 let mut pending = self.inner.pending_consensus_updates.lock().unwrap();
106 if let Some(updater) = self.inner.consensus_address_updater.load_full() {
107 drop(pending);
108 updater
109 .update_address(network_pubkey.clone(), source, addresses)
110 .map_err(|e| {
111 warn!(?network_pubkey, "Error updating consensus address: {e:?}");
112 e
113 })?;
114 } else {
115 info!(
116 ?network_pubkey,
117 "Buffering consensus address update (updater not yet set)"
118 );
119 pending.push((network_pubkey, source, addresses));
120 }
121 }
122 }
123
124 Ok(())
125 }
126
127 pub fn clear_source(&self, peer_id: anemo::PeerId, source: AddressSource) {
129 let _ = self.update_endpoint(EndpointId::P2p(peer_id), source, vec![]);
130 if let Ok(network_pubkey) = NetworkPublicKey::from_bytes(&peer_id.0) {
131 let _ = self.update_endpoint(EndpointId::Consensus(network_pubkey), source, vec![]);
132 }
133
134 fn _assert_all_variants_handled(id: &EndpointId) {
138 match id {
139 EndpointId::P2p(_) | EndpointId::Consensus(_) => {}
140 }
141 }
142 }
143}
144
145#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
146pub enum EndpointId {
147 P2p(anemo::PeerId),
148 Consensus(NetworkPublicKey),
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
152pub enum AddressSource {
154 Admin, Config, Discovery, Seed, Chain, }
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use fastcrypto::traits::KeyPair;
165 use std::sync::{Arc, Mutex};
166 use sui_types::crypto::{NetworkKeyPair, get_key_pair};
167
168 type UpdateEntry = (NetworkPublicKey, Vec<Multiaddr>);
169 struct MockConsensusAddressUpdater {
171 updates: Arc<Mutex<Vec<UpdateEntry>>>,
172 }
173
174 impl MockConsensusAddressUpdater {
175 fn new() -> (Self, Arc<Mutex<Vec<UpdateEntry>>>) {
176 let updates = Arc::new(Mutex::new(Vec::new()));
177 let updater = Self {
178 updates: updates.clone(),
179 };
180 (updater, updates)
181 }
182 }
183
184 impl ConsensusAddressUpdater for MockConsensusAddressUpdater {
185 fn update_address(
186 &self,
187 network_pubkey: NetworkPublicKey,
188 _source: AddressSource,
189 addresses: Vec<Multiaddr>,
190 ) -> SuiResult<()> {
191 self.updates
192 .lock()
193 .unwrap()
194 .push((network_pubkey.clone(), addresses));
195 Ok(())
196 }
197 }
198
199 fn create_mock_endpoint_manager() -> EndpointManager {
200 use sui_config::p2p::P2pConfig;
201
202 let config = P2pConfig::default();
203 let (_unstarted, _server, endpoint_manager) =
204 discovery::Builder::new().config(config).build();
205 endpoint_manager
206 }
207
208 #[tokio::test]
209 async fn test_update_consensus_endpoint() {
210 let endpoint_manager = create_mock_endpoint_manager();
211
212 let (mock_updater, updates) = MockConsensusAddressUpdater::new();
213 endpoint_manager.set_consensus_address_updater(Arc::new(mock_updater));
214
215 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
216 let network_pubkey = network_key.public();
217
218 let addresses = vec![
219 "/ip4/127.0.0.1/udp/9000".parse().unwrap(),
220 "/ip4/127.0.0.1/udp/9001".parse().unwrap(),
221 ];
222
223 let result = endpoint_manager.update_endpoint(
224 EndpointId::Consensus(network_pubkey.clone()),
225 AddressSource::Admin,
226 addresses.clone(),
227 );
228
229 assert!(result.is_ok());
230
231 let recorded_updates = updates.lock().unwrap();
232 assert_eq!(recorded_updates.len(), 1);
233 assert_eq!(recorded_updates[0].0, network_pubkey.clone());
234 assert_eq!(recorded_updates[0].1, addresses);
235 }
236
237 #[tokio::test]
238 async fn test_update_consensus_endpoint_without_updater_buffers() {
239 let endpoint_manager = create_mock_endpoint_manager();
240
241 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
242 let network_pubkey = network_key.public();
243
244 let addresses = vec!["/ip4/127.0.0.1/udp/9000".parse().unwrap()];
245
246 let result = endpoint_manager.update_endpoint(
248 EndpointId::Consensus(network_pubkey.clone()),
249 AddressSource::Discovery,
250 addresses.clone(),
251 );
252 assert!(result.is_ok());
253
254 let (mock_updater, updates) = MockConsensusAddressUpdater::new();
256 endpoint_manager.set_consensus_address_updater(Arc::new(mock_updater));
257
258 let recorded_updates = updates.lock().unwrap();
259 assert_eq!(recorded_updates.len(), 1);
260 assert_eq!(recorded_updates[0].0, network_pubkey.clone());
261 assert_eq!(recorded_updates[0].1, addresses);
262 }
263
264 #[tokio::test]
265 async fn test_concurrent_update_endpoint_and_set_updater_no_lost_updates() {
266 use std::sync::Barrier;
267
268 let endpoint_manager = create_mock_endpoint_manager();
269
270 let num_buffered = 5;
271 let num_concurrent = 20;
272
273 for _ in 0..num_buffered {
275 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
276 endpoint_manager
277 .update_endpoint(
278 EndpointId::Consensus(network_key.public().clone()),
279 AddressSource::Discovery,
280 vec!["/ip4/127.0.0.1/udp/9000".parse().unwrap()],
281 )
282 .unwrap();
283 }
284
285 let barrier = Arc::new(Barrier::new(num_concurrent + 1));
287 let mut handles = Vec::new();
288
289 for i in 0..num_concurrent {
291 let em = endpoint_manager.clone();
292 let b = barrier.clone();
293 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
294 let pubkey = network_key.public().clone();
295 handles.push(std::thread::spawn(move || {
296 b.wait();
297 if i % 2 == 0 {
299 std::thread::yield_now();
300 }
301 em.update_endpoint(
302 EndpointId::Consensus(pubkey),
303 AddressSource::Discovery,
304 vec!["/ip4/127.0.0.1/udp/9000".parse().unwrap()],
305 )
306 .unwrap();
307 }));
308 }
309
310 let (mock_updater, updates) = MockConsensusAddressUpdater::new();
312 let em = endpoint_manager.clone();
313 let b = barrier.clone();
314 let setter_handle = std::thread::spawn(move || {
315 b.wait();
316 em.set_consensus_address_updater(Arc::new(mock_updater));
317 });
318
319 for h in handles {
320 h.join().unwrap();
321 }
322 setter_handle.join().unwrap();
323
324 let recorded = updates.lock().unwrap();
325 assert_eq!(
326 recorded.len(),
327 num_buffered + num_concurrent,
328 "expected {} updates but got {} — some were lost",
329 num_buffered + num_concurrent,
330 recorded.len(),
331 );
332 }
333}