sui_network/state_sync/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{PeerHeights, StateSync, StateSyncMessage};
5use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
6use bytes::Bytes;
7use dashmap::DashMap;
8use futures::future::BoxFuture;
9use serde::{Deserialize, Serialize};
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12use sui_types::messages_checkpoint::VersionedFullCheckpointContents;
13use sui_types::{
14    digests::{CheckpointContentsDigest, CheckpointDigest},
15    messages_checkpoint::{
16        CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
17        VerifiedCheckpoint,
18    },
19    storage::WriteStore,
20};
21use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
22
23#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
24pub enum GetCheckpointSummaryRequest {
25    Latest,
26    ByDigest(CheckpointDigest),
27    BySequenceNumber(CheckpointSequenceNumber),
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct GetCheckpointAvailabilityResponse {
32    pub(crate) highest_synced_checkpoint: Checkpoint,
33    pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
34}
35
36pub(super) struct Server<S> {
37    pub(super) store: S,
38    pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
39    pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
40    pub(super) max_checkpoint_lookahead: u64,
41}
42
43#[anemo::async_trait]
44impl<S> StateSync for Server<S>
45where
46    S: WriteStore + Send + Sync + 'static,
47{
48    async fn push_checkpoint_summary(
49        &self,
50        request: Request<Checkpoint>,
51    ) -> Result<Response<()>, Status> {
52        let peer_id = request
53            .peer_id()
54            .copied()
55            .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
56
57        let checkpoint = request.into_inner();
58        let checkpoint_seq = *checkpoint.sequence_number();
59
60        let highest_verified_checkpoint = *self
61            .store
62            .get_highest_verified_checkpoint()
63            .map_err(|e| Status::internal(e.to_string()))?
64            .sequence_number();
65
66        {
67            let mut peer_heights = self.peer_heights.write().unwrap();
68
69            // Always update the peer's height so we know what they claim to have,
70            // even if we don't store the checkpoint itself.
71            let peer_on_same_chain = peer_heights.update_peer_height(peer_id, checkpoint_seq, None);
72            if !peer_on_same_chain {
73                return Ok(Response::new(()));
74            }
75
76            if checkpoint_seq
77                <= highest_verified_checkpoint.saturating_add(self.max_checkpoint_lookahead)
78            {
79                peer_heights.insert_checkpoint(checkpoint);
80            } else {
81                tracing::debug!(
82                    peer_id = ?peer_id,
83                    checkpoint_seq = checkpoint_seq,
84                    highest_verified = highest_verified_checkpoint,
85                    max_lookahead = self.max_checkpoint_lookahead,
86                    "not storing checkpoint summary that exceeds max lookahead"
87                );
88            }
89        }
90
91        // If this checkpoint is higher than our highest verified checkpoint notify the
92        // event loop to potentially sync it
93        if checkpoint_seq > highest_verified_checkpoint
94            && let Some(sender) = self.sender.upgrade()
95        {
96            sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
97        }
98
99        Ok(Response::new(()))
100    }
101
102    async fn get_checkpoint_summary(
103        &self,
104        request: Request<GetCheckpointSummaryRequest>,
105    ) -> Result<Response<Option<Checkpoint>>, Status> {
106        let checkpoint = match request.inner() {
107            GetCheckpointSummaryRequest::Latest => self
108                .store
109                .get_highest_synced_checkpoint()
110                .map(Some)
111                .map_err(|e| Status::internal(e.to_string()))?,
112            GetCheckpointSummaryRequest::ByDigest(digest) => {
113                self.store.get_checkpoint_by_digest(digest)
114            }
115            GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
116                .store
117                .get_checkpoint_by_sequence_number(*sequence_number),
118        }
119        .map(VerifiedCheckpoint::into_inner);
120
121        Ok(Response::new(checkpoint))
122    }
123
124    async fn get_checkpoint_availability(
125        &self,
126        _request: Request<()>,
127    ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
128        let highest_synced_checkpoint = self
129            .store
130            .get_highest_synced_checkpoint()
131            .map_err(|e| Status::internal(e.to_string()))
132            .map(VerifiedCheckpoint::into_inner)?;
133        let lowest_available_checkpoint = self
134            .store
135            .get_lowest_available_checkpoint()
136            .map_err(|e| Status::internal(e.to_string()))?;
137
138        Ok(Response::new(GetCheckpointAvailabilityResponse {
139            highest_synced_checkpoint,
140            lowest_available_checkpoint,
141        }))
142    }
143
144    async fn get_checkpoint_contents(
145        &self,
146        request: Request<CheckpointContentsDigest>,
147    ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
148        let contents = self
149            .store
150            .get_full_checkpoint_contents(None, request.inner());
151        Ok(Response::new(contents.map(|v| v.into_v1())))
152    }
153
154    async fn get_checkpoint_contents_v2(
155        &self,
156        request: Request<CheckpointContentsDigest>,
157    ) -> Result<Response<Option<VersionedFullCheckpointContents>>, Status> {
158        let contents = self
159            .store
160            .get_full_checkpoint_contents(None, request.inner());
161        Ok(Response::new(contents))
162    }
163}
164
165/// [`Layer`] for limiting the size of incoming requests.
166#[derive(Clone)]
167pub(super) struct SizeLimitLayer {
168    max_size: usize,
169}
170
171impl SizeLimitLayer {
172    pub(super) fn new(max_size: usize) -> Self {
173        Self { max_size }
174    }
175}
176
177impl<S> tower::layer::Layer<S> for SizeLimitLayer {
178    type Service = SizeLimit<S>;
179
180    fn layer(&self, inner: S) -> Self::Service {
181        SizeLimit {
182            inner,
183            max_size: self.max_size,
184        }
185    }
186}
187
188/// Middleware for limiting the size of incoming requests.
189#[derive(Clone)]
190pub(super) struct SizeLimit<S> {
191    inner: S,
192    max_size: usize,
193}
194
195impl<S> tower::Service<Request<Bytes>> for SizeLimit<S>
196where
197    S: tower::Service<Request<Bytes>, Response = Response<Bytes>> + 'static + Clone + Send,
198    <S as tower::Service<Request<Bytes>>>::Future: Send,
199{
200    type Response = Response<Bytes>;
201    type Error = S::Error;
202    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
203
204    #[inline]
205    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206        self.inner.poll_ready(cx)
207    }
208
209    fn call(&mut self, req: Request<Bytes>) -> Self::Future {
210        let body_size = req.body().len();
211        let max_size = self.max_size;
212        if body_size > max_size {
213            let peer_id = req.peer_id().copied();
214            tracing::info!(
215                ?peer_id,
216                body_size,
217                max_size,
218                "rejecting request that exceeds max size"
219            );
220            return Box::pin(async move { Ok(Response::new(Bytes::new())) });
221        }
222        let mut inner = self.inner.clone();
223        Box::pin(async move { inner.call(req).await })
224    }
225}
226
227/// [`Layer`] for adding a per-checkpoint limit to the number of inflight GetCheckpointContent
228/// requests.
229#[derive(Clone)]
230pub(super) struct CheckpointContentsDownloadLimitLayer {
231    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
232    max_inflight_per_checkpoint: usize,
233}
234
235impl CheckpointContentsDownloadLimitLayer {
236    pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
237        Self {
238            inflight_per_checkpoint: Arc::new(DashMap::new()),
239            max_inflight_per_checkpoint,
240        }
241    }
242
243    pub(super) fn maybe_prune_map(&self) {
244        const PRUNE_THRESHOLD: usize = 5000;
245        if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
246            self.inflight_per_checkpoint.retain(|_, semaphore| {
247                semaphore.available_permits() < self.max_inflight_per_checkpoint
248            });
249        }
250    }
251}
252
253impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
254    type Service = CheckpointContentsDownloadLimit<S>;
255
256    fn layer(&self, inner: S) -> Self::Service {
257        CheckpointContentsDownloadLimit {
258            inner,
259            inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
260            max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
261        }
262    }
263}
264
265/// Middleware for adding a per-checkpoint limit to the number of inflight GetCheckpointContent
266/// requests.
267#[derive(Clone)]
268pub(super) struct CheckpointContentsDownloadLimit<S> {
269    inner: S,
270    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
271    max_inflight_per_checkpoint: usize,
272}
273
274impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
275where
276    S: tower::Service<
277            Request<CheckpointContentsDigest>,
278            Response = Response<Option<FullCheckpointContents>>,
279            Error = Status,
280        >
281        + 'static
282        + Clone
283        + Send,
284    <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
285    Request<CheckpointContentsDigest>: 'static + Send + Sync,
286{
287    type Response = Response<Option<FullCheckpointContents>>;
288    type Error = S::Error;
289    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
290
291    #[inline]
292    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
293        self.inner.poll_ready(cx)
294    }
295
296    fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
297        let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
298        let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
299        let mut inner = self.inner.clone();
300
301        let fut = async move {
302            let semaphore = {
303                let semaphore_entry = inflight_per_checkpoint
304                    .entry(*req.body())
305                    .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
306                semaphore_entry.value().clone()
307            };
308            let permit = semaphore.try_acquire_owned().map_err(|e| match e {
309                tokio::sync::TryAcquireError::Closed => {
310                    anemo::rpc::Status::new(StatusCode::InternalServerError)
311                }
312                tokio::sync::TryAcquireError::NoPermits => {
313                    anemo::rpc::Status::new(StatusCode::TooManyRequests)
314                }
315            })?;
316
317            struct SemaphoreExtension(#[allow(unused)] OwnedSemaphorePermit);
318            inner.call(req).await.map(move |mut response| {
319                // Insert permit as extension so it's not dropped until the response is sent.
320                response
321                    .extensions_mut()
322                    .insert(Arc::new(SemaphoreExtension(permit)));
323                response
324            })
325        };
326        Box::pin(fut)
327    }
328}