sui_indexer_alt_framework/ingestion/
streaming_client.rs1use std::time::Duration;
5
6use anyhow::Context;
7use anyhow::anyhow;
8use async_trait::async_trait;
9use futures::StreamExt;
10use futures::stream::BoxStream;
11use sui_rpc::headers::X_SUI_CHAIN_ID;
12use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
13use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
14use sui_types::digests::ChainIdentifier;
15use sui_types::messages_checkpoint::CheckpointDigest;
16use tokio_stream::adapters::Peekable;
17use tonic::Status;
18use tonic::transport::Endpoint;
19use tonic::transport::Uri;
20
21use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
22use crate::ingestion::error::Error;
23use crate::ingestion::error::Result;
24use crate::types::full_checkpoint_content::Checkpoint;
25
26pub struct CheckpointStream {
27 pub stream: Peekable<BoxStream<'static, Result<Checkpoint>>>,
28 pub chain_id: ChainIdentifier,
29}
30
31#[async_trait]
33pub trait CheckpointStreamingClient {
34 async fn connect(&mut self) -> Result<CheckpointStream>;
36}
37
38#[derive(clap::Args, Clone, Debug, Default)]
39pub struct StreamingClientArgs {
40 #[clap(long, env)]
42 pub streaming_url: Option<Uri>,
43}
44
45#[derive(Clone)]
47pub struct GrpcStreamingClient {
48 uri: Uri,
49 connection_timeout: Duration,
50 statement_timeout: Duration,
51}
52
53impl GrpcStreamingClient {
54 pub fn new(uri: Uri, connection_timeout: Duration, statement_timeout: Duration) -> Self {
55 Self {
56 uri,
57 connection_timeout,
58 statement_timeout,
59 }
60 }
61}
62
63#[async_trait]
64impl CheckpointStreamingClient for GrpcStreamingClient {
65 async fn connect(&mut self) -> Result<CheckpointStream> {
66 let endpoint = Endpoint::from(self.uri.clone()).connect_timeout(self.connection_timeout);
67
68 let mut client = SubscriptionServiceClient::connect(endpoint)
69 .await
70 .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
71 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
72
73 let mut request = SubscribeCheckpointsRequest::default();
74 request.read_mask = Some(Checkpoint::proto_field_mask());
75
76 let response = client
77 .subscribe_checkpoints(request)
78 .await
79 .map_err(Error::RpcClientError)?;
80
81 let chain_id_value = response.metadata().get(X_SUI_CHAIN_ID).ok_or_else(|| {
82 Error::StreamingError(anyhow!("Chain ID not found in response metadata"))
83 })?;
84 let chain_id: ChainIdentifier = chain_id_value
85 .to_str()
86 .map_err(|e| Error::StreamingError(anyhow!("Chain ID is not valid ASCII: {e}")))?
87 .parse::<CheckpointDigest>()
88 .map_err(|e| Error::StreamingError(anyhow!("Chain ID parse error: {e}")))?
89 .into();
90
91 let stream = response.into_inner().map(|result| match result {
92 Ok(response) => response
93 .checkpoint
94 .context("Checkpoint data missing in response")
95 .and_then(|checkpoint| {
96 Checkpoint::try_from(&checkpoint).context("Failed to parse checkpoint")
97 })
98 .map_err(Error::StreamingError),
99 Err(e) => Err(Error::RpcClientError(e)),
100 });
101 let stream = wrap_stream(stream, self.statement_timeout);
102
103 Ok(CheckpointStream { stream, chain_id })
104 }
105}
106
107fn wrap_stream(
110 stream: impl futures::Stream<Item = Result<Checkpoint>> + Send + 'static,
111 statement_timeout: Duration,
112) -> Peekable<BoxStream<'static, Result<Checkpoint>>> {
113 let stream = tokio_stream::StreamExt::timeout(stream, statement_timeout)
114 .map(move |result| match result {
115 Err(_elapsed) => Err(Error::StreamingError(anyhow!(
116 "Statement timeout after {statement_timeout:?}"
117 ))),
118 Ok(result) => result,
119 })
120 .boxed();
121 tokio_stream::StreamExt::peekable(stream)
122}
123
124#[cfg(test)]
125pub mod test_utils {
126 use std::pin::Pin;
127 use std::sync::Arc;
128 use std::sync::Mutex;
129 use std::time::Duration;
130 use std::time::Instant;
131
132 use futures::Stream;
133
134 use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
135
136 use super::*;
137
138 enum StreamAction {
139 Checkpoint(u64),
140 Error,
141 Timeout {
142 deadline: Option<Instant>,
143 duration: Duration,
144 },
145 }
146
147 struct MockStreamState {
148 actions: Arc<Mutex<Vec<StreamAction>>>,
149 }
150
151 impl Stream for MockStreamState {
152 type Item = Result<Checkpoint>;
153
154 fn poll_next(
155 self: Pin<&mut Self>,
156 _cx: &mut std::task::Context<'_>,
157 ) -> std::task::Poll<Option<Self::Item>> {
158 let mut actions = self.actions.lock().unwrap();
159 if actions.is_empty() {
160 return std::task::Poll::Ready(None);
161 }
162
163 match &actions[0] {
164 StreamAction::Checkpoint(seq) => {
165 let seq = *seq;
166 actions.remove(0);
167 let mut builder = TestCheckpointBuilder::new(seq);
168 std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
169 }
170 StreamAction::Error => {
171 actions.remove(0);
172 std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
173 "Mock streaming error"
174 )))))
175 }
176 StreamAction::Timeout { deadline, duration } => match deadline {
177 None => {
178 let deadline = Instant::now() + *duration;
179 actions[0] = StreamAction::Timeout {
180 deadline: Some(deadline),
181 duration: *duration,
182 };
183 std::task::Poll::Pending
184 }
185 Some(deadline_instant) => {
186 if Instant::now() >= *deadline_instant {
187 actions.remove(0);
188 drop(actions);
189 self.poll_next(_cx)
190 } else {
191 std::task::Poll::Pending
192 }
193 }
194 },
195 }
196 }
197 }
198
199 pub struct MockStreamingClient {
201 actions: Arc<Mutex<Vec<StreamAction>>>,
202 connection_failures_remaining: usize,
203 connection_timeouts_remaining: usize,
204 timeout_duration: Duration,
206 statement_timeout: Duration,
208 }
209
210 impl MockStreamingClient {
211 pub fn mock_chain_id() -> ChainIdentifier {
212 CheckpointDigest::new([1; 32]).into()
213 }
214
215 pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
216 where
217 I: IntoIterator<Item = u64>,
218 {
219 let timeout_duration = timeout_duration.unwrap_or(Duration::from_secs(5));
220 Self {
221 actions: Arc::new(Mutex::new(
222 checkpoint_range
223 .into_iter()
224 .map(StreamAction::Checkpoint)
225 .collect(),
226 )),
227 connection_failures_remaining: 0,
228 connection_timeouts_remaining: 0,
229 statement_timeout: timeout_duration / 2,
230 timeout_duration,
231 }
232 }
233
234 pub fn fail_connection_times(mut self, times: usize) -> Self {
236 self.connection_failures_remaining = times;
237 self
238 }
239
240 pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
242 self.connection_timeouts_remaining = times;
243 self
244 }
245
246 pub fn insert_error(&mut self) {
248 self.actions.lock().unwrap().push(StreamAction::Error);
249 }
250
251 pub fn insert_timeout(&mut self) {
253 self.insert_timeout_with_duration(self.timeout_duration)
254 }
255
256 pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
258 self.actions.lock().unwrap().push(StreamAction::Timeout {
259 deadline: None,
260 duration,
261 });
262 }
263
264 pub fn insert_checkpoint(&mut self, sequence_number: u64) {
266 self.insert_checkpoint_range([sequence_number])
267 }
268
269 pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
270 where
271 I: IntoIterator<Item = u64>,
272 {
273 let mut actions = self.actions.lock().unwrap();
274 for sequence_number in checkpoint_range {
275 actions.push(StreamAction::Checkpoint(sequence_number));
276 }
277 }
278 }
279
280 #[async_trait]
281 impl CheckpointStreamingClient for MockStreamingClient {
282 async fn connect(&mut self) -> Result<CheckpointStream> {
283 if self.connection_timeouts_remaining > 0 {
284 self.connection_timeouts_remaining -= 1;
285 tokio::time::sleep(self.timeout_duration).await;
287 return Err(Error::StreamingError(anyhow::anyhow!(
288 "Mock connection timeout"
289 )));
290 }
291 if self.connection_failures_remaining > 0 {
292 self.connection_failures_remaining -= 1;
293 return Err(Error::StreamingError(anyhow::anyhow!(
294 "Mock connection failure"
295 )));
296 }
297 let stream_state = MockStreamState {
298 actions: Arc::clone(&self.actions),
299 };
300 Ok(CheckpointStream {
301 stream: wrap_stream(stream_state, self.statement_timeout),
302 chain_id: Self::mock_chain_id(),
303 })
304 }
305 }
306}