sui_indexer_alt_framework/ingestion/
streaming_client.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::Context;
5use async_trait::async_trait;
6use futures::{Stream, StreamExt};
7use std::pin::Pin;
8use sui_rpc::proto::sui::rpc::v2::{
9    SubscribeCheckpointsRequest, subscription_service_client::SubscriptionServiceClient,
10};
11use sui_rpc_api::client::checkpoint_data_field_mask;
12use tonic::{Status, transport::Uri};
13
14use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
15use crate::ingestion::error::{Error, Result};
16use crate::types::full_checkpoint_content::Checkpoint;
17
18/// Type alias for a stream of checkpoint data.
19pub type CheckpointStream = Pin<Box<dyn Stream<Item = Result<Checkpoint>> + Send>>;
20
21/// Trait representing a client for streaming checkpoint data.
22#[async_trait]
23pub trait CheckpointStreamingClient {
24    async fn connect(&mut self) -> Result<CheckpointStream>;
25}
26
27#[derive(clap::Args, Clone, Debug, Default)]
28pub struct StreamingClientArgs {
29    /// gRPC endpoint for streaming checkpoints
30    #[clap(long, env)]
31    pub streaming_url: Option<Uri>,
32}
33
34/// gRPC-based implementation of the CheckpointStreamingClient trait.
35pub struct GrpcStreamingClient {
36    uri: Uri,
37}
38
39impl GrpcStreamingClient {
40    pub fn new(uri: Uri) -> Self {
41        Self { uri }
42    }
43}
44
45#[async_trait]
46impl CheckpointStreamingClient for GrpcStreamingClient {
47    async fn connect(&mut self) -> Result<CheckpointStream> {
48        let mut client = SubscriptionServiceClient::connect(self.uri.clone())
49            .await
50            .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
51            .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
52
53        let mut request = SubscribeCheckpointsRequest::default();
54        request.read_mask = Some(checkpoint_data_field_mask());
55
56        let stream = client
57            .subscribe_checkpoints(request)
58            .await
59            .map_err(Error::RpcClientError)?
60            .into_inner();
61
62        let converted_stream = stream.map(|result| match result {
63            Ok(response) => response
64                .checkpoint
65                .context("Checkpoint data missing in response")
66                .and_then(|checkpoint| {
67                    sui_types::full_checkpoint_content::Checkpoint::try_from(&checkpoint)
68                        .context("Failed to parse checkpoint")
69                })
70                .map_err(Error::StreamingError),
71            Err(e) => Err(Error::RpcClientError(e)),
72        });
73
74        Ok(Box::pin(converted_stream))
75    }
76}
77
78#[cfg(test)]
79pub mod test_utils {
80    use super::*;
81    use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
82    use std::sync::{Arc, Mutex};
83
84    struct MockStreamState {
85        checkpoints: Arc<Mutex<Vec<Result<u64>>>>,
86    }
87
88    impl Stream for MockStreamState {
89        type Item = Result<Checkpoint>;
90
91        fn poll_next(
92            self: Pin<&mut Self>,
93            _cx: &mut std::task::Context<'_>,
94        ) -> std::task::Poll<Option<Self::Item>> {
95            let mut checkpoints = self.checkpoints.lock().unwrap();
96            if checkpoints.is_empty() {
97                return std::task::Poll::Ready(None);
98            }
99            let result = checkpoints.remove(0);
100            std::task::Poll::Ready(Some(result.map(|seq| {
101                let mut builder = TestCheckpointBuilder::new(seq);
102                builder.build_checkpoint()
103            })))
104        }
105    }
106
107    /// Mock streaming client for testing with predefined checkpoints.
108    pub struct MockStreamingClient {
109        checkpoints: Arc<Mutex<Vec<Result<u64>>>>,
110        connection_failures_remaining: usize,
111    }
112
113    impl MockStreamingClient {
114        pub fn new<I>(checkpoint_range: I) -> Self
115        where
116            I: IntoIterator<Item = u64>,
117        {
118            Self {
119                checkpoints: Arc::new(Mutex::new(checkpoint_range.into_iter().map(Ok).collect())),
120                connection_failures_remaining: 0,
121            }
122        }
123
124        /// Make `connect` fail for the next N calls
125        pub fn fail_connection_times(mut self, times: usize) -> Self {
126            self.connection_failures_remaining = times;
127            self
128        }
129
130        /// Insert an error at the back of the queue.
131        pub fn insert_error(&mut self) {
132            self.checkpoints
133                .lock()
134                .unwrap()
135                .push(Err(Error::StreamingError(anyhow::anyhow!(
136                    "Mock streaming error"
137                ))));
138        }
139
140        /// Insert a checkpoint at the back of the queue.
141        pub fn insert_checkpoint(&mut self, sequence_number: u64) {
142            self.insert_checkpoint_range([sequence_number])
143        }
144
145        pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
146        where
147            I: IntoIterator<Item = u64>,
148        {
149            let mut checkpoints = self.checkpoints.lock().unwrap();
150            for sequence_number in checkpoint_range {
151                checkpoints.push(Ok(sequence_number));
152            }
153        }
154    }
155
156    #[async_trait]
157    impl CheckpointStreamingClient for MockStreamingClient {
158        async fn connect(&mut self) -> Result<CheckpointStream> {
159            if self.connection_failures_remaining > 0 {
160                self.connection_failures_remaining -= 1;
161                return Err(Error::StreamingError(anyhow::anyhow!(
162                    "Mock connection failure"
163                )));
164            }
165            let stream = MockStreamState {
166                checkpoints: Arc::clone(&self.checkpoints),
167            };
168
169            Ok(Box::pin(stream))
170        }
171    }
172}