sui_proxy/
middleware.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3use crate::{consumer::ProtobufDecoder, peers::SuiNodeProvider};
4use axum::{
5    body::Body,
6    body::Bytes,
7    extract::{Extension, FromRequest},
8    http::{Request, StatusCode},
9    middleware::Next,
10    response::Response,
11};
12use axum_extra::headers::{ContentLength, ContentType};
13use axum_extra::typed_header::TypedHeader;
14use bytes::Buf;
15use hyper::header::CONTENT_ENCODING;
16use once_cell::sync::Lazy;
17use prometheus::{CounterVec, proto::MetricFamily, register_counter_vec};
18use std::sync::Arc;
19use sui_tls::TlsConnectionInfo;
20use tracing::error;
21
22static MIDDLEWARE_OPS: Lazy<CounterVec> = Lazy::new(|| {
23    register_counter_vec!(
24        "middleware_operations",
25        "Operations counters and status for axum middleware.",
26        &["operation", "status"]
27    )
28    .unwrap()
29});
30
31/// we expect sui-node to send us an http header content-length encoding.
32pub async fn expect_content_length(
33    TypedHeader(_content_length): TypedHeader<ContentLength>,
34    request: Request<Body>,
35    next: Next,
36) -> Result<Response, (StatusCode, &'static str)> {
37    Ok(next.run(request).await)
38}
39
40/// we expect sui-node to send us an http header content-type encoding.
41pub async fn expect_mysten_proxy_header(
42    TypedHeader(content_type): TypedHeader<ContentType>,
43    request: Request<Body>,
44    next: Next,
45) -> Result<Response, (StatusCode, &'static str)> {
46    match format!("{content_type}").as_str() {
47        prometheus::PROTOBUF_FORMAT => Ok(next.run(request).await),
48        ct => {
49            error!("invalid content-type; {ct}");
50            MIDDLEWARE_OPS
51                .with_label_values(&["expect_mysten_proxy_header", "invalid-content-type"])
52                .inc();
53            Err((StatusCode::BAD_REQUEST, "invalid content-type header"))
54        }
55    }
56}
57
58/// we expect that calling sui-nodes are known on the blockchain and we enforce
59/// their pub key tls creds here
60pub async fn expect_valid_public_key(
61    Extension(allower): Extension<Arc<SuiNodeProvider>>,
62    Extension(tls_connect_info): Extension<TlsConnectionInfo>,
63    mut request: Request<Body>,
64    next: Next,
65) -> Result<Response, (StatusCode, &'static str)> {
66    let Some(public_key) = tls_connect_info.public_key() else {
67        error!("unable to obtain public key from connecting client");
68        MIDDLEWARE_OPS
69            .with_label_values(&["expect_valid_public_key", "missing-public-key"])
70            .inc();
71        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
72    };
73    let Some(peer) = allower.get(public_key) else {
74        error!("node with unknown pub key tried to connect {}", public_key);
75        MIDDLEWARE_OPS
76            .with_label_values(&[
77                "expect_valid_public_key",
78                "unknown-validator-connection-attempt",
79            ])
80            .inc();
81        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
82    };
83    request.extensions_mut().insert(peer);
84    Ok(next.run(request).await)
85}
86
87// extractor that shows how to consume the request body upfront
88#[derive(Debug)]
89pub struct LenDelimProtobuf(pub Vec<MetricFamily>);
90
91impl<S> FromRequest<S> for LenDelimProtobuf
92where
93    S: Send + Sync,
94{
95    type Rejection = (StatusCode, String);
96
97    async fn from_request(
98        req: Request<axum::body::Body>,
99        state: &S,
100    ) -> Result<Self, Self::Rejection> {
101        let should_be_snappy = req
102            .headers()
103            .get(CONTENT_ENCODING)
104            .map(|v| v.as_bytes() == b"snappy")
105            .unwrap_or(false);
106
107        let body = Bytes::from_request(req, state).await.map_err(|e| {
108            let msg = format!("error extracting bytes; {e}");
109            error!(msg);
110            MIDDLEWARE_OPS
111                .with_label_values(&["LenDelimProtobuf_from_request", "unable-to-extract-bytes"])
112                .inc();
113            (e.status(), msg)
114        })?;
115
116        let intermediate = if should_be_snappy {
117            let mut s = snap::raw::Decoder::new();
118            let decompressed = s.decompress_vec(&body).map_err(|e| {
119                let msg = format!("unable to decode snappy encoded protobufs; {e}");
120                error!(msg);
121                MIDDLEWARE_OPS
122                    .with_label_values(&[
123                        "LenDelimProtobuf_decompress_vec",
124                        "unable-to-decode-snappy",
125                    ])
126                    .inc();
127                (StatusCode::BAD_REQUEST, msg)
128            })?;
129            Bytes::from(decompressed).reader()
130        } else {
131            body.reader()
132        };
133
134        let mut decoder = ProtobufDecoder::new(intermediate);
135        let decoded = decoder.parse::<MetricFamily>().map_err(|e| {
136            let msg = format!("unable to decode len deliminated protobufs; {e}");
137            error!(msg);
138            MIDDLEWARE_OPS
139                .with_label_values(&[
140                    "LenDelimProtobuf_from_request",
141                    "unable-to-decode-protobufs",
142                ])
143                .inc();
144            (StatusCode::BAD_REQUEST, msg)
145        })?;
146        Ok(Self(decoded))
147    }
148}