1use crate::SuiNode;
5use axum::{
6 Router,
7 extract::{Query, State},
8 http::StatusCode,
9 routing::{get, post},
10};
11use base64::Engine;
12use fastcrypto::encoding::{Encoding, Hex};
13use fastcrypto::traits::ToFromBytes;
14use humantime::parse_duration;
15use mysten_network::Multiaddr;
16use serde::Deserialize;
17use std::sync::Arc;
18use std::{
19 net::{IpAddr, Ipv4Addr, SocketAddr},
20 str::FromStr,
21};
22use sui_network::endpoint_manager::{AddressSource, EndpointId};
23use sui_types::{
24 base_types::AuthorityName,
25 crypto::{NetworkPublicKey, RandomnessPartialSignature, RandomnessRound, RandomnessSignature},
26 digests::TransactionDigest,
27 error::SuiErrorKind,
28 traffic_control::TrafficControlReconfigParams,
29};
30use telemetry_subscribers::TracingHandle;
31use tokio::sync::oneshot;
32use tracing::info;
33
34const NO_TRACING_HANDLE: &str = "tracing handle not available";
91const LOGGING_ROUTE: &str = "/logging";
92const TRACING_ROUTE: &str = "/enable-tracing";
93const TRACING_RESET_ROUTE: &str = "/reset-tracing";
94const SET_BUFFER_STAKE_ROUTE: &str = "/set-override-buffer-stake";
95const CLEAR_BUFFER_STAKE_ROUTE: &str = "/clear-override-buffer-stake";
96const FORCE_CLOSE_EPOCH: &str = "/force-close-epoch";
97const CAPABILITIES: &str = "/capabilities";
98const NODE_CONFIG: &str = "/node-config";
99const RANDOMNESS_PARTIAL_SIGS_ROUTE: &str = "/randomness-partial-sigs";
100const RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE: &str = "/randomness-inject-partial-sigs";
101const RANDOMNESS_INJECT_FULL_SIG_ROUTE: &str = "/randomness-inject-full-sig";
102const GET_TX_COST_ROUTE: &str = "/get-tx-cost";
103const DUMP_CONSENSUS_TX_COST_ESTIMATES_ROUTE: &str = "/dump-consensus-tx-cost-estimates";
104const TRAFFIC_CONTROL: &str = "/traffic-control";
105const UPDATE_ENDPOINT: &str = "/update-endpoint";
106
107struct AppState {
108 node: Arc<SuiNode>,
109 tracing_handle: Option<TracingHandle>,
110}
111
112pub async fn run_admin_server(
113 node: Arc<SuiNode>,
114 port: u16,
115 tracing_handle: Option<TracingHandle>,
116) {
117 let filter = tracing_handle
118 .as_ref()
119 .and_then(|h| h.get_log().ok())
120 .unwrap_or_else(|| NO_TRACING_HANDLE.to_string());
121
122 let app_state = AppState {
123 node,
124 tracing_handle,
125 };
126
127 let app = Router::new()
128 .route(LOGGING_ROUTE, get(get_filter))
129 .route(CAPABILITIES, get(capabilities))
130 .route(NODE_CONFIG, get(node_config))
131 .route(LOGGING_ROUTE, post(set_filter))
132 .route(
133 SET_BUFFER_STAKE_ROUTE,
134 post(set_override_protocol_upgrade_buffer_stake),
135 )
136 .route(
137 CLEAR_BUFFER_STAKE_ROUTE,
138 post(clear_override_protocol_upgrade_buffer_stake),
139 )
140 .route(FORCE_CLOSE_EPOCH, post(force_close_epoch))
141 .route(TRACING_ROUTE, post(enable_tracing))
142 .route(TRACING_RESET_ROUTE, post(reset_tracing))
143 .route(RANDOMNESS_PARTIAL_SIGS_ROUTE, get(randomness_partial_sigs))
144 .route(
145 RANDOMNESS_INJECT_PARTIAL_SIGS_ROUTE,
146 post(randomness_inject_partial_sigs),
147 )
148 .route(
149 RANDOMNESS_INJECT_FULL_SIG_ROUTE,
150 post(randomness_inject_full_sig),
151 )
152 .route(GET_TX_COST_ROUTE, get(get_tx_cost))
153 .route(
154 DUMP_CONSENSUS_TX_COST_ESTIMATES_ROUTE,
155 get(dump_consensus_tx_cost_estimates),
156 )
157 .route(TRAFFIC_CONTROL, post(traffic_control))
158 .route(UPDATE_ENDPOINT, post(update_endpoint))
159 .with_state(Arc::new(app_state));
160
161 let socket_address = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
162 info!(
163 filter =% filter,
164 address =% socket_address,
165 "starting admin server"
166 );
167
168 let listener = tokio::net::TcpListener::bind(&socket_address)
169 .await
170 .unwrap();
171 axum::serve(
172 listener,
173 app.into_make_service_with_connect_info::<SocketAddr>(),
174 )
175 .await
176 .unwrap();
177}
178
179#[derive(Deserialize)]
180struct EnableTracing {
181 filter: Option<String>,
183 duration: Option<String>,
184
185 trace_file: Option<String>,
187
188 sample_rate: Option<f64>,
190}
191
192async fn enable_tracing(
193 State(state): State<Arc<AppState>>,
194 query: Query<EnableTracing>,
195) -> (StatusCode, String) {
196 let Some(tracing_handle) = &state.tracing_handle else {
197 return (StatusCode::UNPROCESSABLE_ENTITY, NO_TRACING_HANDLE.into());
198 };
199
200 let Query(EnableTracing {
201 filter,
202 duration,
203 trace_file,
204 sample_rate,
205 }) = query;
206
207 let mut response = Vec::new();
208
209 if let Some(sample_rate) = sample_rate {
210 tracing_handle.update_sampling_rate(sample_rate);
211 response.push(format!("sample rate set to {:?}", sample_rate));
212 }
213
214 if let Some(trace_file) = trace_file {
215 if let Err(err) = tracing_handle.update_trace_file(&trace_file) {
216 response.push(format!("can't update trace file: {:?}", err));
217 return (StatusCode::BAD_REQUEST, response.join("\n"));
218 } else {
219 response.push(format!("trace file set to {:?}", trace_file));
220 }
221 }
222
223 let Some(filter) = filter else {
224 return (StatusCode::OK, response.join("\n"));
225 };
226
227 let Some(duration) = duration else {
229 response.push("can't update filter: missing duration".into());
230 return (StatusCode::BAD_REQUEST, response.join("\n"));
231 };
232
233 let Ok(duration) = parse_duration(&duration) else {
234 response.push("can't update filter: invalid duration".into());
235 return (StatusCode::BAD_REQUEST, response.join("\n"));
236 };
237
238 match tracing_handle.update_trace_filter(&filter, duration) {
239 Ok(()) => {
240 response.push(format!("filter set to {:?}", filter));
241 response.push(format!("filter will be reset after {:?}", duration));
242 (StatusCode::OK, response.join("\n"))
243 }
244 Err(err) => {
245 response.push(format!("can't update filter: {:?}", err));
246 (StatusCode::BAD_REQUEST, response.join("\n"))
247 }
248 }
249}
250
251async fn reset_tracing(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
252 let Some(tracing_handle) = &state.tracing_handle else {
253 return (StatusCode::UNPROCESSABLE_ENTITY, NO_TRACING_HANDLE.into());
254 };
255 tracing_handle.reset_trace();
256 (
257 StatusCode::OK,
258 "tracing filter reset to TRACE_FILTER env var".into(),
259 )
260}
261
262async fn get_filter(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
263 let Some(tracing_handle) = &state.tracing_handle else {
264 return (StatusCode::UNPROCESSABLE_ENTITY, NO_TRACING_HANDLE.into());
265 };
266 match tracing_handle.get_log() {
267 Ok(filter) => (StatusCode::OK, filter),
268 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
269 }
270}
271
272async fn set_filter(
273 State(state): State<Arc<AppState>>,
274 new_filter: String,
275) -> (StatusCode, String) {
276 let Some(tracing_handle) = &state.tracing_handle else {
277 return (StatusCode::UNPROCESSABLE_ENTITY, NO_TRACING_HANDLE.into());
278 };
279 match tracing_handle.update_log(&new_filter) {
280 Ok(()) => {
281 info!(filter =% new_filter, "Log filter updated");
282 (StatusCode::OK, "".into())
283 }
284 Err(err) => (StatusCode::BAD_REQUEST, err.to_string()),
285 }
286}
287
288async fn capabilities(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
289 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
290
291 let capabilities = epoch_store.get_capabilities_v2();
292 let mut output = String::new();
293 for capability in capabilities.unwrap_or_default() {
294 output.push_str(&format!("{:?}\n", capability));
295 }
296
297 (StatusCode::OK, output)
298}
299
300async fn node_config(State(state): State<Arc<AppState>>) -> (StatusCode, String) {
301 let node_config = &state.node.config;
302
303 (StatusCode::OK, format!("{:#?}\n", node_config))
305}
306
307#[derive(Deserialize)]
308struct Epoch {
309 epoch: u64,
310}
311
312async fn clear_override_protocol_upgrade_buffer_stake(
313 State(state): State<Arc<AppState>>,
314 epoch: Query<Epoch>,
315) -> (StatusCode, String) {
316 let Query(Epoch { epoch }) = epoch;
317
318 match state
319 .node
320 .clear_override_protocol_upgrade_buffer_stake(epoch)
321 {
322 Ok(()) => (
323 StatusCode::OK,
324 "protocol upgrade buffer stake cleared\n".to_string(),
325 ),
326 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
327 }
328}
329
330#[derive(Deserialize)]
331struct SetBufferStake {
332 buffer_bps: u64,
333 epoch: u64,
334}
335
336async fn set_override_protocol_upgrade_buffer_stake(
337 State(state): State<Arc<AppState>>,
338 buffer_state: Query<SetBufferStake>,
339) -> (StatusCode, String) {
340 let Query(SetBufferStake { buffer_bps, epoch }) = buffer_state;
341
342 match state
343 .node
344 .set_override_protocol_upgrade_buffer_stake(epoch, buffer_bps)
345 {
346 Ok(()) => (
347 StatusCode::OK,
348 format!("protocol upgrade buffer stake set to '{}'\n", buffer_bps),
349 ),
350 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
351 }
352}
353
354async fn force_close_epoch(
355 State(state): State<Arc<AppState>>,
356 epoch: Query<Epoch>,
357) -> (StatusCode, String) {
358 let Query(Epoch {
359 epoch: expected_epoch,
360 }) = epoch;
361 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
362 let actual_epoch = epoch_store.epoch();
363 if actual_epoch != expected_epoch {
364 let err = SuiErrorKind::WrongEpoch {
365 expected_epoch,
366 actual_epoch,
367 };
368 return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string());
369 }
370
371 match state.node.close_epoch(&epoch_store).await {
372 Ok(()) => (
373 StatusCode::OK,
374 "close_epoch() called successfully\n".to_string(),
375 ),
376 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
377 }
378}
379
380#[derive(Deserialize)]
381struct Round {
382 round: u64,
383}
384
385async fn randomness_partial_sigs(
386 State(state): State<Arc<AppState>>,
387 round: Query<Round>,
388) -> (StatusCode, String) {
389 let Query(Round { round }) = round;
390
391 let (tx, rx) = oneshot::channel();
392 state
393 .node
394 .randomness_handle()
395 .admin_get_partial_signatures(RandomnessRound(round), tx);
396
397 let sigs = match rx.await {
398 Ok(sigs) => sigs,
399 Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
400 };
401
402 let output = format!(
403 "{}\n",
404 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sigs)
405 );
406
407 (StatusCode::OK, output)
408}
409
410#[derive(Deserialize)]
411struct PartialSigsToInject {
412 hex_authority_name: String,
413 round: u64,
414 base64_sigs: String,
415}
416
417async fn randomness_inject_partial_sigs(
418 State(state): State<Arc<AppState>>,
419 args: Query<PartialSigsToInject>,
420) -> (StatusCode, String) {
421 let Query(PartialSigsToInject {
422 hex_authority_name,
423 round,
424 base64_sigs,
425 }) = args;
426
427 let authority_name = match AuthorityName::from_str(hex_authority_name.as_str()) {
428 Ok(authority_name) => authority_name,
429 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
430 };
431
432 let sigs: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sigs) {
433 Ok(sigs) => sigs,
434 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
435 };
436
437 let sigs: Vec<RandomnessPartialSignature> = match bcs::from_bytes(&sigs) {
438 Ok(sigs) => sigs,
439 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
440 };
441
442 let (tx_result, rx_result) = oneshot::channel();
443 state
444 .node
445 .randomness_handle()
446 .admin_inject_partial_signatures(authority_name, RandomnessRound(round), sigs, tx_result);
447
448 match rx_result.await {
449 Ok(Ok(())) => (StatusCode::OK, "partial signatures injected\n".to_string()),
450 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
451 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
452 }
453}
454
455#[derive(Deserialize)]
456struct FullSigToInject {
457 round: u64,
458 base64_sig: String,
459}
460
461async fn randomness_inject_full_sig(
462 State(state): State<Arc<AppState>>,
463 args: Query<FullSigToInject>,
464) -> (StatusCode, String) {
465 let Query(FullSigToInject { round, base64_sig }) = args;
466
467 let sig: Vec<u8> = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(base64_sig) {
468 Ok(sig) => sig,
469 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
470 };
471
472 let sig: RandomnessSignature = match bcs::from_bytes(&sig) {
473 Ok(sig) => sig,
474 Err(err) => return (StatusCode::BAD_REQUEST, err.to_string()),
475 };
476
477 let (tx_result, rx_result) = oneshot::channel();
478 state.node.randomness_handle().admin_inject_full_signature(
479 RandomnessRound(round),
480 sig,
481 tx_result,
482 );
483
484 match rx_result.await {
485 Ok(Ok(())) => (StatusCode::OK, "full signature injected\n".to_string()),
486 Ok(Err(e)) => (StatusCode::BAD_REQUEST, e.to_string()),
487 Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
488 }
489}
490
491#[derive(Deserialize)]
492struct GetTxCost {
493 tx_digest: String,
494}
495
496async fn get_tx_cost(
497 State(state): State<Arc<AppState>>,
498 args: Query<GetTxCost>,
499) -> (StatusCode, String) {
500 let Query(GetTxCost { tx_digest }) = args;
501 let tx_digest = TransactionDigest::from_str(tx_digest.as_str()).unwrap();
502
503 let Some(transaction) = state
504 .node
505 .state()
506 .get_transaction_cache_reader()
507 .get_transaction_block(&tx_digest)
508 else {
509 return (StatusCode::BAD_REQUEST, "Transaction not found".to_string());
510 };
511
512 let Some(cost) = state
513 .node
514 .state()
515 .load_epoch_store_one_call_per_task()
516 .get_estimated_tx_cost(transaction.transaction_data())
517 .await
518 else {
519 return (StatusCode::BAD_REQUEST, "No estimate available".to_string());
520 };
521
522 (StatusCode::OK, cost.to_string())
523}
524
525async fn dump_consensus_tx_cost_estimates(
526 State(state): State<Arc<AppState>>,
527) -> (StatusCode, String) {
528 let epoch_store = state.node.state().load_epoch_store_one_call_per_task();
529 let estimates = epoch_store.get_consensus_tx_cost_estimates().await;
530 (StatusCode::OK, format!("{:#?}", estimates))
531}
532
533async fn traffic_control(
534 State(state): State<Arc<AppState>>,
535 args: Query<TrafficControlReconfigParams>,
536) -> (StatusCode, String) {
537 let Query(params) = args;
538 match state.node.state().reconfigure_traffic_control(params).await {
539 Ok(updated_state) => (
540 StatusCode::OK,
541 format!(
542 "Traffic control configured with:\n\
543 Error threshold: {:?}\n\
544 Spam threshold: {:?}\n\
545 Dry run: {:?}\n",
546 updated_state.error_threshold, updated_state.spam_threshold, updated_state.dry_run
547 ),
548 ),
549 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
550 }
551}
552
553#[derive(Deserialize)]
554struct UpdateEndpointArgs {
555 endpoint_type: String,
556 id: String,
557 addresses: String,
558}
559
560async fn update_endpoint(
561 State(state): State<Arc<AppState>>,
562 args: Query<UpdateEndpointArgs>,
563) -> (StatusCode, String) {
564 let Query(UpdateEndpointArgs {
565 endpoint_type,
566 id,
567 addresses,
568 }) = args;
569
570 let endpoint_id = match endpoint_type.as_str() {
571 "p2p" => {
572 let peer_id_bytes = match Hex::decode(&id) {
573 Ok(bytes) => bytes,
574 Err(err) => {
575 return (
576 StatusCode::BAD_REQUEST,
577 format!("Invalid id hex encoding: {err}"),
578 );
579 }
580 };
581
582 let peer_id_bytes: [u8; 32] = match peer_id_bytes.try_into() {
583 Ok(bytes) => bytes,
584 Err(_) => {
585 return (
586 StatusCode::BAD_REQUEST,
587 "p2p id must be 32 bytes".to_string(),
588 );
589 }
590 };
591
592 EndpointId::P2p(anemo::PeerId(peer_id_bytes))
593 }
594 "consensus" => {
595 let network_pubkey_bytes = match Hex::decode(&id) {
596 Ok(bytes) => bytes,
597 Err(err) => {
598 return (
599 StatusCode::BAD_REQUEST,
600 format!("Invalid id hex encoding: {err}"),
601 );
602 }
603 };
604
605 let network_pubkey = match NetworkPublicKey::from_bytes(&network_pubkey_bytes) {
606 Ok(key) => key,
607 Err(err) => {
608 return (
609 StatusCode::BAD_REQUEST,
610 format!("Invalid network public key: {err:?}"),
611 );
612 }
613 };
614
615 EndpointId::Consensus(network_pubkey)
616 }
617 _ => {
618 return (
619 StatusCode::BAD_REQUEST,
620 format!("Unknown endpoint_type: {endpoint_type}"),
621 );
622 }
623 };
624
625 let mut parsed_addresses = Vec::new();
626 for addr_str in addresses.split(',') {
627 let addr_str = addr_str.trim();
628 if addr_str.is_empty() {
629 continue;
630 }
631 match addr_str.parse::<Multiaddr>() {
632 Ok(addr) => parsed_addresses.push(addr),
633 Err(err) => {
634 return (
635 StatusCode::BAD_REQUEST,
636 format!("Invalid address '{addr_str}': {err}"),
637 );
638 }
639 }
640 }
641
642 if let Err(e) = state.node.endpoint_manager().update_endpoint(
643 endpoint_id,
644 AddressSource::Admin,
645 parsed_addresses.clone(),
646 ) {
647 return (StatusCode::BAD_REQUEST, e.to_string());
648 }
649
650 (
651 StatusCode::OK,
652 format!(
653 "Endpoint updated for {endpoint_type} endpoint {id} with {} address(es)\n",
654 parsed_addresses.len(),
655 ),
656 )
657}