sui_core/traffic_controller/
nodefw_test_server.rs1use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses};
5use axum::{
6 Json, Router,
7 extract::State,
8 http::StatusCode,
9 response::IntoResponse,
10 routing::{get, post},
11};
12use std::time::{Duration, SystemTime};
13use std::{collections::HashMap, net::SocketAddr, sync::Arc};
14use tokio::sync::{Mutex, Notify};
15use tokio::task::JoinHandle;
16
17#[derive(Clone)]
18struct AppState {
19 blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
21}
22
23pub struct NodeFwTestServer {
24 server_handle: Option<JoinHandle<()>>,
25 shutdown_signal: Arc<Notify>,
26 state: AppState,
27}
28
29impl NodeFwTestServer {
30 pub fn new() -> Self {
31 Self {
32 server_handle: None,
33 shutdown_signal: Arc::new(Notify::new()),
34 state: AppState {
35 blocklist: Arc::new(Mutex::new(HashMap::new())),
36 },
37 }
38 }
39
40 pub async fn start(&mut self, port: u16) {
41 let app_state = self.state.clone();
42 let app = Router::new()
43 .route("/list_addresses", get(Self::list_addresses))
44 .route("/block_addresses", post(Self::block_addresses))
45 .with_state(app_state.clone());
46
47 let addr = SocketAddr::from(([127, 0, 0, 1], port));
48
49 let handle = tokio::spawn(async move {
50 let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
51 axum::serve(listener, app).await.unwrap();
52 });
53
54 tokio::spawn(Self::periodically_remove_expired_addresses(
55 app_state.blocklist.clone(),
56 ));
57
58 self.server_handle = Some(handle);
59 }
60
61 pub async fn list_addresses_rpc(&self) -> Vec<BlockAddress> {
63 let blocklist = self.state.blocklist.lock().await;
64 blocklist.keys().cloned().collect()
65 }
66
67 async fn list_addresses(State(state): State<AppState>) -> impl IntoResponse {
69 let blocklist = state.blocklist.lock().await;
70 let block_addresses = blocklist.keys().cloned().collect();
71 Json(BlockAddresses {
72 addresses: block_addresses,
73 })
74 }
75
76 async fn periodically_remove_expired_addresses(
77 blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
78 ) {
79 loop {
80 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
81 let mut blocklist = blocklist.lock().await;
82 let now = SystemTime::now();
83 blocklist.retain(|_address, expiry| now < *expiry);
84 }
85 }
86
87 async fn block_addresses(
89 State(state): State<AppState>,
90 Json(addresses): Json<BlockAddresses>,
91 ) -> impl IntoResponse {
92 let mut blocklist = state.blocklist.lock().await;
93 for addr in addresses.addresses.iter() {
94 blocklist.insert(
95 addr.clone(),
96 SystemTime::now() + Duration::from_secs(addr.ttl),
97 );
98 }
99 (StatusCode::CREATED, "created")
100 }
101
102 pub async fn stop(&self) {
103 self.shutdown_signal.notify_one();
104 }
105}
106
107impl Default for NodeFwTestServer {
108 fn default() -> Self {
109 Self::new()
110 }
111}