sui_indexer_alt_framework/ingestion/
streaming_client.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::time::Duration;
5
6use anyhow::Context;
7use anyhow::anyhow;
8use async_trait::async_trait;
9use futures::StreamExt;
10use futures::stream::BoxStream;
11use sui_rpc::headers::X_SUI_CHAIN_ID;
12use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
13use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
14use sui_types::digests::ChainIdentifier;
15use sui_types::messages_checkpoint::CheckpointDigest;
16use tokio_stream::adapters::Peekable;
17use tonic::Status;
18use tonic::transport::Endpoint;
19use tonic::transport::Uri;
20
21use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
22use crate::ingestion::error::Error;
23use crate::ingestion::error::Result;
24use crate::types::full_checkpoint_content::Checkpoint;
25
26pub struct CheckpointStream {
27    pub stream: Peekable<BoxStream<'static, Result<Checkpoint>>>,
28    pub chain_id: ChainIdentifier,
29}
30
31/// Trait representing a client for streaming checkpoint data.
32#[async_trait]
33pub trait CheckpointStreamingClient {
34    /// Returns the CheckpointStream and chain id.
35    async fn connect(&mut self) -> Result<CheckpointStream>;
36}
37
38#[derive(clap::Args, Clone, Debug, Default)]
39pub struct StreamingClientArgs {
40    /// gRPC endpoint for streaming checkpoints
41    #[clap(long, env)]
42    pub streaming_url: Option<Uri>,
43}
44
45/// gRPC-based implementation of the CheckpointStreamingClient trait.
46#[derive(Clone)]
47pub struct GrpcStreamingClient {
48    uri: Uri,
49    connection_timeout: Duration,
50    statement_timeout: Duration,
51}
52
53impl GrpcStreamingClient {
54    pub fn new(uri: Uri, connection_timeout: Duration, statement_timeout: Duration) -> Self {
55        Self {
56            uri,
57            connection_timeout,
58            statement_timeout,
59        }
60    }
61}
62
63#[async_trait]
64impl CheckpointStreamingClient for GrpcStreamingClient {
65    async fn connect(&mut self) -> Result<CheckpointStream> {
66        let endpoint = Endpoint::from(self.uri.clone()).connect_timeout(self.connection_timeout);
67
68        let mut client = SubscriptionServiceClient::connect(endpoint)
69            .await
70            .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
71            .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
72
73        let mut request = SubscribeCheckpointsRequest::default();
74        request.read_mask = Some(Checkpoint::proto_field_mask());
75
76        let response = client
77            .subscribe_checkpoints(request)
78            .await
79            .map_err(Error::RpcClientError)?;
80
81        let chain_id_value = response.metadata().get(X_SUI_CHAIN_ID).ok_or_else(|| {
82            Error::StreamingError(anyhow!("Chain ID not found in response metadata"))
83        })?;
84        let chain_id: ChainIdentifier = chain_id_value
85            .to_str()
86            .map_err(|e| Error::StreamingError(anyhow!("Chain ID is not valid ASCII: {e}")))?
87            .parse::<CheckpointDigest>()
88            .map_err(|e| Error::StreamingError(anyhow!("Chain ID parse error: {e}")))?
89            .into();
90
91        let stream = response.into_inner().map(|result| match result {
92            Ok(response) => response
93                .checkpoint
94                .context("Checkpoint data missing in response")
95                .and_then(|checkpoint| {
96                    Checkpoint::try_from(&checkpoint).context("Failed to parse checkpoint")
97                })
98                .map_err(Error::StreamingError),
99            Err(e) => Err(Error::RpcClientError(e)),
100        });
101        let stream = wrap_stream(stream, self.statement_timeout);
102
103        Ok(CheckpointStream { stream, chain_id })
104    }
105}
106
107/// Wraps a stream with a per-item timeout. Converts the resulting `Err(Elapsed)` into
108/// `Err(StreamingError)` if it occurs.
109fn wrap_stream(
110    stream: impl futures::Stream<Item = Result<Checkpoint>> + Send + 'static,
111    statement_timeout: Duration,
112) -> Peekable<BoxStream<'static, Result<Checkpoint>>> {
113    let stream = tokio_stream::StreamExt::timeout(stream, statement_timeout)
114        .map(move |result| match result {
115            Err(_elapsed) => Err(Error::StreamingError(anyhow!(
116                "Statement timeout after {statement_timeout:?}"
117            ))),
118            Ok(result) => result,
119        })
120        .boxed();
121    tokio_stream::StreamExt::peekable(stream)
122}
123
124#[cfg(test)]
125pub mod test_utils {
126    use std::pin::Pin;
127    use std::sync::Arc;
128    use std::sync::Mutex;
129    use std::time::Duration;
130    use std::time::Instant;
131
132    use futures::Stream;
133
134    use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
135
136    use super::*;
137
138    enum StreamAction {
139        Checkpoint(u64),
140        Error,
141        Timeout {
142            deadline: Option<Instant>,
143            duration: Duration,
144        },
145    }
146
147    struct MockStreamState {
148        actions: Arc<Mutex<Vec<StreamAction>>>,
149    }
150
151    impl Stream for MockStreamState {
152        type Item = Result<Checkpoint>;
153
154        fn poll_next(
155            self: Pin<&mut Self>,
156            _cx: &mut std::task::Context<'_>,
157        ) -> std::task::Poll<Option<Self::Item>> {
158            let mut actions = self.actions.lock().unwrap();
159            if actions.is_empty() {
160                return std::task::Poll::Ready(None);
161            }
162
163            match &actions[0] {
164                StreamAction::Checkpoint(seq) => {
165                    let seq = *seq;
166                    actions.remove(0);
167                    let mut builder = TestCheckpointBuilder::new(seq);
168                    std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
169                }
170                StreamAction::Error => {
171                    actions.remove(0);
172                    std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
173                        "Mock streaming error"
174                    )))))
175                }
176                StreamAction::Timeout { deadline, duration } => match deadline {
177                    None => {
178                        let deadline = Instant::now() + *duration;
179                        actions[0] = StreamAction::Timeout {
180                            deadline: Some(deadline),
181                            duration: *duration,
182                        };
183                        std::task::Poll::Pending
184                    }
185                    Some(deadline_instant) => {
186                        if Instant::now() >= *deadline_instant {
187                            actions.remove(0);
188                            drop(actions);
189                            self.poll_next(_cx)
190                        } else {
191                            std::task::Poll::Pending
192                        }
193                    }
194                },
195            }
196        }
197    }
198
199    /// Mock streaming client for testing with predefined checkpoints.
200    pub struct MockStreamingClient {
201        actions: Arc<Mutex<Vec<StreamAction>>>,
202        connection_failures_remaining: usize,
203        connection_timeouts_remaining: usize,
204        /// How long mock timeout actions hang (must be > statement_timeout for timeouts to fire).
205        timeout_duration: Duration,
206        /// Statement timeout applied to the stream wrapper.
207        statement_timeout: Duration,
208    }
209
210    impl MockStreamingClient {
211        pub fn mock_chain_id() -> ChainIdentifier {
212            CheckpointDigest::new([1; 32]).into()
213        }
214
215        pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
216        where
217            I: IntoIterator<Item = u64>,
218        {
219            let timeout_duration = timeout_duration.unwrap_or(Duration::from_secs(5));
220            Self {
221                actions: Arc::new(Mutex::new(
222                    checkpoint_range
223                        .into_iter()
224                        .map(StreamAction::Checkpoint)
225                        .collect(),
226                )),
227                connection_failures_remaining: 0,
228                connection_timeouts_remaining: 0,
229                statement_timeout: timeout_duration / 2,
230                timeout_duration,
231            }
232        }
233
234        /// Make `connect` fail for the next N calls
235        pub fn fail_connection_times(mut self, times: usize) -> Self {
236            self.connection_failures_remaining = times;
237            self
238        }
239
240        /// Make `connect` timeout for the next N calls
241        pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
242            self.connection_timeouts_remaining = times;
243            self
244        }
245
246        /// Insert an error at the back of the queue.
247        pub fn insert_error(&mut self) {
248            self.actions.lock().unwrap().push(StreamAction::Error);
249        }
250
251        /// Insert a timeout at the back of the queue (causes poll_next to return Pending).
252        pub fn insert_timeout(&mut self) {
253            self.insert_timeout_with_duration(self.timeout_duration)
254        }
255
256        /// Insert a timeout with custom duration.
257        pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
258            self.actions.lock().unwrap().push(StreamAction::Timeout {
259                deadline: None,
260                duration,
261            });
262        }
263
264        /// Insert a checkpoint at the back of the queue.
265        pub fn insert_checkpoint(&mut self, sequence_number: u64) {
266            self.insert_checkpoint_range([sequence_number])
267        }
268
269        pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
270        where
271            I: IntoIterator<Item = u64>,
272        {
273            let mut actions = self.actions.lock().unwrap();
274            for sequence_number in checkpoint_range {
275                actions.push(StreamAction::Checkpoint(sequence_number));
276            }
277        }
278    }
279
280    #[async_trait]
281    impl CheckpointStreamingClient for MockStreamingClient {
282        async fn connect(&mut self) -> Result<CheckpointStream> {
283            if self.connection_timeouts_remaining > 0 {
284                self.connection_timeouts_remaining -= 1;
285                // Simulate a connection timeout
286                tokio::time::sleep(self.timeout_duration).await;
287                return Err(Error::StreamingError(anyhow::anyhow!(
288                    "Mock connection timeout"
289                )));
290            }
291            if self.connection_failures_remaining > 0 {
292                self.connection_failures_remaining -= 1;
293                return Err(Error::StreamingError(anyhow::anyhow!(
294                    "Mock connection failure"
295                )));
296            }
297            let stream_state = MockStreamState {
298                actions: Arc::clone(&self.actions),
299            };
300            Ok(CheckpointStream {
301                stream: wrap_stream(stream_state, self.statement_timeout),
302                chain_id: Self::mock_chain_id(),
303            })
304        }
305    }
306}