1use std::sync::{Arc, Mutex};
5
6use arc_swap::ArcSwapOption;
7use mysten_network::Multiaddr;
8use serde::{Deserialize, Serialize};
9use sui_types::crypto::{NetworkPublicKey, ToFromBytes};
10use sui_types::error::SuiResult;
11use tap::TapFallible;
12use tracing::{info, warn};
13
14use crate::discovery;
15
16#[derive(Clone)]
19pub struct EndpointManager {
20 inner: Arc<Inner>,
21}
22
23struct Inner {
24 discovery_sender: discovery::Sender,
25 consensus_address_updater: ArcSwapOption<Arc<dyn ConsensusAddressUpdater>>,
26 pending_consensus_updates: Mutex<Vec<(NetworkPublicKey, AddressSource, Vec<Multiaddr>)>>,
27}
28
29pub trait ConsensusAddressUpdater: Send + Sync + 'static {
30 fn update_address(
31 &self,
32 network_pubkey: NetworkPublicKey,
33 source: AddressSource,
34 addresses: Vec<Multiaddr>,
35 ) -> SuiResult<()>;
36}
37
38impl EndpointManager {
39 pub fn new(discovery_sender: discovery::Sender) -> Self {
40 Self {
41 inner: Arc::new(Inner {
42 discovery_sender,
43 consensus_address_updater: ArcSwapOption::empty(),
44 pending_consensus_updates: Mutex::new(Vec::new()),
45 }),
46 }
47 }
48
49 pub fn set_consensus_address_updater(
50 &self,
51 consensus_address_updater: Arc<dyn ConsensusAddressUpdater>,
52 ) {
53 let mut pending = self.inner.pending_consensus_updates.lock().unwrap();
54
55 for (pubkey, source, addrs) in pending.drain(..) {
56 if let Err(e) = consensus_address_updater.update_address(pubkey.clone(), source, addrs)
57 {
58 warn!(
59 ?pubkey,
60 "Error replaying buffered consensus address update: {e:?}"
61 );
62 }
63 }
64
65 self.inner
66 .consensus_address_updater
67 .store(Some(Arc::new(consensus_address_updater)));
68 }
69
70 pub fn update_endpoint(
75 &self,
76 endpoint: EndpointId,
77 source: AddressSource,
78 addresses: Vec<Multiaddr>,
79 ) -> SuiResult<()> {
80 match endpoint {
81 EndpointId::P2p(peer_id) => {
82 let anemo_addresses: Vec<_> = addresses
83 .into_iter()
84 .filter_map(|addr| {
85 addr.to_anemo_address()
86 .tap_err(|_| {
87 warn!(
88 ?addr,
89 "Skipping peer address: can't convert to anemo address"
90 )
91 })
92 .ok()
93 })
94 .collect();
95
96 self.inner
97 .discovery_sender
98 .peer_address_change(peer_id, source, anemo_addresses);
99 }
100 EndpointId::Consensus(network_pubkey) => {
101 let mut pending = self.inner.pending_consensus_updates.lock().unwrap();
106 if let Some(updater) = self.inner.consensus_address_updater.load_full() {
107 drop(pending);
108 updater
109 .update_address(network_pubkey.clone(), source, addresses)
110 .map_err(|e| {
111 warn!(?network_pubkey, "Error updating consensus address: {e:?}");
112 e
113 })?;
114 } else {
115 info!(
116 ?network_pubkey,
117 "Buffering consensus address update (updater not yet set)"
118 );
119 pending.push((network_pubkey, source, addresses));
120 }
121 }
122 }
123
124 Ok(())
125 }
126
127 pub fn clear_source(&self, peer_id: anemo::PeerId, source: AddressSource) {
129 let _ = self.update_endpoint(EndpointId::P2p(peer_id), source, vec![]);
130 if let Ok(network_pubkey) = NetworkPublicKey::from_bytes(&peer_id.0) {
131 let _ = self.update_endpoint(EndpointId::Consensus(network_pubkey), source, vec![]);
132 }
133
134 fn _assert_all_variants_handled(id: &EndpointId) {
138 match id {
139 EndpointId::P2p(_) | EndpointId::Consensus(_) => {}
140 }
141 }
142 }
143}
144
145#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
146pub enum EndpointId {
147 P2p(anemo::PeerId),
148 Consensus(NetworkPublicKey),
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
152pub enum AddressSource {
154 Admin = 1, Config = 2, Discovery = 3, Seed = 4, Chain = 5, }
160
161impl AddressSource {
162 pub const DEFAULT_ADDRESS_SOURCE_CODE: i64 = 0;
164
165 pub const fn metric_code(self) -> i64 {
167 self as i64
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use fastcrypto::traits::KeyPair;
175 use std::sync::{Arc, Mutex};
176 use sui_types::crypto::{NetworkKeyPair, get_key_pair};
177
178 type UpdateEntry = (NetworkPublicKey, Vec<Multiaddr>);
179 struct MockConsensusAddressUpdater {
181 updates: Arc<Mutex<Vec<UpdateEntry>>>,
182 }
183
184 impl MockConsensusAddressUpdater {
185 fn new() -> (Self, Arc<Mutex<Vec<UpdateEntry>>>) {
186 let updates = Arc::new(Mutex::new(Vec::new()));
187 let updater = Self {
188 updates: updates.clone(),
189 };
190 (updater, updates)
191 }
192 }
193
194 impl ConsensusAddressUpdater for MockConsensusAddressUpdater {
195 fn update_address(
196 &self,
197 network_pubkey: NetworkPublicKey,
198 _source: AddressSource,
199 addresses: Vec<Multiaddr>,
200 ) -> SuiResult<()> {
201 self.updates
202 .lock()
203 .unwrap()
204 .push((network_pubkey.clone(), addresses));
205 Ok(())
206 }
207 }
208
209 fn create_mock_endpoint_manager() -> EndpointManager {
210 use sui_config::p2p::P2pConfig;
211
212 let config = P2pConfig::default();
213 let (_unstarted, _server, endpoint_manager) =
214 discovery::Builder::new().config(config).build();
215 endpoint_manager
216 }
217
218 #[tokio::test]
219 async fn test_update_consensus_endpoint() {
220 let endpoint_manager = create_mock_endpoint_manager();
221
222 let (mock_updater, updates) = MockConsensusAddressUpdater::new();
223 endpoint_manager.set_consensus_address_updater(Arc::new(mock_updater));
224
225 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
226 let network_pubkey = network_key.public();
227
228 let addresses = vec![
229 "/ip4/127.0.0.1/udp/9000".parse().unwrap(),
230 "/ip4/127.0.0.1/udp/9001".parse().unwrap(),
231 ];
232
233 let result = endpoint_manager.update_endpoint(
234 EndpointId::Consensus(network_pubkey.clone()),
235 AddressSource::Admin,
236 addresses.clone(),
237 );
238
239 assert!(result.is_ok());
240
241 let recorded_updates = updates.lock().unwrap();
242 assert_eq!(recorded_updates.len(), 1);
243 assert_eq!(recorded_updates[0].0, network_pubkey.clone());
244 assert_eq!(recorded_updates[0].1, addresses);
245 }
246
247 #[tokio::test]
248 async fn test_update_consensus_endpoint_without_updater_buffers() {
249 let endpoint_manager = create_mock_endpoint_manager();
250
251 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
252 let network_pubkey = network_key.public();
253
254 let addresses = vec!["/ip4/127.0.0.1/udp/9000".parse().unwrap()];
255
256 let result = endpoint_manager.update_endpoint(
258 EndpointId::Consensus(network_pubkey.clone()),
259 AddressSource::Discovery,
260 addresses.clone(),
261 );
262 assert!(result.is_ok());
263
264 let (mock_updater, updates) = MockConsensusAddressUpdater::new();
266 endpoint_manager.set_consensus_address_updater(Arc::new(mock_updater));
267
268 let recorded_updates = updates.lock().unwrap();
269 assert_eq!(recorded_updates.len(), 1);
270 assert_eq!(recorded_updates[0].0, network_pubkey.clone());
271 assert_eq!(recorded_updates[0].1, addresses);
272 }
273
274 #[tokio::test]
275 async fn test_concurrent_update_endpoint_and_set_updater_no_lost_updates() {
276 use std::sync::Barrier;
277
278 let endpoint_manager = create_mock_endpoint_manager();
279
280 let num_buffered = 5;
281 let num_concurrent = 20;
282
283 for _ in 0..num_buffered {
285 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
286 endpoint_manager
287 .update_endpoint(
288 EndpointId::Consensus(network_key.public().clone()),
289 AddressSource::Discovery,
290 vec!["/ip4/127.0.0.1/udp/9000".parse().unwrap()],
291 )
292 .unwrap();
293 }
294
295 let barrier = Arc::new(Barrier::new(num_concurrent + 1));
297 let mut handles = Vec::new();
298
299 for i in 0..num_concurrent {
301 let em = endpoint_manager.clone();
302 let b = barrier.clone();
303 let (_, network_key): (_, NetworkKeyPair) = get_key_pair();
304 let pubkey = network_key.public().clone();
305 handles.push(std::thread::spawn(move || {
306 b.wait();
307 if i % 2 == 0 {
309 std::thread::yield_now();
310 }
311 em.update_endpoint(
312 EndpointId::Consensus(pubkey),
313 AddressSource::Discovery,
314 vec!["/ip4/127.0.0.1/udp/9000".parse().unwrap()],
315 )
316 .unwrap();
317 }));
318 }
319
320 let (mock_updater, updates) = MockConsensusAddressUpdater::new();
322 let em = endpoint_manager.clone();
323 let b = barrier.clone();
324 let setter_handle = std::thread::spawn(move || {
325 b.wait();
326 em.set_consensus_address_updater(Arc::new(mock_updater));
327 });
328
329 for h in handles {
330 h.join().unwrap();
331 }
332 setter_handle.join().unwrap();
333
334 let recorded = updates.lock().unwrap();
335 assert_eq!(
336 recorded.len(),
337 num_buffered + num_concurrent,
338 "expected {} updates but got {} — some were lost",
339 num_buffered + num_concurrent,
340 recorded.len(),
341 );
342 }
343}