sui_indexer_alt_framework/ingestion/
streaming_client.rs1use std::pin::Pin;
5use std::time::Duration;
6
7use anyhow::Context;
8use async_trait::async_trait;
9use futures::Stream;
10use futures::StreamExt;
11use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
12use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
13use tonic::Status;
14use tonic::transport::Endpoint;
15use tonic::transport::Uri;
16
17use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
18use crate::ingestion::error::Error;
19use crate::ingestion::error::Result;
20use crate::types::full_checkpoint_content::Checkpoint;
21
22pub type CheckpointStream = Pin<Box<dyn Stream<Item = Result<Checkpoint>> + Send>>;
24
25#[async_trait]
27pub trait CheckpointStreamingClient {
28 async fn connect(&mut self) -> Result<CheckpointStream>;
29}
30
31#[derive(clap::Args, Clone, Debug, Default)]
32pub struct StreamingClientArgs {
33 #[clap(long, env)]
35 pub streaming_url: Option<Uri>,
36}
37
38pub struct GrpcStreamingClient {
40 uri: Uri,
41 connection_timeout: Duration,
42}
43
44impl GrpcStreamingClient {
45 pub fn new(uri: Uri, connection_timeout: Duration) -> Self {
46 Self {
47 uri,
48 connection_timeout,
49 }
50 }
51}
52
53#[async_trait]
54impl CheckpointStreamingClient for GrpcStreamingClient {
55 async fn connect(&mut self) -> Result<CheckpointStream> {
56 let endpoint = Endpoint::from(self.uri.clone()).connect_timeout(self.connection_timeout);
57
58 let mut client = SubscriptionServiceClient::connect(endpoint)
59 .await
60 .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
61 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
62
63 let mut request = SubscribeCheckpointsRequest::default();
64 request.read_mask = Some(Checkpoint::proto_field_mask());
65
66 let stream = client
67 .subscribe_checkpoints(request)
68 .await
69 .map_err(Error::RpcClientError)?
70 .into_inner();
71
72 let converted_stream = stream.map(|result| match result {
73 Ok(response) => response
74 .checkpoint
75 .context("Checkpoint data missing in response")
76 .and_then(|checkpoint| {
77 Checkpoint::try_from(&checkpoint).context("Failed to parse checkpoint")
78 })
79 .map_err(Error::StreamingError),
80 Err(e) => Err(Error::RpcClientError(e)),
81 });
82
83 Ok(Box::pin(converted_stream))
84 }
85}
86
87#[cfg(test)]
88pub mod test_utils {
89 use std::sync::Arc;
90 use std::sync::Mutex;
91 use std::time::Duration;
92 use std::time::Instant;
93
94 use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
95
96 use super::*;
97
98 enum StreamAction {
99 Checkpoint(u64),
100 Error,
101 Timeout {
102 deadline: Option<Instant>,
103 duration: Duration,
104 },
105 }
106
107 struct MockStreamState {
108 actions: Arc<Mutex<Vec<StreamAction>>>,
109 }
110
111 impl Stream for MockStreamState {
112 type Item = Result<Checkpoint>;
113
114 fn poll_next(
115 self: Pin<&mut Self>,
116 _cx: &mut std::task::Context<'_>,
117 ) -> std::task::Poll<Option<Self::Item>> {
118 let mut actions = self.actions.lock().unwrap();
119 if actions.is_empty() {
120 return std::task::Poll::Ready(None);
121 }
122
123 match &actions[0] {
124 StreamAction::Checkpoint(seq) => {
125 let seq = *seq;
126 actions.remove(0);
127 let mut builder = TestCheckpointBuilder::new(seq);
128 std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
129 }
130 StreamAction::Error => {
131 actions.remove(0);
132 std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
133 "Mock streaming error"
134 )))))
135 }
136 StreamAction::Timeout { deadline, duration } => match deadline {
137 None => {
138 let deadline = Instant::now() + *duration;
139 actions[0] = StreamAction::Timeout {
140 deadline: Some(deadline),
141 duration: *duration,
142 };
143 std::task::Poll::Pending
144 }
145 Some(deadline_instant) => {
146 if Instant::now() >= *deadline_instant {
147 actions.remove(0);
148 drop(actions);
149 self.poll_next(_cx)
150 } else {
151 std::task::Poll::Pending
152 }
153 }
154 },
155 }
156 }
157 }
158
159 pub struct MockStreamingClient {
161 actions: Arc<Mutex<Vec<StreamAction>>>,
162 connection_failures_remaining: usize,
163 connection_timeouts_remaining: usize,
164 timeout_duration: Duration,
165 }
166
167 impl MockStreamingClient {
168 pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
169 where
170 I: IntoIterator<Item = u64>,
171 {
172 Self {
173 actions: Arc::new(Mutex::new(
174 checkpoint_range
175 .into_iter()
176 .map(StreamAction::Checkpoint)
177 .collect(),
178 )),
179 connection_failures_remaining: 0,
180 connection_timeouts_remaining: 0,
181 timeout_duration: timeout_duration.unwrap_or(Duration::from_secs(5)),
182 }
183 }
184
185 pub fn fail_connection_times(mut self, times: usize) -> Self {
187 self.connection_failures_remaining = times;
188 self
189 }
190
191 pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
193 self.connection_timeouts_remaining = times;
194 self
195 }
196
197 pub fn insert_error(&mut self) {
199 self.actions.lock().unwrap().push(StreamAction::Error);
200 }
201
202 pub fn insert_timeout(&mut self) {
204 self.insert_timeout_with_duration(self.timeout_duration)
205 }
206
207 pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
209 self.actions.lock().unwrap().push(StreamAction::Timeout {
210 deadline: None,
211 duration,
212 });
213 }
214
215 pub fn insert_checkpoint(&mut self, sequence_number: u64) {
217 self.insert_checkpoint_range([sequence_number])
218 }
219
220 pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
221 where
222 I: IntoIterator<Item = u64>,
223 {
224 let mut actions = self.actions.lock().unwrap();
225 for sequence_number in checkpoint_range {
226 actions.push(StreamAction::Checkpoint(sequence_number));
227 }
228 }
229 }
230
231 #[async_trait]
232 impl CheckpointStreamingClient for MockStreamingClient {
233 async fn connect(&mut self) -> Result<CheckpointStream> {
234 if self.connection_timeouts_remaining > 0 {
235 self.connection_timeouts_remaining -= 1;
236 tokio::time::sleep(self.timeout_duration).await;
238 return Err(Error::StreamingError(anyhow::anyhow!(
239 "Mock connection timeout"
240 )));
241 }
242 if self.connection_failures_remaining > 0 {
243 self.connection_failures_remaining -= 1;
244 return Err(Error::StreamingError(anyhow::anyhow!(
245 "Mock connection failure"
246 )));
247 }
248 let stream = MockStreamState {
249 actions: Arc::clone(&self.actions),
250 };
251
252 Ok(Box::pin(stream))
253 }
254 }
255}