sui_indexer_alt_framework/ingestion/
streaming_client.rs1use std::time::Duration;
5
6use anyhow::Context;
7use anyhow::anyhow;
8use async_trait::async_trait;
9use futures::StreamExt;
10use futures::stream::BoxStream;
11use sui_rpc::headers::X_SUI_CHAIN_ID;
12use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
13use sui_rpc::proto::sui::rpc::v2::subscription_service_client::SubscriptionServiceClient;
14use sui_types::digests::ChainIdentifier;
15use sui_types::messages_checkpoint::CheckpointDigest;
16use tokio_stream::adapters::Peekable;
17use tonic::Status;
18use tonic::transport::Endpoint;
19use tonic::transport::Uri;
20
21use crate::ingestion::MAX_GRPC_MESSAGE_SIZE_BYTES;
22use crate::ingestion::error::Error;
23use crate::ingestion::error::Result;
24use crate::types::full_checkpoint_content::Checkpoint;
25
26pub struct CheckpointStream {
27 pub stream: Peekable<BoxStream<'static, Result<Checkpoint>>>,
28 pub chain_id: ChainIdentifier,
29}
30
31#[async_trait]
33pub trait CheckpointStreamingClient {
34 async fn connect(&mut self) -> Result<CheckpointStream>;
36
37 async fn latest_checkpoint_number(&mut self) -> Result<u64> {
39 let mut stream = self.connect().await?;
40
41 match stream.stream.next().await {
42 Some(Ok(checkpoint)) => Ok(checkpoint.summary.sequence_number),
43 Some(Err(e)) => Err(e),
44 None => Err(Error::StreamingError(anyhow!("Stream ended unexpectedly"))),
45 }
46 }
47}
48
49#[derive(clap::Args, Clone, Debug, Default)]
50pub struct StreamingClientArgs {
51 #[clap(long, env)]
53 pub streaming_url: Option<Uri>,
54}
55
56#[derive(Clone)]
58pub struct GrpcStreamingClient {
59 uri: Uri,
60 connection_timeout: Duration,
61 statement_timeout: Duration,
62}
63
64impl GrpcStreamingClient {
65 pub fn new(uri: Uri, connection_timeout: Duration, statement_timeout: Duration) -> Self {
66 Self {
67 uri,
68 connection_timeout,
69 statement_timeout,
70 }
71 }
72}
73
74#[async_trait]
75impl CheckpointStreamingClient for GrpcStreamingClient {
76 async fn connect(&mut self) -> Result<CheckpointStream> {
77 let endpoint = Endpoint::from(self.uri.clone())
78 .connect_timeout(self.connection_timeout)
79 .timeout(self.connection_timeout);
80
81 let mut client = SubscriptionServiceClient::connect(endpoint)
82 .await
83 .map_err(|err| Error::RpcClientError(Status::from_error(err.into())))?
84 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE_BYTES);
85
86 let mut request = SubscribeCheckpointsRequest::default();
87 request.read_mask = Some(Checkpoint::proto_field_mask());
88
89 let response = client
90 .subscribe_checkpoints(request)
91 .await
92 .map_err(Error::RpcClientError)?;
93
94 let chain_id_value = response.metadata().get(X_SUI_CHAIN_ID).ok_or_else(|| {
95 Error::StreamingError(anyhow!("Chain ID not found in response metadata"))
96 })?;
97 let chain_id: ChainIdentifier = chain_id_value
98 .to_str()
99 .map_err(|e| Error::StreamingError(anyhow!("Chain ID is not valid ASCII: {e}")))?
100 .parse::<CheckpointDigest>()
101 .map_err(|e| Error::StreamingError(anyhow!("Chain ID parse error: {e}")))?
102 .into();
103
104 let stream = response
105 .into_inner()
106 .map(|result| async move {
107 match result {
108 Ok(response) => {
109 let checkpoint = response
110 .checkpoint
111 .context("Checkpoint data missing in response")
112 .map_err(Error::StreamingError)?;
113 tokio::task::spawn_blocking(move || {
118 Checkpoint::try_from(&checkpoint).context("Failed to parse checkpoint")
119 })
120 .await
121 .map_err(|e| Error::StreamingError(anyhow!("decode task panicked: {e}")))?
122 .map_err(Error::StreamingError)
123 }
124 Err(e) => Err(Error::RpcClientError(e)),
125 }
126 })
127 .buffered(4);
128 let stream = wrap_stream(stream, self.statement_timeout);
129
130 Ok(CheckpointStream { stream, chain_id })
131 }
132}
133
134fn wrap_stream(
137 stream: impl futures::Stream<Item = Result<Checkpoint>> + Send + 'static,
138 statement_timeout: Duration,
139) -> Peekable<BoxStream<'static, Result<Checkpoint>>> {
140 let stream = tokio_stream::StreamExt::timeout(stream, statement_timeout)
141 .map(move |result| match result {
142 Err(_elapsed) => Err(Error::StreamingError(anyhow!(
143 "Statement timeout after {statement_timeout:?}"
144 ))),
145 Ok(result) => result,
146 })
147 .boxed();
148 tokio_stream::StreamExt::peekable(stream)
149}
150
151#[cfg(test)]
152mod tests {
153 use std::net::SocketAddr;
154 use std::time::Duration;
155
156 use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsRequest;
157 use sui_rpc::proto::sui::rpc::v2::SubscribeCheckpointsResponse;
158 use sui_rpc::proto::sui::rpc::v2::subscription_service_server::SubscriptionService;
159 use sui_rpc::proto::sui::rpc::v2::subscription_service_server::SubscriptionServiceServer;
160 use tonic::transport::Server;
161
162 use super::*;
163
164 struct HangingSubscriptionService;
167
168 #[tonic::async_trait]
169 impl SubscriptionService for HangingSubscriptionService {
170 async fn subscribe_checkpoints(
171 &self,
172 _request: tonic::Request<SubscribeCheckpointsRequest>,
173 ) -> std::result::Result<
174 tonic::Response<
175 BoxStream<'static, std::result::Result<SubscribeCheckpointsResponse, Status>>,
176 >,
177 Status,
178 > {
179 futures::future::pending().await
180 }
181 }
182
183 #[tokio::test]
184 async fn subscribe_checkpoints_times_out_on_stalled_server() {
185 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
186 let addr: SocketAddr = listener.local_addr().unwrap();
187
188 tokio::spawn(async move {
189 let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
190 Server::builder()
191 .add_service(SubscriptionServiceServer::new(HangingSubscriptionService))
192 .serve_with_incoming(incoming)
193 .await
194 .unwrap();
195 });
196
197 let timeout = Duration::from_millis(200);
198 let uri: Uri = format!("http://{addr}").parse().unwrap();
199 let mut client = GrpcStreamingClient::new(uri, timeout, timeout);
200
201 let start = std::time::Instant::now();
202 let result = client.connect().await;
203 let elapsed = start.elapsed();
204
205 assert!(result.is_err(), "expected timeout error");
206 assert!(
207 elapsed < Duration::from_secs(5),
208 "connect() took {elapsed:?}, should have timed out in ~200ms"
209 );
210 }
211}
212
213#[cfg(test)]
214pub mod test_utils {
215 use std::pin::Pin;
216 use std::sync::Arc;
217 use std::sync::Mutex;
218 use std::time::Duration;
219 use std::time::Instant;
220
221 use futures::Stream;
222
223 use crate::types::test_checkpoint_data_builder::TestCheckpointBuilder;
224
225 use super::*;
226
227 enum StreamAction {
228 Checkpoint(u64),
229 Error,
230 Timeout {
231 deadline: Option<Instant>,
232 duration: Duration,
233 },
234 }
235
236 struct MockStreamState {
237 actions: Arc<Mutex<Vec<StreamAction>>>,
238 }
239
240 impl Stream for MockStreamState {
241 type Item = Result<Checkpoint>;
242
243 fn poll_next(
244 self: Pin<&mut Self>,
245 _cx: &mut std::task::Context<'_>,
246 ) -> std::task::Poll<Option<Self::Item>> {
247 let mut actions = self.actions.lock().unwrap();
248 if actions.is_empty() {
249 return std::task::Poll::Ready(None);
250 }
251
252 match &actions[0] {
253 StreamAction::Checkpoint(seq) => {
254 let seq = *seq;
255 actions.remove(0);
256 let mut builder = TestCheckpointBuilder::new(seq);
257 std::task::Poll::Ready(Some(Ok(builder.build_checkpoint())))
258 }
259 StreamAction::Error => {
260 actions.remove(0);
261 std::task::Poll::Ready(Some(Err(Error::StreamingError(anyhow::anyhow!(
262 "Mock streaming error"
263 )))))
264 }
265 StreamAction::Timeout { deadline, duration } => match deadline {
266 None => {
267 let deadline = Instant::now() + *duration;
268 actions[0] = StreamAction::Timeout {
269 deadline: Some(deadline),
270 duration: *duration,
271 };
272 std::task::Poll::Pending
273 }
274 Some(deadline_instant) => {
275 if Instant::now() >= *deadline_instant {
276 actions.remove(0);
277 drop(actions);
278 self.poll_next(_cx)
279 } else {
280 std::task::Poll::Pending
281 }
282 }
283 },
284 }
285 }
286 }
287
288 pub struct MockStreamingClient {
290 actions: Arc<Mutex<Vec<StreamAction>>>,
291 connection_failures_remaining: usize,
292 connection_timeouts_remaining: usize,
293 timeout_duration: Duration,
295 statement_timeout: Duration,
297 }
298
299 impl MockStreamingClient {
300 pub fn mock_chain_id() -> ChainIdentifier {
301 CheckpointDigest::new([1; 32]).into()
302 }
303
304 pub fn new<I>(checkpoint_range: I, timeout_duration: Option<Duration>) -> Self
305 where
306 I: IntoIterator<Item = u64>,
307 {
308 let timeout_duration = timeout_duration.unwrap_or(Duration::from_secs(5));
309 Self {
310 actions: Arc::new(Mutex::new(
311 checkpoint_range
312 .into_iter()
313 .map(StreamAction::Checkpoint)
314 .collect(),
315 )),
316 connection_failures_remaining: 0,
317 connection_timeouts_remaining: 0,
318 statement_timeout: timeout_duration / 2,
319 timeout_duration,
320 }
321 }
322
323 pub fn fail_connection_times(mut self, times: usize) -> Self {
325 self.connection_failures_remaining = times;
326 self
327 }
328
329 pub fn fail_connection_with_timeout(mut self, times: usize) -> Self {
331 self.connection_timeouts_remaining = times;
332 self
333 }
334
335 pub fn insert_error(&mut self) {
337 self.actions.lock().unwrap().push(StreamAction::Error);
338 }
339
340 pub fn insert_timeout(&mut self) {
342 self.insert_timeout_with_duration(self.timeout_duration)
343 }
344
345 pub fn insert_timeout_with_duration(&mut self, duration: Duration) {
347 self.actions.lock().unwrap().push(StreamAction::Timeout {
348 deadline: None,
349 duration,
350 });
351 }
352
353 pub fn insert_checkpoint(&mut self, sequence_number: u64) {
355 self.insert_checkpoint_range([sequence_number])
356 }
357
358 pub fn insert_checkpoint_range<I>(&mut self, checkpoint_range: I)
359 where
360 I: IntoIterator<Item = u64>,
361 {
362 let mut actions = self.actions.lock().unwrap();
363 for sequence_number in checkpoint_range {
364 actions.push(StreamAction::Checkpoint(sequence_number));
365 }
366 }
367 }
368
369 #[async_trait]
370 impl CheckpointStreamingClient for MockStreamingClient {
371 async fn connect(&mut self) -> Result<CheckpointStream> {
372 if self.connection_timeouts_remaining > 0 {
373 self.connection_timeouts_remaining -= 1;
374 tokio::time::sleep(self.timeout_duration).await;
376 return Err(Error::StreamingError(anyhow::anyhow!(
377 "Mock connection timeout"
378 )));
379 }
380 if self.connection_failures_remaining > 0 {
381 self.connection_failures_remaining -= 1;
382 return Err(Error::StreamingError(anyhow::anyhow!(
383 "Mock connection failure"
384 )));
385 }
386 let stream_state = MockStreamState {
387 actions: Arc::clone(&self.actions),
388 };
389 Ok(CheckpointStream {
390 stream: wrap_stream(stream_state, self.statement_timeout),
391 chain_id: Self::mock_chain_id(),
392 })
393 }
394 }
395}