sui_indexer_alt_framework/ingestion/
streaming_client.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::time::Duration;
5
6use anyhow::Context;
7use anyhow::anyhow;
8use async_trait::async_trait;
9use futures::StreamExt;
10use futures::stream::BoxStream;
11use sui_rpc::headers::X_SUI_CHAIN_ID;
12use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
13use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
14use sui_types::digests::ChainIdentifier;
15use sui_types::messages_checkpoint::CheckpointDigest;
16use tokio_stream::adapters::Peekable;
17use tonic::Status;
18use tonic::transport::Endpoint;
19use tonic::transport::Uri;
20
21use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
22use crate::ingestion::error::Error;
23use crate::ingestion::error::Result;
24use crate::types::full_checkpoint_content::Checkpoint;
25
26pub struct CheckpointStream {
27    pub stream: Peekable<BoxStream<'static, Result<Checkpoint>>>,
28    pub chain_id: ChainIdentifier,
29}
30
31/// Trait representing a client for streaming checkpoint data.
32#[async_trait]
33pub trait CheckpointStreamingClient {
34    /// Returns the CheckpointStream and chain id.
35    async fn connect(&mut self) -> Result<CheckpointStream>;
36
37    /// Returns the latest checkpoint number available from the streaming source.
38    async fn latest_checkpoint_number(&mut self) -> Result<u64> {
39        let mut stream = self.connect().await?;
40
41        match stream.stream.next().await {
42            Some(Ok(checkpoint)) => Ok(checkpoint.summary.sequence_number),
43            Some(Err(e)) => Err(e),
44            None => Err(Error::StreamingError(anyhow!("Stream ended unexpectedly"))),
45        }
46    }
47}
48
49#[derive(clap::Args, Clone, Debug, Default)]
50pub struct StreamingClientArgs {
51    /// gRPC endpoint for streaming checkpoints
52    #[clap(long, env)]
53    pub streaming_url: Option<Uri>,
54}
55
56/// gRPC-based implementation of the CheckpointStreamingClient trait.
57#[derive(Clone)]
58pub struct GrpcStreamingClient {
59    uri: Uri,
60    connection_timeout: Duration,
61    statement_timeout: Duration,
62}
63
64impl GrpcStreamingClient {
65    pub fn new(uri: Uri, connection_timeout: Duration, statement_timeout: Duration) -> Self {
66        Self {
67            uri,
68            connection_timeout,
69            statement_timeout,
70        }
71    }
72}
73
74#[async_trait]
75impl CheckpointStreamingClient for GrpcStreamingClient {
76    async fn connect(&mut self) -> Result<CheckpointStream> {
77        let endpoint = Endpoint::from(self.uri.clone())
78            .connect_timeout(self.connection_timeout)
79            .timeout(self.connection_timeout);
80
81        let mut client = SubscriptionServiceClient::connect(endpoint)
82            .await
83            .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
84            .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
85
86        let mut request = SubscribeCheckpointsRequest::default();
87        request.read_mask = Some(Checkpoint::proto_field_mask());
88
89        let response = client
90            .subscribe_checkpoints(request)
91            .await
92            .map_err(Error::RpcClientError)?;
93
94        let chain_id_value = response.metadata().get(X_SUI_CHAIN_ID).ok_or_else(|| {
95            Error::StreamingError(anyhow!("Chain ID not found in response metadata"))
96        })?;
97        let chain_id: ChainIdentifier = chain_id_value
98            .to_str()
99            .map_err(|e| Error::StreamingError(anyhow!("Chain ID is not valid ASCII: {e}")))?
100            .parse::<CheckpointDigest>()
101            .map_err(|e| Error::StreamingError(anyhow!("Chain ID parse error: {e}")))?
102            .into();
103
104        let stream = response
105            .into_inner()
106            .map(|result| async move {
107                match result {
108                    Ok(response) => {
109                        let checkpoint = response
110                            .checkpoint
111                            .context("Checkpoint data missing in response")
112                            .map_err(Error::StreamingError)?;
113                        // Proto -> Checkpoint conversion is multi-ms of CPU work;
114                        // offload to the blocking pool so it doesn't stall the reactor.
115                        // Combined with `.buffered(4)` below, up to 4 decodes can run
116                        // concurrently while new bytes keep flowing from gRPC.
117                        tokio::task::spawn_blocking(move || {
118                            Checkpoint::try_from(&checkpoint).context("Failed to parse checkpoint")
119                        })
120                        .await
121                        .map_err(|e| Error::StreamingError(anyhow!("decode task panicked: {e}")))?
122                        .map_err(Error::StreamingError)
123                    }
124                    Err(e) => Err(Error::RpcClientError(e)),
125                }
126            })
127            .buffered(4);
128        let stream = wrap_stream(stream, self.statement_timeout);
129
130        Ok(CheckpointStream { stream, chain_id })
131    }
132}
133
134/// Wraps a stream with a per-item timeout. Converts the resulting `Err(Elapsed)` into
135/// `Err(StreamingError)` if it occurs.
136fn wrap_stream(
137    stream: impl futures::Stream<Item = Result<Checkpoint>> + Send + 'static,
138    statement_timeout: Duration,
139) -> Peekable<BoxStream<'static, Result<Checkpoint>>> {
140    let stream = tokio_stream::StreamExt::timeout(stream, statement_timeout)
141        .map(move |result| match result {
142            Err(_elapsed) => Err(Error::StreamingError(anyhow!(
143                "Statement timeout after {statement_timeout:?}"
144            ))),
145            Ok(result) => result,
146        })
147        .boxed();
148    tokio_stream::StreamExt::peekable(stream)
149}
150
151#[cfg(test)]
152mod tests {
153    use std::net::SocketAddr;
154    use std::time::Duration;
155
156    use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
157    use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsResponse;
158    use sui_rpc::proto::sui::rpc::v2::subscription_service_server::SubscriptionService;
159    use sui_rpc::proto::sui::rpc::v2::subscription_service_server::SubscriptionServiceServer;
160    use tonic::transport::Server;
161
162    use super::*;
163
164    /// A gRPC server that accepts connections but never responds to
165    /// subscribe_checkpoints, simulating a stalled RPC handshake.
166    struct HangingSubscriptionService;
167
168    #[tonic::async_trait]
169    impl SubscriptionService for HangingSubscriptionService {
170        async fn subscribe_checkpoints(
171            &self,
172            _request: tonic::Request<SubscribeCheckpointsRequest>,
173        ) -> std::result::Result<
174            tonic::Response<
175                BoxStream<'static, std::result::Result<SubscribeCheckpointsResponse, Status>>,
176            >,
177            Status,
178        > {
179            futures::future::pending().await
180        }
181    }
182
183    #[tokio::test]
184    async fn subscribe_checkpoints_times_out_on_stalled_server() {
185        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
186        let addr: SocketAddr = listener.local_addr().unwrap();
187
188        tokio::spawn(async move {
189            let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
190            Server::builder()
191                .add_service(SubscriptionServiceServer::new(HangingSubscriptionService))
192                .serve_with_incoming(incoming)
193                .await
194                .unwrap();
195        });
196
197        let timeout = Duration::from_millis(200);
198        let uri: Uri = format!("http://{addr}").parse().unwrap();
199        let mut client = GrpcStreamingClient::new(uri, timeout, timeout);
200
201        let start = std::time::Instant::now();
202        let result = client.connect().await;
203        let elapsed = start.elapsed();
204
205        assert!(result.is_err(), "expected timeout error");
206        assert!(
207            elapsed < Duration::from_secs(5),
208            "connect() took {elapsed:?}, should have timed out in ~200ms"
209        );
210    }
211}
212
213#[cfg(test)]
214pub mod test_utils {
215    use std::pin::Pin;
216    use std::sync::Arc;
217    use std::sync::Mutex;
218    use std::time::Duration;
219    use std::time::Instant;
220
221    use futures::Stream;
222
223    use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
224
225    use super::*;
226
227    enum StreamAction {
228        Checkpoint(u64),
229        Error,
230        Timeout {
231            deadline: Option<Instant>,
232            duration: Duration,
233        },
234    }
235
236    struct MockStreamState {
237        actions: Arc<Mutex<Vec<StreamAction>>>,
238    }
239
240    impl Stream for MockStreamState {
241        type Item = Result<Checkpoint>;
242
243        fn poll_next(
244            self: Pin<&mut Self>,
245            _cx: &mut std::task::Context<'_>,
246        ) -> std::task::Poll<Option<Self::Item>> {
247            let mut actions = self.actions.lock().unwrap();
248            if actions.is_empty() {
249                return std::task::Poll::Ready(None);
250            }
251
252            match &actions[0] {
253                StreamAction::Checkpoint(seq) => {
254                    let seq = *seq;
255                    actions.remove(0);
256                    let mut builder = TestCheckpointBuilder::new(seq);
257                    std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
258                }
259                StreamAction::Error => {
260                    actions.remove(0);
261                    std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
262                        "Mock streaming error"
263                    )))))
264                }
265                StreamAction::Timeout { deadline, duration } => match deadline {
266                    None => {
267                        let deadline = Instant::now() + *duration;
268                        actions[0] = StreamAction::Timeout {
269                            deadline: Some(deadline),
270                            duration: *duration,
271                        };
272                        std::task::Poll::Pending
273                    }
274                    Some(deadline_instant) => {
275                        if Instant::now() >= *deadline_instant {
276                            actions.remove(0);
277                            drop(actions);
278                            self.poll_next(_cx)
279                        } else {
280                            std::task::Poll::Pending
281                        }
282                    }
283                },
284            }
285        }
286    }
287
288    /// Mock streaming client for testing with predefined checkpoints.
289    pub struct MockStreamingClient {
290        actions: Arc<Mutex<Vec<StreamAction>>>,
291        connection_failures_remaining: usize,
292        connection_timeouts_remaining: usize,
293        /// How long mock timeout actions hang (must be > statement_timeout for timeouts to fire).
294        timeout_duration: Duration,
295        /// Statement timeout applied to the stream wrapper.
296        statement_timeout: Duration,
297    }
298
299    impl MockStreamingClient {
300        pub fn mock_chain_id() -> ChainIdentifier {
301            CheckpointDigest::new([1; 32]).into()
302        }
303
304        pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
305        where
306            I: IntoIterator<Item = u64>,
307        {
308            let timeout_duration = timeout_duration.unwrap_or(Duration::from_secs(5));
309            Self {
310                actions: Arc::new(Mutex::new(
311                    checkpoint_range
312                        .into_iter()
313                        .map(StreamAction::Checkpoint)
314                        .collect(),
315                )),
316                connection_failures_remaining: 0,
317                connection_timeouts_remaining: 0,
318                statement_timeout: timeout_duration / 2,
319                timeout_duration,
320            }
321        }
322
323        /// Make `connect` fail for the next N calls
324        pub fn fail_connection_times(mut self, times: usize) -> Self {
325            self.connection_failures_remaining = times;
326            self
327        }
328
329        /// Make `connect` timeout for the next N calls
330        pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
331            self.connection_timeouts_remaining = times;
332            self
333        }
334
335        /// Insert an error at the back of the queue.
336        pub fn insert_error(&mut self) {
337            self.actions.lock().unwrap().push(StreamAction::Error);
338        }
339
340        /// Insert a timeout at the back of the queue (causes poll_next to return Pending).
341        pub fn insert_timeout(&mut self) {
342            self.insert_timeout_with_duration(self.timeout_duration)
343        }
344
345        /// Insert a timeout with custom duration.
346        pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
347            self.actions.lock().unwrap().push(StreamAction::Timeout {
348                deadline: None,
349                duration,
350            });
351        }
352
353        /// Insert a checkpoint at the back of the queue.
354        pub fn insert_checkpoint(&mut self, sequence_number: u64) {
355            self.insert_checkpoint_range([sequence_number])
356        }
357
358        pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
359        where
360            I: IntoIterator<Item = u64>,
361        {
362            let mut actions = self.actions.lock().unwrap();
363            for sequence_number in checkpoint_range {
364                actions.push(StreamAction::Checkpoint(sequence_number));
365            }
366        }
367    }
368
369    #[async_trait]
370    impl CheckpointStreamingClient for MockStreamingClient {
371        async fn connect(&mut self) -> Result<CheckpointStream> {
372            if self.connection_timeouts_remaining > 0 {
373                self.connection_timeouts_remaining -= 1;
374                // Simulate a connection timeout
375                tokio::time::sleep(self.timeout_duration).await;
376                return Err(Error::StreamingError(anyhow::anyhow!(
377                    "Mock connection timeout"
378                )));
379            }
380            if self.connection_failures_remaining > 0 {
381                self.connection_failures_remaining -= 1;
382                return Err(Error::StreamingError(anyhow::anyhow!(
383                    "Mock connection failure"
384                )));
385            }
386            let stream_state = MockStreamState {
387                actions: Arc::clone(&self.actions),
388            };
389            Ok(CheckpointStream {
390                stream: wrap_stream(stream_state, self.statement_timeout),
391                chain_id: Self::mock_chain_id(),
392            })
393        }
394    }
395}