sui_core/traffic_controller/
nodefw_test_server.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses};
5use axum::{
6    Json, Router,
7    extract::State,
8    http::StatusCode,
9    response::IntoResponse,
10    routing::{get, post},
11};
12use std::time::{Duration, SystemTime};
13use std::{collections::HashMap, net::SocketAddr, sync::Arc};
14use tokio::sync::{Mutex, Notify};
15use tokio::task::JoinHandle;
16
17#[derive(Clone)]
18struct AppState {
19    /// BlockAddress -> expiry time
20    blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
21}
22
23pub struct NodeFwTestServer {
24    server_handle: Option<JoinHandle<()>>,
25    shutdown_signal: Arc<Notify>,
26    state: AppState,
27}
28
29impl NodeFwTestServer {
30    pub fn new() -> Self {
31        Self {
32            server_handle: None,
33            shutdown_signal: Arc::new(Notify::new()),
34            state: AppState {
35                blocklist: Arc::new(Mutex::new(HashMap::new())),
36            },
37        }
38    }
39
40    pub async fn start(&mut self, port: u16) {
41        let app_state = self.state.clone();
42        let app = Router::new()
43            .route("/list_addresses", get(Self::list_addresses))
44            .route("/block_addresses", post(Self::block_addresses))
45            .with_state(app_state.clone());
46
47        let addr = SocketAddr::from(([127, 0, 0, 1], port));
48
49        let handle = tokio::spawn(async move {
50            let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
51            axum::serve(listener, app).await.unwrap();
52        });
53
54        tokio::spawn(Self::periodically_remove_expired_addresses(
55            app_state.blocklist.clone(),
56        ));
57
58        self.server_handle = Some(handle);
59    }
60
61    /// Direct access api for test verification
62    pub async fn list_addresses_rpc(&self) -> Vec<BlockAddress> {
63        let blocklist = self.state.blocklist.lock().await;
64        blocklist.keys().cloned().collect()
65    }
66
67    /// Endpoint handler to list addresses
68    async fn list_addresses(State(state): State<AppState>) -> impl IntoResponse {
69        let blocklist = state.blocklist.lock().await;
70        let block_addresses = blocklist.keys().cloned().collect();
71        Json(BlockAddresses {
72            addresses: block_addresses,
73        })
74    }
75
76    async fn periodically_remove_expired_addresses(
77        blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
78    ) {
79        loop {
80            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
81            let mut blocklist = blocklist.lock().await;
82            let now = SystemTime::now();
83            blocklist.retain(|_address, expiry| now < *expiry);
84        }
85    }
86
87    /// Endpoint handler to block addresses
88    async fn block_addresses(
89        State(state): State<AppState>,
90        Json(addresses): Json<BlockAddresses>,
91    ) -> impl IntoResponse {
92        let mut blocklist = state.blocklist.lock().await;
93        for addr in addresses.addresses.iter() {
94            blocklist.insert(
95                addr.clone(),
96                SystemTime::now() + Duration::from_secs(addr.ttl),
97            );
98        }
99        (StatusCode::CREATED, "created")
100    }
101
102    pub async fn stop(&self) {
103        self.shutdown_signal.notify_one();
104    }
105}
106
107impl Default for NodeFwTestServer {
108    fn default() -> Self {
109        Self::new()
110    }
111}