sui_network/state_sync/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{PeerHeights, StateSync, StateSyncMessage};
5use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
6use dashmap::DashMap;
7use futures::future::BoxFuture;
8use serde::{Deserialize, Serialize};
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11use sui_types::messages_checkpoint::VersionedFullCheckpointContents;
12use sui_types::{
13    digests::{CheckpointContentsDigest, CheckpointDigest},
14    messages_checkpoint::{
15        CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
16        VerifiedCheckpoint,
17    },
18    storage::WriteStore,
19};
20use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
21
22#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
23pub enum GetCheckpointSummaryRequest {
24    Latest,
25    ByDigest(CheckpointDigest),
26    BySequenceNumber(CheckpointSequenceNumber),
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct GetCheckpointAvailabilityResponse {
31    pub(crate) highest_synced_checkpoint: Checkpoint,
32    pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
33}
34
35pub(super) struct Server<S> {
36    pub(super) store: S,
37    pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
38    pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
39}
40
41#[anemo::async_trait]
42impl<S> StateSync for Server<S>
43where
44    S: WriteStore + Send + Sync + 'static,
45{
46    async fn push_checkpoint_summary(
47        &self,
48        request: Request<Checkpoint>,
49    ) -> Result<Response<()>, Status> {
50        let peer_id = request
51            .peer_id()
52            .copied()
53            .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
54
55        let checkpoint = request.into_inner();
56        if !self
57            .peer_heights
58            .write()
59            .unwrap()
60            .update_peer_info(peer_id, checkpoint.clone(), None)
61        {
62            return Ok(Response::new(()));
63        }
64
65        let highest_verified_checkpoint = *self
66            .store
67            .get_highest_verified_checkpoint()
68            .map_err(|e| Status::internal(e.to_string()))?
69            .sequence_number();
70
71        // If this checkpoint is higher than our highest verified checkpoint notify the
72        // event loop to potentially sync it
73        if *checkpoint.sequence_number() > highest_verified_checkpoint
74            && let Some(sender) = self.sender.upgrade()
75        {
76            sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
77        }
78
79        Ok(Response::new(()))
80    }
81
82    async fn get_checkpoint_summary(
83        &self,
84        request: Request<GetCheckpointSummaryRequest>,
85    ) -> Result<Response<Option<Checkpoint>>, Status> {
86        let checkpoint = match request.inner() {
87            GetCheckpointSummaryRequest::Latest => self
88                .store
89                .get_highest_synced_checkpoint()
90                .map(Some)
91                .map_err(|e| Status::internal(e.to_string()))?,
92            GetCheckpointSummaryRequest::ByDigest(digest) => {
93                self.store.get_checkpoint_by_digest(digest)
94            }
95            GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
96                .store
97                .get_checkpoint_by_sequence_number(*sequence_number),
98        }
99        .map(VerifiedCheckpoint::into_inner);
100
101        Ok(Response::new(checkpoint))
102    }
103
104    async fn get_checkpoint_availability(
105        &self,
106        _request: Request<()>,
107    ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
108        let highest_synced_checkpoint = self
109            .store
110            .get_highest_synced_checkpoint()
111            .map_err(|e| Status::internal(e.to_string()))
112            .map(VerifiedCheckpoint::into_inner)?;
113        let lowest_available_checkpoint = self
114            .store
115            .get_lowest_available_checkpoint()
116            .map_err(|e| Status::internal(e.to_string()))?;
117
118        Ok(Response::new(GetCheckpointAvailabilityResponse {
119            highest_synced_checkpoint,
120            lowest_available_checkpoint,
121        }))
122    }
123
124    async fn get_checkpoint_contents(
125        &self,
126        request: Request<CheckpointContentsDigest>,
127    ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
128        let contents = self
129            .store
130            .get_full_checkpoint_contents(None, request.inner());
131        Ok(Response::new(contents.map(|v| v.into_v1())))
132    }
133
134    async fn get_checkpoint_contents_v2(
135        &self,
136        request: Request<CheckpointContentsDigest>,
137    ) -> Result<Response<Option<VersionedFullCheckpointContents>>, Status> {
138        let contents = self
139            .store
140            .get_full_checkpoint_contents(None, request.inner());
141        Ok(Response::new(contents))
142    }
143}
144
145/// [`Layer`] for adding a per-checkpoint limit to the number of inflight GetCheckpointContent
146/// requests.
147#[derive(Clone)]
148pub(super) struct CheckpointContentsDownloadLimitLayer {
149    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
150    max_inflight_per_checkpoint: usize,
151}
152
153impl CheckpointContentsDownloadLimitLayer {
154    pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
155        Self {
156            inflight_per_checkpoint: Arc::new(DashMap::new()),
157            max_inflight_per_checkpoint,
158        }
159    }
160
161    pub(super) fn maybe_prune_map(&self) {
162        const PRUNE_THRESHOLD: usize = 5000;
163        if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
164            self.inflight_per_checkpoint.retain(|_, semaphore| {
165                semaphore.available_permits() < self.max_inflight_per_checkpoint
166            });
167        }
168    }
169}
170
171impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
172    type Service = CheckpointContentsDownloadLimit<S>;
173
174    fn layer(&self, inner: S) -> Self::Service {
175        CheckpointContentsDownloadLimit {
176            inner,
177            inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
178            max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
179        }
180    }
181}
182
183/// Middleware for adding a per-checkpoint limit to the number of inflight GetCheckpointContent
184/// requests.
185#[derive(Clone)]
186pub(super) struct CheckpointContentsDownloadLimit<S> {
187    inner: S,
188    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
189    max_inflight_per_checkpoint: usize,
190}
191
192impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
193where
194    S: tower::Service<
195            Request<CheckpointContentsDigest>,
196            Response = Response<Option<FullCheckpointContents>>,
197            Error = Status,
198        >
199        + 'static
200        + Clone
201        + Send,
202    <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
203    Request<CheckpointContentsDigest>: 'static + Send + Sync,
204{
205    type Response = Response<Option<FullCheckpointContents>>;
206    type Error = S::Error;
207    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
208
209    #[inline]
210    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211        self.inner.poll_ready(cx)
212    }
213
214    fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
215        let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
216        let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
217        let mut inner = self.inner.clone();
218
219        let fut = async move {
220            let semaphore = {
221                let semaphore_entry = inflight_per_checkpoint
222                    .entry(*req.body())
223                    .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
224                semaphore_entry.value().clone()
225            };
226            let permit = semaphore.try_acquire_owned().map_err(|e| match e {
227                tokio::sync::TryAcquireError::Closed => {
228                    anemo::rpc::Status::new(StatusCode::InternalServerError)
229                }
230                tokio::sync::TryAcquireError::NoPermits => {
231                    anemo::rpc::Status::new(StatusCode::TooManyRequests)
232                }
233            })?;
234
235            struct SemaphoreExtension(#[allow(unused)] OwnedSemaphorePermit);
236            inner.call(req).await.map(move |mut response| {
237                // Insert permit as extension so it's not dropped until the response is sent.
238                response
239                    .extensions_mut()
240                    .insert(Arc::new(SemaphoreExtension(permit)));
241                response
242            })
243        };
244        Box::pin(fut)
245    }
246}