sui_network/state_sync/
server.rs1use super::{PeerHeights, StateSync, StateSyncMessage};
5use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
6use dashmap::DashMap;
7use futures::future::BoxFuture;
8use serde::{Deserialize, Serialize};
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11use sui_types::{
12 digests::{CheckpointContentsDigest, CheckpointDigest},
13 messages_checkpoint::{
14 CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
15 VerifiedCheckpoint,
16 },
17 storage::WriteStore,
18};
19use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
20
21#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
22pub enum GetCheckpointSummaryRequest {
23 Latest,
24 ByDigest(CheckpointDigest),
25 BySequenceNumber(CheckpointSequenceNumber),
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct GetCheckpointAvailabilityResponse {
30 pub(crate) highest_synced_checkpoint: Checkpoint,
31 pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
32}
33
34pub(super) struct Server<S> {
35 pub(super) store: S,
36 pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
37 pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
38}
39
40#[anemo::async_trait]
41impl<S> StateSync for Server<S>
42where
43 S: WriteStore + Send + Sync + 'static,
44{
45 async fn push_checkpoint_summary(
46 &self,
47 request: Request<Checkpoint>,
48 ) -> Result<Response<()>, Status> {
49 let peer_id = request
50 .peer_id()
51 .copied()
52 .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
53
54 let checkpoint = request.into_inner();
55 if !self
56 .peer_heights
57 .write()
58 .unwrap()
59 .update_peer_info(peer_id, checkpoint.clone(), None)
60 {
61 return Ok(Response::new(()));
62 }
63
64 let highest_verified_checkpoint = *self
65 .store
66 .get_highest_verified_checkpoint()
67 .map_err(|e| Status::internal(e.to_string()))?
68 .sequence_number();
69
70 if *checkpoint.sequence_number() > highest_verified_checkpoint
73 && let Some(sender) = self.sender.upgrade()
74 {
75 sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
76 }
77
78 Ok(Response::new(()))
79 }
80
81 async fn get_checkpoint_summary(
82 &self,
83 request: Request<GetCheckpointSummaryRequest>,
84 ) -> Result<Response<Option<Checkpoint>>, Status> {
85 let checkpoint = match request.inner() {
86 GetCheckpointSummaryRequest::Latest => self
87 .store
88 .get_highest_synced_checkpoint()
89 .map(Some)
90 .map_err(|e| Status::internal(e.to_string()))?,
91 GetCheckpointSummaryRequest::ByDigest(digest) => {
92 self.store.get_checkpoint_by_digest(digest)
93 }
94 GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
95 .store
96 .get_checkpoint_by_sequence_number(*sequence_number),
97 }
98 .map(VerifiedCheckpoint::into_inner);
99
100 Ok(Response::new(checkpoint))
101 }
102
103 async fn get_checkpoint_availability(
104 &self,
105 _request: Request<()>,
106 ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
107 let highest_synced_checkpoint = self
108 .store
109 .get_highest_synced_checkpoint()
110 .map_err(|e| Status::internal(e.to_string()))
111 .map(VerifiedCheckpoint::into_inner)?;
112 let lowest_available_checkpoint = self
113 .store
114 .get_lowest_available_checkpoint()
115 .map_err(|e| Status::internal(e.to_string()))?;
116
117 Ok(Response::new(GetCheckpointAvailabilityResponse {
118 highest_synced_checkpoint,
119 lowest_available_checkpoint,
120 }))
121 }
122
123 async fn get_checkpoint_contents(
124 &self,
125 request: Request<CheckpointContentsDigest>,
126 ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
127 let contents = self
128 .store
129 .get_full_checkpoint_contents(None, request.inner());
130 Ok(Response::new(contents))
131 }
132}
133
134#[derive(Clone)]
137pub(super) struct CheckpointContentsDownloadLimitLayer {
138 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
139 max_inflight_per_checkpoint: usize,
140}
141
142impl CheckpointContentsDownloadLimitLayer {
143 pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
144 Self {
145 inflight_per_checkpoint: Arc::new(DashMap::new()),
146 max_inflight_per_checkpoint,
147 }
148 }
149
150 pub(super) fn maybe_prune_map(&self) {
151 const PRUNE_THRESHOLD: usize = 5000;
152 if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
153 self.inflight_per_checkpoint.retain(|_, semaphore| {
154 semaphore.available_permits() < self.max_inflight_per_checkpoint
155 });
156 }
157 }
158}
159
160impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
161 type Service = CheckpointContentsDownloadLimit<S>;
162
163 fn layer(&self, inner: S) -> Self::Service {
164 CheckpointContentsDownloadLimit {
165 inner,
166 inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
167 max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
168 }
169 }
170}
171
172#[derive(Clone)]
175pub(super) struct CheckpointContentsDownloadLimit<S> {
176 inner: S,
177 inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
178 max_inflight_per_checkpoint: usize,
179}
180
181impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
182where
183 S: tower::Service<
184 Request<CheckpointContentsDigest>,
185 Response = Response<Option<FullCheckpointContents>>,
186 Error = Status,
187 >
188 + 'static
189 + Clone
190 + Send,
191 <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
192 Request<CheckpointContentsDigest>: 'static + Send + Sync,
193{
194 type Response = Response<Option<FullCheckpointContents>>;
195 type Error = S::Error;
196 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
197
198 #[inline]
199 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200 self.inner.poll_ready(cx)
201 }
202
203 fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
204 let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
205 let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
206 let mut inner = self.inner.clone();
207
208 let fut = async move {
209 let semaphore = {
210 let semaphore_entry = inflight_per_checkpoint
211 .entry(*req.body())
212 .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
213 semaphore_entry.value().clone()
214 };
215 let permit = semaphore.try_acquire_owned().map_err(|e| match e {
216 tokio::sync::TryAcquireError::Closed => {
217 anemo::rpc::Status::new(StatusCode::InternalServerError)
218 }
219 tokio::sync::TryAcquireError::NoPermits => {
220 anemo::rpc::Status::new(StatusCode::TooManyRequests)
221 }
222 })?;
223
224 struct SemaphoreExtension(#[allow(unused)] OwnedSemaphorePermit);
225 inner.call(req).await.map(move |mut response| {
226 response
228 .extensions_mut()
229 .insert(Arc::new(SemaphoreExtension(permit)));
230 response
231 })
232 };
233 Box::pin(fut)
234 }
235}