1use super::{PeerHeights, StateSync, StateSyncMessage};
5use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
6use bytes::Bytes;
7use dashmap::DashMap;
8use futures::future::BoxFuture;
9use serde::{Deserialize, Serialize};
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12use sui_types::messages_checkpoint::VersionedFullCheckpointContents;
13use sui_types::{
14 digests::{CheckpointContentsDigest, CheckpointDigest},
15 messages_checkpoint::{
16 CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
17 VerifiedCheckpoint,
18 },
19 storage::WriteStore,
20};
21use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
22
23#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
24pub enum GetCheckpointSummaryRequest {
25 Latest,
26 ByDigest(CheckpointDigest),
27 BySequenceNumber(CheckpointSequenceNumber),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct GetCheckpointAvailabilityResponse {
32 pub(crate) highest_synced_checkpoint: Checkpoint,
33 pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
34}
35
36pub(super) struct Server<S> {
37 pub(super) store: S,
38 pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
39 pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
40 pub(super) max_checkpoint_lookahead: u64,
41}
42
43#[anemo::async_trait]
44impl<S> StateSync for Server<S>
45where
46 S: WriteStore + Send + Sync + 'static,
47{
48 async fn push_checkpoint_summary(
49 &self,
50 request: Request<Checkpoint>,
51 ) -> Result<Response<()>, Status> {
52 let peer_id = request
53 .peer_id()
54 .copied()
55 .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
56
57 let checkpoint = request.into_inner();
58 let checkpoint_seq = *checkpoint.sequence_number();
59
60 let highest_verified_checkpoint = *self
61 .store
62 .get_highest_verified_checkpoint()
63 .map_err(|e| Status::internal(e.to_string()))?
64 .sequence_number();
65
66 {
67 let mut peer_heights = self.peer_heights.write().unwrap();
68
69 let peer_on_same_chain = peer_heights.update_peer_height(peer_id, checkpoint_seq, None);
72 if !peer_on_same_chain {
73 return Ok(Response::new(()));
74 }
75
76 if checkpoint_seq
77 <= highest_verified_checkpoint.saturating_add(self.max_checkpoint_lookahead)
78 {
79 peer_heights.insert_checkpoint(checkpoint);
80 } else {
81 tracing::debug!(
82 peer_id = ?peer_id,
83 checkpoint_seq = checkpoint_seq,
84 highest_verified = highest_verified_checkpoint,
85 max_lookahead = self.max_checkpoint_lookahead,
86 "not storing checkpoint summary that exceeds max lookahead"
87 );
88 }
89 }
90
91 if checkpoint_seq > highest_verified_checkpoint
94 && let Some(sender) = self.sender.upgrade()
95 {
96 sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
97 }
98
99 Ok(Response::new(()))
100 }
101
102 async fn get_checkpoint_summary(
103 &self,
104 request: Request<GetCheckpointSummaryRequest>,
105 ) -> Result<Response<Option<Checkpoint>>, Status> {
106 let checkpoint = match request.inner() {
107 GetCheckpointSummaryRequest::Latest => self
108 .store
109 .get_highest_synced_checkpoint()
110 .map(Some)
111 .map_err(|e| Status::internal(e.to_string()))?,
112 GetCheckpointSummaryRequest::ByDigest(digest) => {
113 self.store.get_checkpoint_by_digest(digest)
114 }
115 GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
116 .store
117 .get_checkpoint_by_sequence_number(*sequence_number),
118 }
119 .map(VerifiedCheckpoint::into_inner);
120
121 Ok(Response::new(checkpoint))
122 }
123
124 async fn get_checkpoint_availability(
125 &self,
126 _request: Request<()>,
127 ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
128 let highest_synced_checkpoint = self
129 .store
130 .get_highest_synced_checkpoint()
131 .map_err(|e| Status::internal(e.to_string()))
132 .map(VerifiedCheckpoint::into_inner)?;
133 let lowest_available_checkpoint = self
134 .store
135 .get_lowest_available_checkpoint()
136 .map_err(|e| Status::internal(e.to_string()))?;
137
138 Ok(Response::new(GetCheckpointAvailabilityResponse {
139 highest_synced_checkpoint,
140 lowest_available_checkpoint,
141 }))
142 }
143
144 async fn get_checkpoint_contents(
145 &self,
146 request: Request<CheckpointContentsDigest>,
147 ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
148 let contents = self
149 .store
150 .get_full_checkpoint_contents(None, request.inner());
151 Ok(Response::new(contents.map(|v| v.into_v1())))
152 }
153
154 async fn get_checkpoint_contents_v2(
155 &self,
156 request: Request<CheckpointContentsDigest>,
157 ) -> Result<Response<Option<VersionedFullCheckpointContents>>, Status> {
158 let contents = self
159 .store
160 .get_full_checkpoint_contents(None, request.inner());
161 Ok(Response::new(contents))
162 }
163}
164
165#[derive(Clone)]
167pub(super) struct SizeLimitLayer {
168 max_size: usize,
169}
170
171impl SizeLimitLayer {
172 pub(super) fn new(max_size: usize) -> Self {
173 Self { max_size }
174 }
175}
176
177impl<S> tower::layer::Layer<S> for SizeLimitLayer {
178 type Service = SizeLimit<S>;
179
180 fn layer(&self, inner: S) -> Self::Service {
181 SizeLimit {
182 inner,
183 max_size: self.max_size,
184 }
185 }
186}
187
188#[derive(Clone)]
190pub(super) struct SizeLimit<S> {
191 inner: S,
192 max_size: usize,
193}
194
195impl<S> tower::Service<Request<Bytes>> for SizeLimit<S>
196where
197 S: tower::Service<Request<Bytes>, Response = Response<Bytes>> + 'static + Clone + Send,
198 <S as tower::Service<Request<Bytes>>>::Future: Send,
199{
200 type Response = Response<Bytes>;
201 type Error = S::Error;
202 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
203
204 #[inline]
205 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206 self.inner.poll_ready(cx)
207 }
208
209 fn call(&mut self, req: Request<Bytes>) -> Self::Future {
210 let body_size = req.body().len();
211 let max_size = self.max_size;
212 if body_size > max_size {
213 let peer_id = req.peer_id().copied();
214 tracing::info!(
215 ?peer_id,
216 body_size,
217 max_size,
218 "rejecting request that exceeds max size"
219 );
220 return Box::pin(async move { Ok(Response::new(Bytes::new())) });
221 }
222 let mut inner = self.inner.clone();
223 Box::pin(async move { inner.call(req).await })
224 }
225}
226
227#[derive(Clone)]
230pub(super) struct CheckpointContentsDownloadLimitLayer {
231 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
232 max_inflight_per_checkpoint: usize,
233}
234
235impl CheckpointContentsDownloadLimitLayer {
236 pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
237 Self {
238 inflight_per_checkpoint: Arc::new(DashMap::new()),
239 max_inflight_per_checkpoint,
240 }
241 }
242
243 pub(super) fn maybe_prune_map(&self) {
244 const PRUNE_THRESHOLD: usize = 5000;
245 if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
246 self.inflight_per_checkpoint.retain(|_, semaphore| {
247 semaphore.available_permits() < self.max_inflight_per_checkpoint
248 });
249 }
250 }
251}
252
253impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
254 type Service = CheckpointContentsDownloadLimit<S>;
255
256 fn layer(&self, inner: S) -> Self::Service {
257 CheckpointContentsDownloadLimit {
258 inner,
259 inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
260 max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
261 }
262 }
263}
264
265#[derive(Clone)]
268pub(super) struct CheckpointContentsDownloadLimit<S> {
269 inner: S,
270 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
271 max_inflight_per_checkpoint: usize,
272}
273
274impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
275where
276 S: tower::Service<
277 Request<CheckpointContentsDigest>,
278 Response = Response<Option<FullCheckpointContents>>,
279 Error = Status,
280 >
281 + 'static
282 + Clone
283 + Send,
284 <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
285 Request<CheckpointContentsDigest>: 'static + Send + Sync,
286{
287 type Response = Response<Option<FullCheckpointContents>>;
288 type Error = S::Error;
289 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
290
291 #[inline]
292 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
293 self.inner.poll_ready(cx)
294 }
295
296 fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
297 let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
298 let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
299 let mut inner = self.inner.clone();
300
301 let fut = async move {
302 let semaphore = {
303 let semaphore_entry = inflight_per_checkpoint
304 .entry(*req.body())
305 .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
306 semaphore_entry.value().clone()
307 };
308 let permit = semaphore.try_acquire_owned().map_err(|e| match e {
309 tokio::sync::TryAcquireError::Closed => {
310 anemo::rpc::Status::new(StatusCode::InternalServerError)
311 }
312 tokio::sync::TryAcquireError::NoPermits => {
313 anemo::rpc::Status::new(StatusCode::TooManyRequests)
314 }
315 })?;
316
317 struct SemaphoreExtension(#[allow(unused)] OwnedSemaphorePermit);
318 inner.call(req).await.map(move |mut response| {
319 response
321 .extensions_mut()
322 .insert(Arc::new(SemaphoreExtension(permit)));
323 response
324 })
325 };
326 Box::pin(fut)
327 }
328}