sui_indexer_alt_framework/ingestion/
streaming_client.rs1use std::pin::Pin;
5use std::time::Duration;
6
7use anyhow::Context;
8use async_trait::async_trait;
9use futures::Stream;
10use futures::StreamExt;
11use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
12use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
13use sui_rpc_api::client::checkpoint_data_field_mask;
14use tonic::Status;
15use tonic::transport::Endpoint;
16use tonic::transport::Uri;
17
18use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
19use crate::ingestion::error::Error;
20use crate::ingestion::error::Result;
21use crate::types::full_checkpoint_content::Checkpoint;
22
23pub type CheckpointStream = Pin<Box<dyn Stream<Item = Result<Checkpoint>> + Send>>;
25
26#[async_trait]
28pub trait CheckpointStreamingClient {
29 async fn connect(&mut self) -> Result<CheckpointStream>;
30}
31
32#[derive(clap::Args, Clone, Debug, Default)]
33pub struct StreamingClientArgs {
34 #[clap(long, env)]
36 pub streaming_url: Option<Uri>,
37}
38
39pub struct GrpcStreamingClient {
41 uri: Uri,
42 connection_timeout: Duration,
43}
44
45impl GrpcStreamingClient {
46 pub fn new(uri: Uri, connection_timeout: Duration) -> Self {
47 Self {
48 uri,
49 connection_timeout,
50 }
51 }
52}
53
54#[async_trait]
55impl CheckpointStreamingClient for GrpcStreamingClient {
56 async fn connect(&mut self) -> Result<CheckpointStream> {
57 let endpoint = Endpoint::from(self.uri.clone()).connect_timeout(self.connection_timeout);
58
59 let mut client = SubscriptionServiceClient::connect(endpoint)
60 .await
61 .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
62 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
63
64 let mut request = SubscribeCheckpointsRequest::default();
65 request.read_mask = Some(checkpoint_data_field_mask());
66
67 let stream = client
68 .subscribe_checkpoints(request)
69 .await
70 .map_err(Error::RpcClientError)?
71 .into_inner();
72
73 let converted_stream = stream.map(|result| match result {
74 Ok(response) => response
75 .checkpoint
76 .context("Checkpoint data missing in response")
77 .and_then(|checkpoint| {
78 sui_types::full_checkpoint_content::Checkpoint::try_from(&checkpoint)
79 .context("Failed to parse checkpoint")
80 })
81 .map_err(Error::StreamingError),
82 Err(e) => Err(Error::RpcClientError(e)),
83 });
84
85 Ok(Box::pin(converted_stream))
86 }
87}
88
89#[cfg(test)]
90pub mod test_utils {
91 use std::sync::Arc;
92 use std::sync::Mutex;
93 use std::time::Duration;
94 use std::time::Instant;
95
96 use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
97
98 use super::*;
99
100 enum StreamAction {
101 Checkpoint(u64),
102 Error,
103 Timeout {
104 deadline: Option<Instant>,
105 duration: Duration,
106 },
107 }
108
109 struct MockStreamState {
110 actions: Arc<Mutex<Vec<StreamAction>>>,
111 }
112
113 impl Stream for MockStreamState {
114 type Item = Result<Checkpoint>;
115
116 fn poll_next(
117 self: Pin<&mut Self>,
118 _cx: &mut std::task::Context<'_>,
119 ) -> std::task::Poll<Option<Self::Item>> {
120 let mut actions = self.actions.lock().unwrap();
121 if actions.is_empty() {
122 return std::task::Poll::Ready(None);
123 }
124
125 match &actions[0] {
126 StreamAction::Checkpoint(seq) => {
127 let seq = *seq;
128 actions.remove(0);
129 let mut builder = TestCheckpointBuilder::new(seq);
130 std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
131 }
132 StreamAction::Error => {
133 actions.remove(0);
134 std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
135 "Mock streaming error"
136 )))))
137 }
138 StreamAction::Timeout { deadline, duration } => match deadline {
139 None => {
140 let deadline = Instant::now() + *duration;
141 actions[0] = StreamAction::Timeout {
142 deadline: Some(deadline),
143 duration: *duration,
144 };
145 std::task::Poll::Pending
146 }
147 Some(deadline_instant) => {
148 if Instant::now() >= *deadline_instant {
149 actions.remove(0);
150 drop(actions);
151 self.poll_next(_cx)
152 } else {
153 std::task::Poll::Pending
154 }
155 }
156 },
157 }
158 }
159 }
160
161 pub struct MockStreamingClient {
163 actions: Arc<Mutex<Vec<StreamAction>>>,
164 connection_failures_remaining: usize,
165 connection_timeouts_remaining: usize,
166 timeout_duration: Duration,
167 }
168
169 impl MockStreamingClient {
170 pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
171 where
172 I: IntoIterator<Item = u64>,
173 {
174 Self {
175 actions: Arc::new(Mutex::new(
176 checkpoint_range
177 .into_iter()
178 .map(StreamAction::Checkpoint)
179 .collect(),
180 )),
181 connection_failures_remaining: 0,
182 connection_timeouts_remaining: 0,
183 timeout_duration: timeout_duration.unwrap_or(Duration::from_secs(5)),
184 }
185 }
186
187 pub fn fail_connection_times(mut self, times: usize) -> Self {
189 self.connection_failures_remaining = times;
190 self
191 }
192
193 pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
195 self.connection_timeouts_remaining = times;
196 self
197 }
198
199 pub fn insert_error(&mut self) {
201 self.actions.lock().unwrap().push(StreamAction::Error);
202 }
203
204 pub fn insert_timeout(&mut self) {
206 self.insert_timeout_with_duration(self.timeout_duration)
207 }
208
209 pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
211 self.actions.lock().unwrap().push(StreamAction::Timeout {
212 deadline: None,
213 duration,
214 });
215 }
216
217 pub fn insert_checkpoint(&mut self, sequence_number: u64) {
219 self.insert_checkpoint_range([sequence_number])
220 }
221
222 pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
223 where
224 I: IntoIterator<Item = u64>,
225 {
226 let mut actions = self.actions.lock().unwrap();
227 for sequence_number in checkpoint_range {
228 actions.push(StreamAction::Checkpoint(sequence_number));
229 }
230 }
231 }
232
233 #[async_trait]
234 impl CheckpointStreamingClient for MockStreamingClient {
235 async fn connect(&mut self) -> Result<CheckpointStream> {
236 if self.connection_timeouts_remaining > 0 {
237 self.connection_timeouts_remaining -= 1;
238 tokio::time::sleep(self.timeout_duration).await;
240 return Err(Error::StreamingError(anyhow::anyhow!(
241 "Mock connection timeout"
242 )));
243 }
244 if self.connection_failures_remaining > 0 {
245 self.connection_failures_remaining -= 1;
246 return Err(Error::StreamingError(anyhow::anyhow!(
247 "Mock connection failure"
248 )));
249 }
250 let stream = MockStreamState {
251 actions: Arc::clone(&self.actions),
252 };
253
254 Ok(Box::pin(stream))
255 }
256 }
257}