sui_indexer_alt_framework/ingestion/
streaming_client.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::pin::Pin;
5use std::time::Duration;
6
7use anyhow::Context;
8use async_trait::async_trait;
9use futures::Stream;
10use futures::StreamExt;
11use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
12use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
13use tonic::Status;
14use tonic::transport::Endpoint;
15use tonic::transport::Uri;
16
17use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
18use crate::ingestion::error::Error;
19use crate::ingestion::error::Result;
20use crate::types::full_checkpoint_content::Checkpoint;
21
22/// Type alias for a stream of checkpoint data.
23pub type CheckpointStream = Pin<Box<dyn Stream<Item = Result<Checkpoint>> + Send>>;
24
25/// Trait representing a client for streaming checkpoint data.
26#[async_trait]
27pub trait CheckpointStreamingClient {
28    async fn connect(&mut self) -> Result<CheckpointStream>;
29}
30
31#[derive(clap::Args, Clone, Debug, Default)]
32pub struct StreamingClientArgs {
33    /// gRPC endpoint for streaming checkpoints
34    #[clap(long, env)]
35    pub streaming_url: Option<Uri>,
36}
37
38/// gRPC-based implementation of the CheckpointStreamingClient trait.
39pub struct GrpcStreamingClient {
40    uri: Uri,
41    connection_timeout: Duration,
42}
43
44impl GrpcStreamingClient {
45    pub fn new(uri: Uri, connection_timeout: Duration) -> Self {
46        Self {
47            uri,
48            connection_timeout,
49        }
50    }
51}
52
53#[async_trait]
54impl CheckpointStreamingClient for GrpcStreamingClient {
55    async fn connect(&mut self) -> Result<CheckpointStream> {
56        let endpoint = Endpoint::from(self.uri.clone()).connect_timeout(self.connection_timeout);
57
58        let mut client = SubscriptionServiceClient::connect(endpoint)
59            .await
60            .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
61            .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
62
63        let mut request = SubscribeCheckpointsRequest::default();
64        request.read_mask = Some(Checkpoint::proto_field_mask());
65
66        let stream = client
67            .subscribe_checkpoints(request)
68            .await
69            .map_err(Error::RpcClientError)?
70            .into_inner();
71
72        let converted_stream = stream.map(|result| match result {
73            Ok(response) => response
74                .checkpoint
75                .context("Checkpoint data missing in response")
76                .and_then(|checkpoint| {
77                    Checkpoint::try_from(&checkpoint).context("Failed to parse checkpoint")
78                })
79                .map_err(Error::StreamingError),
80            Err(e) => Err(Error::RpcClientError(e)),
81        });
82
83        Ok(Box::pin(converted_stream))
84    }
85}
86
87#[cfg(test)]
88pub mod test_utils {
89    use std::sync::Arc;
90    use std::sync::Mutex;
91    use std::time::Duration;
92    use std::time::Instant;
93
94    use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
95
96    use super::*;
97
98    enum StreamAction {
99        Checkpoint(u64),
100        Error,
101        Timeout {
102            deadline: Option<Instant>,
103            duration: Duration,
104        },
105    }
106
107    struct MockStreamState {
108        actions: Arc<Mutex<Vec<StreamAction>>>,
109    }
110
111    impl Stream for MockStreamState {
112        type Item = Result<Checkpoint>;
113
114        fn poll_next(
115            self: Pin<&mut Self>,
116            _cx: &mut std::task::Context<'_>,
117        ) -> std::task::Poll<Option<Self::Item>> {
118            let mut actions = self.actions.lock().unwrap();
119            if actions.is_empty() {
120                return std::task::Poll::Ready(None);
121            }
122
123            match &actions[0] {
124                StreamAction::Checkpoint(seq) => {
125                    let seq = *seq;
126                    actions.remove(0);
127                    let mut builder = TestCheckpointBuilder::new(seq);
128                    std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
129                }
130                StreamAction::Error => {
131                    actions.remove(0);
132                    std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
133                        "Mock streaming error"
134                    )))))
135                }
136                StreamAction::Timeout { deadline, duration } => match deadline {
137                    None => {
138                        let deadline = Instant::now() + *duration;
139                        actions[0] = StreamAction::Timeout {
140                            deadline: Some(deadline),
141                            duration: *duration,
142                        };
143                        std::task::Poll::Pending
144                    }
145                    Some(deadline_instant) => {
146                        if Instant::now() >= *deadline_instant {
147                            actions.remove(0);
148                            drop(actions);
149                            self.poll_next(_cx)
150                        } else {
151                            std::task::Poll::Pending
152                        }
153                    }
154                },
155            }
156        }
157    }
158
159    /// Mock streaming client for testing with predefined checkpoints.
160    pub struct MockStreamingClient {
161        actions: Arc<Mutex<Vec<StreamAction>>>,
162        connection_failures_remaining: usize,
163        connection_timeouts_remaining: usize,
164        timeout_duration: Duration,
165    }
166
167    impl MockStreamingClient {
168        pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
169        where
170            I: IntoIterator<Item = u64>,
171        {
172            Self {
173                actions: Arc::new(Mutex::new(
174                    checkpoint_range
175                        .into_iter()
176                        .map(StreamAction::Checkpoint)
177                        .collect(),
178                )),
179                connection_failures_remaining: 0,
180                connection_timeouts_remaining: 0,
181                timeout_duration: timeout_duration.unwrap_or(Duration::from_secs(5)),
182            }
183        }
184
185        /// Make `connect` fail for the next N calls
186        pub fn fail_connection_times(mut self, times: usize) -> Self {
187            self.connection_failures_remaining = times;
188            self
189        }
190
191        /// Make `connect` timeout for the next N calls
192        pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
193            self.connection_timeouts_remaining = times;
194            self
195        }
196
197        /// Insert an error at the back of the queue.
198        pub fn insert_error(&mut self) {
199            self.actions.lock().unwrap().push(StreamAction::Error);
200        }
201
202        /// Insert a timeout at the back of the queue (causes poll_next to return Pending).
203        pub fn insert_timeout(&mut self) {
204            self.insert_timeout_with_duration(self.timeout_duration)
205        }
206
207        /// Insert a timeout with custom duration.
208        pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
209            self.actions.lock().unwrap().push(StreamAction::Timeout {
210                deadline: None,
211                duration,
212            });
213        }
214
215        /// Insert a checkpoint at the back of the queue.
216        pub fn insert_checkpoint(&mut self, sequence_number: u64) {
217            self.insert_checkpoint_range([sequence_number])
218        }
219
220        pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
221        where
222            I: IntoIterator<Item = u64>,
223        {
224            let mut actions = self.actions.lock().unwrap();
225            for sequence_number in checkpoint_range {
226                actions.push(StreamAction::Checkpoint(sequence_number));
227            }
228        }
229    }
230
231    #[async_trait]
232    impl CheckpointStreamingClient for MockStreamingClient {
233        async fn connect(&mut self) -> Result<CheckpointStream> {
234            if self.connection_timeouts_remaining > 0 {
235                self.connection_timeouts_remaining -= 1;
236                // Simulate a connection timeout
237                tokio::time::sleep(self.timeout_duration).await;
238                return Err(Error::StreamingError(anyhow::anyhow!(
239                    "Mock connection timeout"
240                )));
241            }
242            if self.connection_failures_remaining > 0 {
243                self.connection_failures_remaining -= 1;
244                return Err(Error::StreamingError(anyhow::anyhow!(
245                    "Mock connection failure"
246                )));
247            }
248            let stream = MockStreamState {
249                actions: Arc::clone(&self.actions),
250            };
251
252            Ok(Box::pin(stream))
253        }
254    }
255}