sui_json_rpc/
traffic_control.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// Copyright (c) Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use axum::extract::ConnectInfo;
use futures::FutureExt;
use jsonrpsee::server::middleware::rpc::RpcServiceT;
use jsonrpsee::types::{ErrorCode, ErrorObject, Id};
use jsonrpsee::MethodResponse;
use std::net::IpAddr;
use std::time::SystemTime;
use std::{net::SocketAddr, sync::Arc};
use sui_core::traffic_controller::{parse_ip, policies::TrafficTally, TrafficController};
use sui_json_rpc_api::TRANSACTION_EXECUTION_CLIENT_ERROR_CODE;
use sui_types::traffic_control::ClientIdSource;
use sui_types::traffic_control::Weight;
use tracing::error;

const TOO_MANY_REQUESTS_MSG: &str = "Too many requests";

#[derive(Clone)]
pub struct TrafficControllerService<S> {
    inner: S,
    traffic_controller: Option<Arc<TrafficController>>,
}

impl<S> TrafficControllerService<S> {
    pub fn new(service: S, traffic_controller: Option<Arc<TrafficController>>) -> Self {
        Self {
            inner: service,
            traffic_controller,
        }
    }
}

impl<'a, S> RpcServiceT<'a> for TrafficControllerService<S>
where
    S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
    S::Future: 'a,
{
    type Future = futures::future::BoxFuture<'a, jsonrpsee::MethodResponse>;

    fn call(&self, req: jsonrpsee::types::Request<'a>) -> Self::Future {
        let service = self.inner.clone();
        let traffic_controller = self.traffic_controller.clone();

        async move {
            if let Some(traffic_controller) = traffic_controller {
                let client = req.extensions().get::<IpAddr>().cloned();
                if let Err(response) = handle_traffic_req(&traffic_controller, &client).await {
                    response
                } else {
                    let response = service.call(req).await;
                    handle_traffic_resp(&traffic_controller, client, &response);
                    response
                }
            } else {
                service.call(req).await
            }
        }
        .boxed()
    }
}

async fn handle_traffic_req(
    traffic_controller: &TrafficController,
    client: &Option<IpAddr>,
) -> Result<(), MethodResponse> {
    if !traffic_controller.check(client, &None).await {
        // Entity in blocklist
        let err_obj =
            ErrorObject::borrowed(ErrorCode::ServerIsBusy.code(), TOO_MANY_REQUESTS_MSG, None);
        Err(MethodResponse::error(Id::Null, err_obj))
    } else {
        Ok(())
    }
}

fn handle_traffic_resp(
    traffic_controller: &TrafficController,
    client: Option<IpAddr>,
    response: &MethodResponse,
) {
    let error = response.as_error_code().map(ErrorCode::from);
    traffic_controller.tally(TrafficTally {
        direct: client,
        through_fullnode: None,
        error_info: error.map(|e| {
            let error_type = e.to_string();
            let error_weight = normalize(e);
            (error_weight, error_type)
        }),
        // For now, count everything as spam with equal weight
        // on the rpc node side, including gas-charging endpoints
        // such as `sui_executeTransactionBlock`, as this can enable
        // node operators who wish to rate limit their transcation
        // traffic and incentivize high volume clients to choose a
        // suitable rpc provider (or run their own). Later we may want
        // to provide a weight distribution based on the method being called.
        spam_weight: Weight::one(),
        timestamp: SystemTime::now(),
    });
}

// TODO: refine error matching here
fn normalize(err: ErrorCode) -> Weight {
    match err {
        ErrorCode::InvalidRequest | ErrorCode::InvalidParams => Weight::one(),
        // e.g. invalid client signature
        ErrorCode::ServerError(i) if i == TRANSACTION_EXECUTION_CLIENT_ERROR_CODE => Weight::one(),
        _ => Weight::zero(),
    }
}

pub fn determine_client_ip<T>(
    client_id_source: ClientIdSource,
    request: &mut axum::http::Request<T>,
) {
    let headers = request.headers();
    let client = match client_id_source {
        ClientIdSource::SocketAddr => request
            .extensions()
            .get::<ConnectInfo<SocketAddr>>()
            .map(|info| info.0.ip()),
        ClientIdSource::XForwardedFor(num_hops) => {
            let do_header_parse = |header: &axum::http::HeaderValue| match header.to_str() {
                Ok(header_val) => {
                    let header_contents = header_val.split(',').map(str::trim).collect::<Vec<_>>();
                    if num_hops == 0 {
                        error!(
                                "x-forwarded-for: 0 specified. x-forwarded-for contents: {:?}. Please assign nonzero value for \
                                number of hops here, or use `socket-addr` client-id-source type if requests are not being proxied \
                                to this node. Skipping traffic controller request handling.",
                                header_contents,
                            );
                        return None;
                    }
                    let contents_len = header_contents.len();
                    let Some(client_ip) = header_contents.get(contents_len - num_hops) else {
                        error!(
                                "x-forwarded-for header value of {:?} contains {} values, but {} hops were specificed. \
                                Expected {} values. Skipping traffic controller request handling.",
                                header_contents,
                                contents_len,
                                num_hops,
                                num_hops + 1,
                            );
                        return None;
                    };
                    parse_ip(client_ip)
                }
                Err(e) => {
                    error!("Invalid UTF-8 in x-forwarded-for header: {:?}", e);
                    None
                }
            };
            if let Some(header) = headers.get("x-forwarded-for") {
                do_header_parse(header)
            } else if let Some(header) = headers.get("X-Forwarded-For") {
                do_header_parse(header)
            } else {
                error!("x-forwarded-for header not present for request despite node configuring x-forwarded-for tracking type");
                None
            }
        }
    };

    if let Some(ip) = client {
        request.extensions_mut().insert(ip);
    }
}