1use axum::{extract::Extension, http::StatusCode, routing::get, Router};
4use mysten_metrics::RegistryService;
5use prometheus::{Registry, TextEncoder};
6use std::net::TcpListener;
7use std::sync::{Arc, RwLock};
8use tower::ServiceBuilder;
9use tower_http::trace::{DefaultOnResponse, TraceLayer};
10use tower_http::LatencyUnit;
11use tracing::Level;
12
13const METRICS_ROUTE: &str = "/metrics";
14const POD_HEALTH_ROUTE: &str = "/pod_health";
15
16type HealthCheckMetrics = Arc<RwLock<HealthCheck>>;
17
18#[derive(Debug)]
20struct HealthCheck {
21 consumer_operations_submitted: f64,
23}
24
25impl HealthCheck {
28 fn new() -> Self {
29 Self {
30 consumer_operations_submitted: 0.0,
31 }
32 }
33}
34
35pub fn start_prometheus_server(listener: TcpListener) -> RegistryService {
39 let registry = Registry::new();
40
41 let registry_service = RegistryService::new(registry);
42
43 let pod_health_data = Arc::new(RwLock::new(HealthCheck::new()));
44
45 let app = Router::new()
46 .route(METRICS_ROUTE, get(metrics))
47 .route(POD_HEALTH_ROUTE, get(pod_health))
48 .layer(Extension(registry_service.clone()))
49 .layer(Extension(pod_health_data.clone()))
50 .layer(
51 ServiceBuilder::new().layer(
52 TraceLayer::new_for_http().on_response(
53 DefaultOnResponse::new()
54 .level(Level::INFO)
55 .latency_unit(LatencyUnit::Seconds),
56 ),
57 ),
58 );
59
60 tokio::spawn(async move {
61 listener.set_nonblocking(true).unwrap();
62 let listener = tokio::net::TcpListener::from_std(listener).unwrap();
63 axum::serve(listener, app).await.unwrap();
64 });
65
66 registry_service
67}
68
69async fn metrics(
71 Extension(registry_service): Extension<RegistryService>,
72 Extension(pod_health): Extension<HealthCheckMetrics>,
73) -> (StatusCode, String) {
74 let mut metric_families = registry_service.gather_all();
75 metric_families.extend(prometheus::gather());
76
77 if let Some(consumer_operations_submitted) = metric_families
78 .iter()
79 .filter_map(|v| {
80 if v.get_name() == "consumer_operations_submitted" {
81 v.get_metric().first().map(|m| m.get_counter().get_value())
83 } else {
84 None
85 }
86 })
87 .next()
88 {
89 pod_health
90 .write()
91 .expect("unable to write to pod health metrics")
92 .consumer_operations_submitted = consumer_operations_submitted;
93 };
94 match TextEncoder.encode_to_string(&metric_families) {
95 Ok(metrics) => (StatusCode::OK, metrics),
96 Err(error) => (
97 StatusCode::INTERNAL_SERVER_ERROR,
98 format!("unable to encode metrics: {error}"),
99 ),
100 }
101}
102
103async fn pod_health(Extension(pod_health): Extension<HealthCheckMetrics>) -> (StatusCode, String) {
105 let consumer_operations_submitted = pod_health
106 .read()
107 .expect("unable to read pod health metrics")
108 .consumer_operations_submitted;
109
110 if consumer_operations_submitted > 0.0 {
111 (StatusCode::OK, consumer_operations_submitted.to_string())
112 } else {
113 (
114 StatusCode::SERVICE_UNAVAILABLE,
115 consumer_operations_submitted.to_string(),
116 )
117 }
118}