1use crate::{consumer::ProtobufDecoder, peers::SuiNodeProvider};
4use axum::{
5 body::Body,
6 body::Bytes,
7 extract::{Extension, FromRequest},
8 http::{Request, StatusCode},
9 middleware::Next,
10 response::Response,
11};
12use axum_extra::headers::{ContentLength, ContentType};
13use axum_extra::typed_header::TypedHeader;
14use bytes::Buf;
15use hyper::header::CONTENT_ENCODING;
16use once_cell::sync::Lazy;
17use prometheus::{CounterVec, proto::MetricFamily, register_counter_vec};
18use std::sync::Arc;
19use sui_tls::TlsConnectionInfo;
20use tracing::error;
21
22static MIDDLEWARE_OPS: Lazy<CounterVec> = Lazy::new(|| {
23 register_counter_vec!(
24 "middleware_operations",
25 "Operations counters and status for axum middleware.",
26 &["operation", "status"]
27 )
28 .unwrap()
29});
30
31pub async fn expect_content_length(
33 TypedHeader(_content_length): TypedHeader<ContentLength>,
34 request: Request<Body>,
35 next: Next,
36) -> Result<Response, (StatusCode, &'static str)> {
37 Ok(next.run(request).await)
38}
39
40pub async fn expect_mysten_proxy_header(
42 TypedHeader(content_type): TypedHeader<ContentType>,
43 request: Request<Body>,
44 next: Next,
45) -> Result<Response, (StatusCode, &'static str)> {
46 match format!("{content_type}").as_str() {
47 prometheus::PROTOBUF_FORMAT => Ok(next.run(request).await),
48 ct => {
49 error!("invalid content-type; {ct}");
50 MIDDLEWARE_OPS
51 .with_label_values(&["expect_mysten_proxy_header", "invalid-content-type"])
52 .inc();
53 Err((StatusCode::BAD_REQUEST, "invalid content-type header"))
54 }
55 }
56}
57
58pub async fn expect_valid_public_key(
61 Extension(allower): Extension<Arc<SuiNodeProvider>>,
62 Extension(tls_connect_info): Extension<TlsConnectionInfo>,
63 mut request: Request<Body>,
64 next: Next,
65) -> Result<Response, (StatusCode, &'static str)> {
66 let Some(public_key) = tls_connect_info.public_key() else {
67 error!("unable to obtain public key from connecting client");
68 MIDDLEWARE_OPS
69 .with_label_values(&["expect_valid_public_key", "missing-public-key"])
70 .inc();
71 return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
72 };
73 let Some(peer) = allower.get(public_key) else {
74 error!("node with unknown pub key tried to connect {}", public_key);
75 MIDDLEWARE_OPS
76 .with_label_values(&[
77 "expect_valid_public_key",
78 "unknown-validator-connection-attempt",
79 ])
80 .inc();
81 return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
82 };
83 request.extensions_mut().insert(peer);
84 Ok(next.run(request).await)
85}
86
87#[derive(Debug)]
89pub struct LenDelimProtobuf(pub Vec<MetricFamily>);
90
91impl<S> FromRequest<S> for LenDelimProtobuf
92where
93 S: Send + Sync,
94{
95 type Rejection = (StatusCode, String);
96
97 async fn from_request(
98 req: Request<axum::body::Body>,
99 state: &S,
100 ) -> Result<Self, Self::Rejection> {
101 let should_be_snappy = req
102 .headers()
103 .get(CONTENT_ENCODING)
104 .map(|v| v.as_bytes() == b"snappy")
105 .unwrap_or(false);
106
107 let body = Bytes::from_request(req, state).await.map_err(|e| {
108 let msg = format!("error extracting bytes; {e}");
109 error!(msg);
110 MIDDLEWARE_OPS
111 .with_label_values(&["LenDelimProtobuf_from_request", "unable-to-extract-bytes"])
112 .inc();
113 (e.status(), msg)
114 })?;
115
116 let intermediate = if should_be_snappy {
117 let mut s = snap::raw::Decoder::new();
118 let decompressed = s.decompress_vec(&body).map_err(|e| {
119 let msg = format!("unable to decode snappy encoded protobufs; {e}");
120 error!(msg);
121 MIDDLEWARE_OPS
122 .with_label_values(&[
123 "LenDelimProtobuf_decompress_vec",
124 "unable-to-decode-snappy",
125 ])
126 .inc();
127 (StatusCode::BAD_REQUEST, msg)
128 })?;
129 Bytes::from(decompressed).reader()
130 } else {
131 body.reader()
132 };
133
134 let mut decoder = ProtobufDecoder::new(intermediate);
135 let decoded = decoder.parse::<MetricFamily>().map_err(|e| {
136 let msg = format!("unable to decode len deliminated protobufs; {e}");
137 error!(msg);
138 MIDDLEWARE_OPS
139 .with_label_values(&[
140 "LenDelimProtobuf_from_request",
141 "unable-to-decode-protobufs",
142 ])
143 .inc();
144 (StatusCode::BAD_REQUEST, msg)
145 })?;
146 Ok(Self(decoded))
147 }
148}