sui_proxy/
middleware.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
// Copyright (c) Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0
use crate::{consumer::ProtobufDecoder, peers::SuiNodeProvider};
use axum::{
    async_trait,
    body::Body,
    body::Bytes,
    extract::{Extension, FromRequest},
    http::{Request, StatusCode},
    middleware::Next,
    response::Response,
};
use axum_extra::headers::{ContentLength, ContentType};
use axum_extra::typed_header::TypedHeader;
use bytes::Buf;
use hyper::header::CONTENT_ENCODING;
use once_cell::sync::Lazy;
use prometheus::{proto::MetricFamily, register_counter_vec, CounterVec};
use std::sync::Arc;
use sui_tls::TlsConnectionInfo;
use tracing::error;

static MIDDLEWARE_OPS: Lazy<CounterVec> = Lazy::new(|| {
    register_counter_vec!(
        "middleware_operations",
        "Operations counters and status for axum middleware.",
        &["operation", "status"]
    )
    .unwrap()
});

static MIDDLEWARE_HEADERS: Lazy<CounterVec> = Lazy::new(|| {
    register_counter_vec!(
        "middleware_headers",
        "Operations counters and status for axum middleware.",
        &["header", "value"]
    )
    .unwrap()
});

/// we expect sui-node to send us an http header content-length encoding.
pub async fn expect_content_length(
    TypedHeader(content_length): TypedHeader<ContentLength>,
    request: Request<Body>,
    next: Next,
) -> Result<Response, (StatusCode, &'static str)> {
    MIDDLEWARE_HEADERS.with_label_values(&["content-length", &format!("{}", content_length.0)]);
    Ok(next.run(request).await)
}

/// we expect sui-node to send us an http header content-type encoding.
pub async fn expect_mysten_proxy_header(
    TypedHeader(content_type): TypedHeader<ContentType>,
    request: Request<Body>,
    next: Next,
) -> Result<Response, (StatusCode, &'static str)> {
    match format!("{content_type}").as_str() {
        prometheus::PROTOBUF_FORMAT => Ok(next.run(request).await),
        ct => {
            error!("invalid content-type; {ct}");
            MIDDLEWARE_OPS
                .with_label_values(&["expect_mysten_proxy_header", "invalid-content-type"])
                .inc();
            Err((StatusCode::BAD_REQUEST, "invalid content-type header"))
        }
    }
}

/// we expect that calling sui-nodes are known on the blockchain and we enforce
/// their pub key tls creds here
pub async fn expect_valid_public_key(
    Extension(allower): Extension<Arc<SuiNodeProvider>>,
    Extension(tls_connect_info): Extension<TlsConnectionInfo>,
    mut request: Request<Body>,
    next: Next,
) -> Result<Response, (StatusCode, &'static str)> {
    let Some(public_key) = tls_connect_info.public_key() else {
        error!("unable to obtain public key from connecting client");
        MIDDLEWARE_OPS
            .with_label_values(&["expect_valid_public_key", "missing-public-key"])
            .inc();
        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
    };
    let Some(peer) = allower.get(public_key) else {
        error!("node with unknown pub key tried to connect {}", public_key);
        MIDDLEWARE_OPS
            .with_label_values(&[
                "expect_valid_public_key",
                "unknown-validator-connection-attempt",
            ])
            .inc();
        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
    };
    request.extensions_mut().insert(peer);
    Ok(next.run(request).await)
}

// extractor that shows how to consume the request body upfront
#[derive(Debug)]
pub struct LenDelimProtobuf(pub Vec<MetricFamily>);

#[async_trait]
impl<S> FromRequest<S> for LenDelimProtobuf
where
    S: Send + Sync,
{
    type Rejection = (StatusCode, String);

    async fn from_request(
        req: Request<axum::body::Body>,
        state: &S,
    ) -> Result<Self, Self::Rejection> {
        let should_be_snappy = req
            .headers()
            .get(CONTENT_ENCODING)
            .map(|v| v.as_bytes() == b"snappy")
            .unwrap_or(false);

        let body = Bytes::from_request(req, state).await.map_err(|e| {
            let msg = format!("error extracting bytes; {e}");
            error!(msg);
            MIDDLEWARE_OPS
                .with_label_values(&["LenDelimProtobuf_from_request", "unable-to-extract-bytes"])
                .inc();
            (e.status(), msg)
        })?;

        let intermediate = if should_be_snappy {
            let mut s = snap::raw::Decoder::new();
            let decompressed = s.decompress_vec(&body).map_err(|e| {
                let msg = format!("unable to decode snappy encoded protobufs; {e}");
                error!(msg);
                MIDDLEWARE_OPS
                    .with_label_values(&[
                        "LenDelimProtobuf_decompress_vec",
                        "unable-to-decode-snappy",
                    ])
                    .inc();
                (StatusCode::BAD_REQUEST, msg)
            })?;
            Bytes::from(decompressed).reader()
        } else {
            body.reader()
        };

        let mut decoder = ProtobufDecoder::new(intermediate);
        let decoded = decoder.parse::<MetricFamily>().map_err(|e| {
            let msg = format!("unable to decode len deliminated protobufs; {e}");
            error!(msg);
            MIDDLEWARE_OPS
                .with_label_values(&[
                    "LenDelimProtobuf_from_request",
                    "unable-to-decode-protobufs",
                ])
                .inc();
            (StatusCode::BAD_REQUEST, msg)
        })?;
        Ok(Self(decoded))
    }
}