sui_edge_proxy/
handlers.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::config::{LoggingConfig, PeerConfig};
5use crate::metrics::AppMetrics;
6use axum::{
7    body::Body,
8    extract::{Request, State},
9    http::StatusCode,
10    http::request::Parts,
11    response::Response,
12};
13use bytes::Bytes;
14use rand::Rng;
15use std::time::Instant;
16use tracing::{debug, warn};
17
18#[derive(Debug)]
19enum PeerRole {
20    Read,
21    Execution,
22}
23
24impl PeerRole {
25    fn as_str(&self) -> &'static str {
26        match self {
27            PeerRole::Read => "read",
28            PeerRole::Execution => "execution",
29        }
30    }
31}
32
33#[derive(Clone)]
34pub struct AppState {
35    client: reqwest::Client,
36    read_peer: PeerConfig,
37    execution_peer: PeerConfig,
38    metrics: AppMetrics,
39    logging_config: LoggingConfig,
40}
41
42impl AppState {
43    pub fn new(
44        client: reqwest::Client,
45        read_peer: PeerConfig,
46        execution_peer: PeerConfig,
47        metrics: AppMetrics,
48        logging_config: LoggingConfig,
49    ) -> Self {
50        Self {
51            client,
52            read_peer,
53            execution_peer,
54            metrics,
55            logging_config,
56        }
57    }
58}
59
60pub async fn proxy_handler(
61    State(state): State<AppState>,
62    request: Request<Body>,
63) -> Result<Response, (StatusCode, String)> {
64    let (parts, body) = request.into_parts();
65    let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
66        Ok(bytes) => bytes,
67        Err(e) => {
68            warn!("Failed to read request body: {}", e);
69            state
70                .metrics
71                .request_body_read_failures
72                .with_label_values(&[])
73                .inc();
74            return Ok(Response::builder()
75                .status(StatusCode::INTERNAL_SERVER_ERROR)
76                .body(Body::from("Failed to read request body"))
77                .unwrap());
78        }
79    };
80
81    match parts
82        .headers
83        .get("Client-Request-Method")
84        .and_then(|h| h.to_str().ok())
85    {
86        Some("sui_executeTransactionBlock") => {
87            debug!("Using execution peer");
88            proxy_request(state, parts, body_bytes, PeerRole::Execution).await
89        }
90        _ => {
91            let json_body = match serde_json::from_slice::<serde_json::Value>(&body_bytes) {
92                Ok(json_body) => json_body,
93                Err(_) => {
94                    debug!("Failed to parse request body as JSON");
95                    return proxy_request(state, parts, body_bytes, PeerRole::Read).await;
96                }
97            };
98            if let Some("sui_executeTransactionBlock") =
99                json_body.get("method").and_then(|m| m.as_str())
100            {
101                proxy_request(state, parts, body_bytes, PeerRole::Execution).await
102            } else {
103                proxy_request(state, parts, body_bytes, PeerRole::Read).await
104            }
105        }
106    }
107}
108
109async fn proxy_request(
110    state: AppState,
111    parts: Parts,
112    body_bytes: Bytes,
113    peer_type: PeerRole,
114) -> Result<Response, (StatusCode, String)> {
115    debug!(
116        "Proxying request: method={:?}, uri={:?}, headers={:?}, body_len={}, peer_type={:?}",
117        parts.method,
118        parts.uri,
119        parts.headers,
120        body_bytes.len(),
121        peer_type
122    );
123    if matches!(peer_type, PeerRole::Read) {
124        let user_agent = parts
125            .headers
126            .get("user-agent")
127            .and_then(|h| h.to_str().ok());
128        let is_health_check = matches!(user_agent, Some(ua) if ua.contains("GoogleHC/1.0"));
129        let is_grafana_agent = matches!(user_agent, Some(ua) if ua.contains("GrafanaAgent"));
130        let is_grpc = parts
131            .headers
132            .get("content-type")
133            .and_then(|h| h.to_str().ok())
134            .map(|ct| ct.contains("grpc"))
135            .unwrap_or(false);
136
137        let should_sample = !is_health_check && !is_grafana_agent && !is_grpc;
138        let rate = state.logging_config.read_request_sample_rate;
139        if should_sample && rand::thread_rng().r#gen::<f64>() < rate {
140            tracing::info!(
141                headers = ?parts.headers,
142                body = ?body_bytes,
143                peer_type = ?peer_type,
144                "Sampled read request"
145            );
146        }
147    }
148
149    let metrics = &state.metrics;
150    let peer_type_str = peer_type.as_str();
151
152    let timer_histogram = metrics.request_latency.with_label_values(&[peer_type_str]);
153    let _timer = timer_histogram.start_timer();
154
155    metrics
156        .request_size_bytes
157        .with_label_values(&[peer_type_str])
158        .observe(body_bytes.len() as f64);
159
160    let peer_config = match peer_type {
161        PeerRole::Read => &state.read_peer,
162        PeerRole::Execution => &state.execution_peer,
163    };
164
165    let mut target_url = peer_config.address.clone();
166    target_url.set_path(parts.uri.path());
167    if let Some(query) = parts.uri.query() {
168        target_url.set_query(Some(query));
169    }
170
171    // remove host header to avoid interfering with reqwest auto-host header
172    let mut headers = parts.headers.clone();
173    headers.remove("host");
174    let request_builder = state
175        .client
176        .request(parts.method.clone(), target_url)
177        .headers(headers)
178        .body(body_bytes.clone());
179    debug!("Request builder: {:?}", request_builder);
180
181    let upstream_start = Instant::now();
182    let response = match request_builder.send().await {
183        Ok(response) => {
184            let status = response.status().as_u16().to_string();
185            metrics
186                .upstream_response_latency
187                .with_label_values(&[peer_type_str, &status])
188                .observe(upstream_start.elapsed().as_secs_f64());
189            metrics
190                .requests_total
191                .with_label_values(&[peer_type_str, &status])
192                .inc();
193            debug!("Response: {:?}", response);
194            response
195        }
196        Err(e) => {
197            warn!("Failed to send request: {}", e);
198            let error_type = if e.is_timeout() {
199                metrics
200                    .timeouts_total
201                    .with_label_values(&[peer_type_str])
202                    .inc();
203                "timeout"
204            } else {
205                "send_failure"
206            };
207            metrics
208                .upstream_request_failures
209                .with_label_values(&[peer_type_str, error_type])
210                .inc();
211            return Err((StatusCode::BAD_GATEWAY, format!("Request failed: {}", e)));
212        }
213    };
214
215    let response_headers = response.headers().clone();
216    let response_bytes = match response.bytes().await {
217        Ok(bytes) => bytes,
218        Err(e) => {
219            warn!("Failed to read response body: {}", e);
220            metrics
221                .error_counts
222                .with_label_values(&[peer_type_str, "response_body_read"])
223                .inc();
224            return Err((
225                StatusCode::INTERNAL_SERVER_ERROR,
226                "Failed to read response body".to_string(),
227            ));
228        }
229    };
230    metrics
231        .response_size_bytes
232        .with_label_values(&[peer_type_str])
233        .observe(response_bytes.len() as f64);
234
235    let mut resp = Response::new(response_bytes.into());
236    for (name, value) in response_headers {
237        if let Some(name) = name {
238            resp.headers_mut().insert(name, value);
239        }
240    }
241
242    Ok(resp)
243}