1use crate::SuiNode;
5use axum::{
6 Router,
7 extract::{Query, State},
8 http::StatusCode,
9 routing::{get, post},
10};
11use base64::Engine;
12use humantime::parse_duration;
13use serde::Deserialize;
14use std::sync::Arc;
15use std::{
16 net::{IpAddr, Ipv4Addr, SocketAddr},
17 str::FromStr,
18};
19use sui_types::{
20 base_types::AuthorityName,
21 crypto::{RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
22 digests::TransactionDigest,
23 error::SuiErrorKind,
24 traffic_control::TrafficControlReconfigParams,
25};
26use telemetry_subscribers::TracingHandle;
27use tokio::sync::oneshot;
28use tracing::info;
29
30const LOGGING_ROUTE: &str = "/logging";
82const TRACING_ROUTE: &str = "/enable-tracing";
83const TRACING_RESET_ROUTE: &str = "/reset-tracing";
84const SET_BUFFER_STAKE_ROUTE: &str = "/set-override-buffer-stake";
85const CLEAR_BUFFER_STAKE_ROUTE: &str = "/clear-override-buffer-stake";
86const FORCE_CLOSE_EPOCH: &str = "/force-close-epoch";
87const CAPABILITIES: &str = "/capabilities";
88const NODE_CONFIG: &str = "/node-config";
89const RANDOMNESS_PARTIAL_SIGS_ROUTE: &str = "/randomness-partial-sigs";
90const RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE: &str = "/randomness-inject-partial-sigs";
91const RANDOMNESS_INJECT_FULL_SIG_ROUTE: &str = "/randomness-inject-full-sig";
92const GET_TX_COST_ROUTE: &str = "/get-tx-cost";
93const DUMP_CONSENSUS_TX_COST_ESTIMATES_ROUTE: &str = "/dump-consensus-tx-cost-estimates";
94const TRAFFIC_CONTROL: &str = "/traffic-control";
95
96struct AppState {
97 node: Arc<SuiNode>,
98 tracing_handle: TracingHandle,
99}
100
101pub async fn run_admin_server(node: Arc<SuiNode>, port: u16, tracing_handle: TracingHandle) {
102 let filter = tracing_handle.get_log().unwrap();
103
104 let app_state = AppState {
105 node,
106 tracing_handle,
107 };
108
109 let app = Router::new()
110 .route(LOGGING_ROUTE, get(get_filter))
111 .route(CAPABILITIES, get(capabilities))
112 .route(NODE_CONFIG, get(node_config))
113 .route(LOGGING_ROUTE, post(set_filter))
114 .route(
115 SET_BUFFER_STAKE_ROUTE,
116 post(set_override_protocol_upgrade_buffer_stake),
117 )
118 .route(
119 CLEAR_BUFFER_STAKE_ROUTE,
120 post(clear_override_protocol_upgrade_buffer_stake),
121 )
122 .route(FORCE_CLOSE_EPOCH, post(force_close_epoch))
123 .route(TRACING_ROUTE, post(enable_tracing))
124 .route(TRACING_RESET_ROUTE, post(reset_tracing))
125 .route(RANDOMNESS_PARTIAL_SIGS_ROUTE, get(randomness_partial_sigs))
126 .route(
127 RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE,
128 post(randomness_inject_partial_sigs),
129 )
130 .route(
131 RANDOMNESS_INJECT_FULL_SIG_ROUTE,
132 post(randomness_inject_full_sig),
133 )
134 .route(GET_TX_COST_ROUTE, get(get_tx_cost))
135 .route(
136 DUMP_CONSENSUS_TX_COST_ESTIMATES_ROUTE,
137 get(dump_consensus_tx_cost_estimates),
138 )
139 .route(TRAFFIC_CONTROL, post(traffic_control))
140 .with_state(Arc::new(app_state));
141
142 let socket_address = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
143 info!(
144 filter =% filter,
145 address =% socket_address,
146 "starting admin server"
147 );
148
149 let listener = tokio::net::TcpListener::bind(&socket_address)
150 .await
151 .unwrap();
152 axum::serve(
153 listener,
154 app.into_make_service_with_connect_info::<SocketAddr>(),
155 )
156 .await
157 .unwrap();
158}
159
160#[derive(Deserialize)]
161struct EnableTracing {
162 filter: Option<String>,
164 duration: Option<String>,
165
166 trace_file: Option<String>,
168
169 sample_rate: Option<f64>,
171}
172
173async fn enable_tracing(
174 State(state): State<Arc<AppState>>,
175 query: Query<EnableTracing>,
176) -> (StatusCode, String) {
177 let Query(EnableTracing {
178 filter,
179 duration,
180 trace_file,
181 sample_rate,
182 }) = query;
183
184 let mut response = Vec::new();
185
186 if let Some(sample_rate) = sample_rate {
187 state.tracing_handle.update_sampling_rate(sample_rate);
188 response.push(format!("sample rate set to {:?}", sample_rate));
189 }
190
191 if let Some(trace_file) = trace_file {
192 if let Err(err) = state.tracing_handle.update_trace_file(&trace_file) {
193 response.push(format!("can't update trace file: {:?}", err));
194 return (StatusCode::BAD_REQUEST, response.join("\n"));
195 } else {
196 response.push(format!("trace file set to {:?}", trace_file));
197 }
198 }
199
200 let Some(filter) = filter else {
201 return (StatusCode::OK, response.join("\n"));
202 };
203
204 let Some(duration) = duration else {
206 response.push("can't update filter: missing duration".into());
207 return (StatusCode::BAD_REQUEST, response.join("\n"));
208 };
209
210 let Ok(duration) = parse_duration(&duration) else {
211 response.push("can't update filter: invalid duration".into());
212 return (StatusCode::BAD_REQUEST, response.join("\n"));
213 };
214
215 match state.tracing_handle.update_trace_filter(&filter, duration) {
216 Ok(()) => {
217 response.push(format!("filter set to {:?}", filter));
218 response.push(format!("filter will be reset after {:?}", duration));
219 (StatusCode::OK, response.join("\n"))
220 }
221 Err(err) => {
222 response.push(format!("can't update filter: {:?}", err));
223 (StatusCode::BAD_REQUEST, response.join("\n"))
224 }
225 }
226}
227
228async fn reset_tracing(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
229 state.tracing_handle.reset_trace();
230 (
231 StatusCode::OK,
232 "tracing filter reset to TRACE_FILTER env var".into(),
233 )
234}
235
236async fn get_filter(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
237 match state.tracing_handle.get_log() {
238 Ok(filter) => (StatusCode::OK, filter),
239 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
240 }
241}
242
243async fn set_filter(
244 State(state): State<Arc<AppState>>,
245 new_filter: String,
246) -> (StatusCode, String) {
247 match state.tracing_handle.update_log(&new_filter) {
248 Ok(()) => {
249 info!(filter =% new_filter, "Log filter updated");
250 (StatusCode::OK, "".into())
251 }
252 Err(err) => (StatusCode::BAD_REQUEST, err.to_string()),
253 }
254}
255
256async fn capabilities(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
257 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
258
259 let capabilities = epoch_store.get_capabilities_v1();
261 let mut output = String::new();
262 for capability in capabilities.unwrap_or_default() {
263 output.push_str(&format!("{:?}\n", capability));
264 }
265
266 let capabilities = epoch_store.get_capabilities_v2();
267 for capability in capabilities.unwrap_or_default() {
268 output.push_str(&format!("{:?}\n", capability));
269 }
270
271 (StatusCode::OK, output)
272}
273
274async fn node_config(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
275 let node_config = &state.node.config;
276
277 (StatusCode::OK, format!("{:#?}\n", node_config))
279}
280
281#[derive(Deserialize)]
282struct Epoch {
283 epoch: u64,
284}
285
286async fn clear_override_protocol_upgrade_buffer_stake(
287 State(state): State<Arc<AppState>>,
288 epoch: Query<Epoch>,
289) -> (StatusCode, String) {
290 let Query(Epoch { epoch }) = epoch;
291
292 match state
293 .node
294 .clear_override_protocol_upgrade_buffer_stake(epoch)
295 {
296 Ok(()) => (
297 StatusCode::OK,
298 "protocol upgrade buffer stake cleared\n".to_string(),
299 ),
300 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
301 }
302}
303
304#[derive(Deserialize)]
305struct SetBufferStake {
306 buffer_bps: u64,
307 epoch: u64,
308}
309
310async fn set_override_protocol_upgrade_buffer_stake(
311 State(state): State<Arc<AppState>>,
312 buffer_state: Query<SetBufferStake>,
313) -> (StatusCode, String) {
314 let Query(SetBufferStake { buffer_bps, epoch }) = buffer_state;
315
316 match state
317 .node
318 .set_override_protocol_upgrade_buffer_stake(epoch, buffer_bps)
319 {
320 Ok(()) => (
321 StatusCode::OK,
322 format!("protocol upgrade buffer stake set to '{}'\n", buffer_bps),
323 ),
324 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
325 }
326}
327
328async fn force_close_epoch(
329 State(state): State<Arc<AppState>>,
330 epoch: Query<Epoch>,
331) -> (StatusCode, String) {
332 let Query(Epoch {
333 epoch: expected_epoch,
334 }) = epoch;
335 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
336 let actual_epoch = epoch_store.epoch();
337 if actual_epoch != expected_epoch {
338 let err = SuiErrorKind::WrongEpoch {
339 expected_epoch,
340 actual_epoch,
341 };
342 return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string());
343 }
344
345 match state.node.close_epoch(&epoch_store).await {
346 Ok(()) => (
347 StatusCode::OK,
348 "close_epoch() called successfully\n".to_string(),
349 ),
350 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
351 }
352}
353
354#[derive(Deserialize)]
355struct Round {
356 round: u64,
357}
358
359async fn randomness_partial_sigs(
360 State(state): State<Arc<AppState>>,
361 round: Query<Round>,
362) -> (StatusCode, String) {
363 let Query(Round { round }) = round;
364
365 let (tx, rx) = oneshot::channel();
366 state
367 .node
368 .randomness_handle()
369 .admin_get_partial_signatures(RandomnessRound(round), tx);
370
371 let sigs = match rx.await {
372 Ok(sigs) => sigs,
373 Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
374 };
375
376 let output = format!(
377 "{}\n",
378 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sigs)
379 );
380
381 (StatusCode::OK, output)
382}
383
384#[derive(Deserialize)]
385struct PartialSigsToInject {
386 hex_authority_name: String,
387 round: u64,
388 base64_sigs: String,
389}
390
391async fn randomness_inject_partial_sigs(
392 State(state): State<Arc<AppState>>,
393 args: Query<PartialSigsToInject>,
394) -> (StatusCode, String) {
395 let Query(PartialSigsToInject {
396 hex_authority_name,
397 round,
398 base64_sigs,
399 }) = args;
400
401 let authority_name = match AuthorityName::from_str(hex_authority_name.as_str()) {
402 Ok(authority_name) => authority_name,
403 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
404 };
405
406 let sigs: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sigs) {
407 Ok(sigs) => sigs,
408 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
409 };
410
411 let sigs: Vec<RandomnessPartialSignature> = match bcs::from_bytes(&sigs) {
412 Ok(sigs) => sigs,
413 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
414 };
415
416 let (tx_result, rx_result) = oneshot::channel();
417 state
418 .node
419 .randomness_handle()
420 .admin_inject_partial_signatures(authority_name, RandomnessRound(round), sigs, tx_result);
421
422 match rx_result.await {
423 Ok(Ok(())) => (StatusCode::OK, "partial signatures injected\n".to_string()),
424 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
425 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
426 }
427}
428
429#[derive(Deserialize)]
430struct FullSigToInject {
431 round: u64,
432 base64_sig: String,
433}
434
435async fn randomness_inject_full_sig(
436 State(state): State<Arc<AppState>>,
437 args: Query<FullSigToInject>,
438) -> (StatusCode, String) {
439 let Query(FullSigToInject { round, base64_sig }) = args;
440
441 let sig: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sig) {
442 Ok(sig) => sig,
443 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
444 };
445
446 let sig: RandomnessSignature = match bcs::from_bytes(&sig) {
447 Ok(sig) => sig,
448 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
449 };
450
451 let (tx_result, rx_result) = oneshot::channel();
452 state.node.randomness_handle().admin_inject_full_signature(
453 RandomnessRound(round),
454 sig,
455 tx_result,
456 );
457
458 match rx_result.await {
459 Ok(Ok(())) => (StatusCode::OK, "full signature injected\n".to_string()),
460 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
461 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
462 }
463}
464
465#[derive(Deserialize)]
466struct GetTxCost {
467 tx_digest: String,
468}
469
470async fn get_tx_cost(
471 State(state): State<Arc<AppState>>,
472 args: Query<GetTxCost>,
473) -> (StatusCode, String) {
474 let Query(GetTxCost { tx_digest }) = args;
475 let tx_digest = TransactionDigest::from_str(tx_digest.as_str()).unwrap();
476
477 let Some(transaction) = state
478 .node
479 .state()
480 .get_transaction_cache_reader()
481 .get_transaction_block(&tx_digest)
482 else {
483 return (StatusCode::BAD_REQUEST, "Transaction not found".to_string());
484 };
485
486 let Some(cost) = state
487 .node
488 .state()
489 .load_epoch_store_one_call_per_task()
490 .get_estimated_tx_cost(transaction.transaction_data())
491 .await
492 else {
493 return (StatusCode::BAD_REQUEST, "No estimate available".to_string());
494 };
495
496 (StatusCode::OK, cost.to_string())
497}
498
499async fn dump_consensus_tx_cost_estimates(
500 State(state): State<Arc<AppState>>,
501) -> (StatusCode, String) {
502 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
503 let estimates = epoch_store.get_consensus_tx_cost_estimates().await;
504 (StatusCode::OK, format!("{:#?}", estimates))
505}
506
507async fn traffic_control(
508 State(state): State<Arc<AppState>>,
509 args: Query<TrafficControlReconfigParams>,
510) -> (StatusCode, String) {
511 let Query(params) = args;
512 match state.node.state().reconfigure_traffic_control(params).await {
513 Ok(updated_state) => (
514 StatusCode::OK,
515 format!(
516 "Traffic control configured with:\n\
517 Error threshold: {:?}\n\
518 Spam threshold: {:?}\n\
519 Dry run: {:?}\n",
520 updated_state.error_threshold, updated_state.spam_threshold, updated_state.dry_run
521 ),
522 ),
523 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
524 }
525}