sui_indexer_alt_framework/ingestion/
streaming_client.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::pin::Pin;
5use std::time::Duration;
6
7use anyhow::Context;
8use async_trait::async_trait;
9use futures::Stream;
10use futures::StreamExt;
11use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
12use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
13use sui_rpc_api::client::checkpoint_data_field_mask;
14use tonic::Status;
15use tonic::transport::Endpoint;
16use tonic::transport::Uri;
17
18use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
19use crate::ingestion::error::Error;
20use crate::ingestion::error::Result;
21use crate::types::full_checkpoint_content::Checkpoint;
22
23/// Type alias for a stream of checkpoint data.
24pub type CheckpointStream = Pin<Box<dyn Stream<Item = Result<Checkpoint>> + Send>>;
25
26/// Trait representing a client for streaming checkpoint data.
27#[async_trait]
28pub trait CheckpointStreamingClient {
29    async fn connect(&mut self) -> Result<CheckpointStream>;
30}
31
32#[derive(clap::Args, Clone, Debug, Default)]
33pub struct StreamingClientArgs {
34    /// gRPC endpoint for streaming checkpoints
35    #[clap(long, env)]
36    pub streaming_url: Option<Uri>,
37}
38
39/// gRPC-based implementation of the CheckpointStreamingClient trait.
40pub struct GrpcStreamingClient {
41    uri: Uri,
42    connection_timeout: Duration,
43}
44
45impl GrpcStreamingClient {
46    pub fn new(uri: Uri, connection_timeout: Duration) -> Self {
47        Self {
48            uri,
49            connection_timeout,
50        }
51    }
52}
53
54#[async_trait]
55impl CheckpointStreamingClient for GrpcStreamingClient {
56    async fn connect(&mut self) -> Result<CheckpointStream> {
57        let endpoint = Endpoint::from(self.uri.clone()).connect_timeout(self.connection_timeout);
58
59        let mut client = SubscriptionServiceClient::connect(endpoint)
60            .await
61            .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
62            .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
63
64        let mut request = SubscribeCheckpointsRequest::default();
65        request.read_mask = Some(checkpoint_data_field_mask());
66
67        let stream = client
68            .subscribe_checkpoints(request)
69            .await
70            .map_err(Error::RpcClientError)?
71            .into_inner();
72
73        let converted_stream = stream.map(|result| match result {
74            Ok(response) => response
75                .checkpoint
76                .context("Checkpoint data missing in response")
77                .and_then(|checkpoint| {
78                    sui_types::full_checkpoint_content::Checkpoint::try_from(&checkpoint)
79                        .context("Failed to parse checkpoint")
80                })
81                .map_err(Error::StreamingError),
82            Err(e) => Err(Error::RpcClientError(e)),
83        });
84
85        Ok(Box::pin(converted_stream))
86    }
87}
88
89#[cfg(test)]
90pub mod test_utils {
91    use std::sync::Arc;
92    use std::sync::Mutex;
93    use std::time::Duration;
94    use std::time::Instant;
95
96    use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
97
98    use super::*;
99
100    enum StreamAction {
101        Checkpoint(u64),
102        Error,
103        Timeout {
104            deadline: Option<Instant>,
105            duration: Duration,
106        },
107    }
108
109    struct MockStreamState {
110        actions: Arc<Mutex<Vec<StreamAction>>>,
111    }
112
113    impl Stream for MockStreamState {
114        type Item = Result<Checkpoint>;
115
116        fn poll_next(
117            self: Pin<&mut Self>,
118            _cx: &mut std::task::Context<'_>,
119        ) -> std::task::Poll<Option<Self::Item>> {
120            let mut actions = self.actions.lock().unwrap();
121            if actions.is_empty() {
122                return std::task::Poll::Ready(None);
123            }
124
125            match &actions[0] {
126                StreamAction::Checkpoint(seq) => {
127                    let seq = *seq;
128                    actions.remove(0);
129                    let mut builder = TestCheckpointBuilder::new(seq);
130                    std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
131                }
132                StreamAction::Error => {
133                    actions.remove(0);
134                    std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
135                        "Mock streaming error"
136                    )))))
137                }
138                StreamAction::Timeout { deadline, duration } => match deadline {
139                    None => {
140                        let deadline = Instant::now() + *duration;
141                        actions[0] = StreamAction::Timeout {
142                            deadline: Some(deadline),
143                            duration: *duration,
144                        };
145                        std::task::Poll::Pending
146                    }
147                    Some(deadline_instant) => {
148                        if Instant::now() >= *deadline_instant {
149                            actions.remove(0);
150                            drop(actions);
151                            self.poll_next(_cx)
152                        } else {
153                            std::task::Poll::Pending
154                        }
155                    }
156                },
157            }
158        }
159    }
160
161    /// Mock streaming client for testing with predefined checkpoints.
162    pub struct MockStreamingClient {
163        actions: Arc<Mutex<Vec<StreamAction>>>,
164        connection_failures_remaining: usize,
165        connection_timeouts_remaining: usize,
166        timeout_duration: Duration,
167    }
168
169    impl MockStreamingClient {
170        pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
171        where
172            I: IntoIterator<Item = u64>,
173        {
174            Self {
175                actions: Arc::new(Mutex::new(
176                    checkpoint_range
177                        .into_iter()
178                        .map(StreamAction::Checkpoint)
179                        .collect(),
180                )),
181                connection_failures_remaining: 0,
182                connection_timeouts_remaining: 0,
183                timeout_duration: timeout_duration.unwrap_or(Duration::from_secs(5)),
184            }
185        }
186
187        /// Make `connect` fail for the next N calls
188        pub fn fail_connection_times(mut self, times: usize) -> Self {
189            self.connection_failures_remaining = times;
190            self
191        }
192
193        /// Make `connect` timeout for the next N calls
194        pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
195            self.connection_timeouts_remaining = times;
196            self
197        }
198
199        /// Insert an error at the back of the queue.
200        pub fn insert_error(&mut self) {
201            self.actions.lock().unwrap().push(StreamAction::Error);
202        }
203
204        /// Insert a timeout at the back of the queue (causes poll_next to return Pending).
205        pub fn insert_timeout(&mut self) {
206            self.insert_timeout_with_duration(self.timeout_duration)
207        }
208
209        /// Insert a timeout with custom duration.
210        pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
211            self.actions.lock().unwrap().push(StreamAction::Timeout {
212                deadline: None,
213                duration,
214            });
215        }
216
217        /// Insert a checkpoint at the back of the queue.
218        pub fn insert_checkpoint(&mut self, sequence_number: u64) {
219            self.insert_checkpoint_range([sequence_number])
220        }
221
222        pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
223        where
224            I: IntoIterator<Item = u64>,
225        {
226            let mut actions = self.actions.lock().unwrap();
227            for sequence_number in checkpoint_range {
228                actions.push(StreamAction::Checkpoint(sequence_number));
229            }
230        }
231    }
232
233    #[async_trait]
234    impl CheckpointStreamingClient for MockStreamingClient {
235        async fn connect(&mut self) -> Result<CheckpointStream> {
236            if self.connection_timeouts_remaining > 0 {
237                self.connection_timeouts_remaining -= 1;
238                // Simulate a connection timeout
239                tokio::time::sleep(self.timeout_duration).await;
240                return Err(Error::StreamingError(anyhow::anyhow!(
241                    "Mock connection timeout"
242                )));
243            }
244            if self.connection_failures_remaining > 0 {
245                self.connection_failures_remaining -= 1;
246                return Err(Error::StreamingError(anyhow::anyhow!(
247                    "Mock connection failure"
248                )));
249            }
250            let stream = MockStreamState {
251                actions: Arc::clone(&self.actions),
252            };
253
254            Ok(Box::pin(stream))
255        }
256    }
257}