1use super::{PeerHeights, StateSync, StateSyncMessage};
5use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
6use dashmap::DashMap;
7use futures::future::BoxFuture;
8use serde::{Deserialize, Serialize};
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11use sui_types::messages_checkpoint::VersionedFullCheckpointContents;
12use sui_types::{
13 digests::{CheckpointContentsDigest, CheckpointDigest},
14 messages_checkpoint::{
15 CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
16 VerifiedCheckpoint,
17 },
18 storage::WriteStore,
19};
20use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
21
22#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
23pub enum GetCheckpointSummaryRequest {
24 Latest,
25 ByDigest(CheckpointDigest),
26 BySequenceNumber(CheckpointSequenceNumber),
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct GetCheckpointAvailabilityResponse {
31 pub(crate) highest_synced_checkpoint: Checkpoint,
32 pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
33}
34
35pub(super) struct Server<S> {
36 pub(super) store: S,
37 pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
38 pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
39}
40
41#[anemo::async_trait]
42impl<S> StateSync for Server<S>
43where
44 S: WriteStore + Send + Sync + 'static,
45{
46 async fn push_checkpoint_summary(
47 &self,
48 request: Request<Checkpoint>,
49 ) -> Result<Response<()>, Status> {
50 let peer_id = request
51 .peer_id()
52 .copied()
53 .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
54
55 let checkpoint = request.into_inner();
56 if !self
57 .peer_heights
58 .write()
59 .unwrap()
60 .update_peer_info(peer_id, checkpoint.clone(), None)
61 {
62 return Ok(Response::new(()));
63 }
64
65 let highest_verified_checkpoint = *self
66 .store
67 .get_highest_verified_checkpoint()
68 .map_err(|e| Status::internal(e.to_string()))?
69 .sequence_number();
70
71 if *checkpoint.sequence_number() > highest_verified_checkpoint
74 && let Some(sender) = self.sender.upgrade()
75 {
76 sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
77 }
78
79 Ok(Response::new(()))
80 }
81
82 async fn get_checkpoint_summary(
83 &self,
84 request: Request<GetCheckpointSummaryRequest>,
85 ) -> Result<Response<Option<Checkpoint>>, Status> {
86 let checkpoint = match request.inner() {
87 GetCheckpointSummaryRequest::Latest => self
88 .store
89 .get_highest_synced_checkpoint()
90 .map(Some)
91 .map_err(|e| Status::internal(e.to_string()))?,
92 GetCheckpointSummaryRequest::ByDigest(digest) => {
93 self.store.get_checkpoint_by_digest(digest)
94 }
95 GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
96 .store
97 .get_checkpoint_by_sequence_number(*sequence_number),
98 }
99 .map(VerifiedCheckpoint::into_inner);
100
101 Ok(Response::new(checkpoint))
102 }
103
104 async fn get_checkpoint_availability(
105 &self,
106 _request: Request<()>,
107 ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
108 let highest_synced_checkpoint = self
109 .store
110 .get_highest_synced_checkpoint()
111 .map_err(|e| Status::internal(e.to_string()))
112 .map(VerifiedCheckpoint::into_inner)?;
113 let lowest_available_checkpoint = self
114 .store
115 .get_lowest_available_checkpoint()
116 .map_err(|e| Status::internal(e.to_string()))?;
117
118 Ok(Response::new(GetCheckpointAvailabilityResponse {
119 highest_synced_checkpoint,
120 lowest_available_checkpoint,
121 }))
122 }
123
124 async fn get_checkpoint_contents(
125 &self,
126 request: Request<CheckpointContentsDigest>,
127 ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
128 let contents = self
129 .store
130 .get_full_checkpoint_contents(None, request.inner());
131 Ok(Response::new(contents.map(|v| v.into_v1())))
132 }
133
134 async fn get_checkpoint_contents_v2(
135 &self,
136 request: Request<CheckpointContentsDigest>,
137 ) -> Result<Response<Option<VersionedFullCheckpointContents>>, Status> {
138 let contents = self
139 .store
140 .get_full_checkpoint_contents(None, request.inner());
141 Ok(Response::new(contents))
142 }
143}
144
145#[derive(Clone)]
148pub(super) struct CheckpointContentsDownloadLimitLayer {
149 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
150 max_inflight_per_checkpoint: usize,
151}
152
153impl CheckpointContentsDownloadLimitLayer {
154 pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
155 Self {
156 inflight_per_checkpoint: Arc::new(DashMap::new()),
157 max_inflight_per_checkpoint,
158 }
159 }
160
161 pub(super) fn maybe_prune_map(&self) {
162 const PRUNE_THRESHOLD: usize = 5000;
163 if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
164 self.inflight_per_checkpoint.retain(|_, semaphore| {
165 semaphore.available_permits() < self.max_inflight_per_checkpoint
166 });
167 }
168 }
169}
170
171impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
172 type Service = CheckpointContentsDownloadLimit<S>;
173
174 fn layer(&self, inner: S) -> Self::Service {
175 CheckpointContentsDownloadLimit {
176 inner,
177 inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
178 max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
179 }
180 }
181}
182
183#[derive(Clone)]
186pub(super) struct CheckpointContentsDownloadLimit<S> {
187 inner: S,
188 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
189 max_inflight_per_checkpoint: usize,
190}
191
192impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
193where
194 S: tower::Service<
195 Request<CheckpointContentsDigest>,
196 Response = Response<Option<FullCheckpointContents>>,
197 Error = Status,
198 >
199 + 'static
200 + Clone
201 + Send,
202 <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
203 Request<CheckpointContentsDigest>: 'static + Send + Sync,
204{
205 type Response = Response<Option<FullCheckpointContents>>;
206 type Error = S::Error;
207 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
208
209 #[inline]
210 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211 self.inner.poll_ready(cx)
212 }
213
214 fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
215 let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
216 let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
217 let mut inner = self.inner.clone();
218
219 let fut = async move {
220 let semaphore = {
221 let semaphore_entry = inflight_per_checkpoint
222 .entry(*req.body())
223 .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
224 semaphore_entry.value().clone()
225 };
226 let permit = semaphore.try_acquire_owned().map_err(|e| match e {
227 tokio::sync::TryAcquireError::Closed => {
228 anemo::rpc::Status::new(StatusCode::InternalServerError)
229 }
230 tokio::sync::TryAcquireError::NoPermits => {
231 anemo::rpc::Status::new(StatusCode::TooManyRequests)
232 }
233 })?;
234
235 struct SemaphoreExtension(#[allow(unused)] OwnedSemaphorePermit);
236 inner.call(req).await.map(move |mut response| {
237 response
239 .extensions_mut()
240 .insert(Arc::new(SemaphoreExtension(permit)));
241 response
242 })
243 };
244 Box::pin(fut)
245 }
246}