sui_network/state_sync/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::{PeerHeights, StateSync, StateSyncMessage};
5use anemo::{Request, Response, Result, rpc::Status, types::response::StatusCode};
6use dashmap::DashMap;
7use futures::future::BoxFuture;
8use serde::{Deserialize, Serialize};
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11use sui_types::{
12    digests::{CheckpointContentsDigest, CheckpointDigest},
13    messages_checkpoint::{
14        CertifiedCheckpointSummary as Checkpoint, CheckpointSequenceNumber, FullCheckpointContents,
15        VerifiedCheckpoint,
16    },
17    storage::WriteStore,
18};
19use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
20
21#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
22pub enum GetCheckpointSummaryRequest {
23    Latest,
24    ByDigest(CheckpointDigest),
25    BySequenceNumber(CheckpointSequenceNumber),
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct GetCheckpointAvailabilityResponse {
30    pub(crate) highest_synced_checkpoint: Checkpoint,
31    pub(crate) lowest_available_checkpoint: CheckpointSequenceNumber,
32}
33
34pub(super) struct Server<S> {
35    pub(super) store: S,
36    pub(super) peer_heights: Arc<RwLock<PeerHeights>>,
37    pub(super) sender: mpsc::WeakSender<StateSyncMessage>,
38}
39
40#[anemo::async_trait]
41impl<S> StateSync for Server<S>
42where
43    S: WriteStore + Send + Sync + 'static,
44{
45    async fn push_checkpoint_summary(
46        &self,
47        request: Request<Checkpoint>,
48    ) -> Result<Response<()>, Status> {
49        let peer_id = request
50            .peer_id()
51            .copied()
52            .ok_or_else(|| Status::internal("unable to query sender's PeerId"))?;
53
54        let checkpoint = request.into_inner();
55        if !self
56            .peer_heights
57            .write()
58            .unwrap()
59            .update_peer_info(peer_id, checkpoint.clone(), None)
60        {
61            return Ok(Response::new(()));
62        }
63
64        let highest_verified_checkpoint = *self
65            .store
66            .get_highest_verified_checkpoint()
67            .map_err(|e| Status::internal(e.to_string()))?
68            .sequence_number();
69
70        // If this checkpoint is higher than our highest verified checkpoint notify the
71        // event loop to potentially sync it
72        if *checkpoint.sequence_number() > highest_verified_checkpoint
73            && let Some(sender) = self.sender.upgrade()
74        {
75            sender.send(StateSyncMessage::StartSyncJob).await.unwrap();
76        }
77
78        Ok(Response::new(()))
79    }
80
81    async fn get_checkpoint_summary(
82        &self,
83        request: Request<GetCheckpointSummaryRequest>,
84    ) -> Result<Response<Option<Checkpoint>>, Status> {
85        let checkpoint = match request.inner() {
86            GetCheckpointSummaryRequest::Latest => self
87                .store
88                .get_highest_synced_checkpoint()
89                .map(Some)
90                .map_err(|e| Status::internal(e.to_string()))?,
91            GetCheckpointSummaryRequest::ByDigest(digest) => {
92                self.store.get_checkpoint_by_digest(digest)
93            }
94            GetCheckpointSummaryRequest::BySequenceNumber(sequence_number) => self
95                .store
96                .get_checkpoint_by_sequence_number(*sequence_number),
97        }
98        .map(VerifiedCheckpoint::into_inner);
99
100        Ok(Response::new(checkpoint))
101    }
102
103    async fn get_checkpoint_availability(
104        &self,
105        _request: Request<()>,
106    ) -> Result<Response<GetCheckpointAvailabilityResponse>, Status> {
107        let highest_synced_checkpoint = self
108            .store
109            .get_highest_synced_checkpoint()
110            .map_err(|e| Status::internal(e.to_string()))
111            .map(VerifiedCheckpoint::into_inner)?;
112        let lowest_available_checkpoint = self
113            .store
114            .get_lowest_available_checkpoint()
115            .map_err(|e| Status::internal(e.to_string()))?;
116
117        Ok(Response::new(GetCheckpointAvailabilityResponse {
118            highest_synced_checkpoint,
119            lowest_available_checkpoint,
120        }))
121    }
122
123    async fn get_checkpoint_contents(
124        &self,
125        request: Request<CheckpointContentsDigest>,
126    ) -> Result<Response<Option<FullCheckpointContents>>, Status> {
127        let contents = self
128            .store
129            .get_full_checkpoint_contents(None, request.inner());
130        Ok(Response::new(contents))
131    }
132}
133
134/// [`Layer`] for adding a per-checkpoint limit to the number of inflight GetCheckpointContent
135/// requests.
136#[derive(Clone)]
137pub(super) struct CheckpointContentsDownloadLimitLayer {
138    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
139    max_inflight_per_checkpoint: usize,
140}
141
142impl CheckpointContentsDownloadLimitLayer {
143    pub(super) fn new(max_inflight_per_checkpoint: usize) -> Self {
144        Self {
145            inflight_per_checkpoint: Arc::new(DashMap::new()),
146            max_inflight_per_checkpoint,
147        }
148    }
149
150    pub(super) fn maybe_prune_map(&self) {
151        const PRUNE_THRESHOLD: usize = 5000;
152        if self.inflight_per_checkpoint.len() >= PRUNE_THRESHOLD {
153            self.inflight_per_checkpoint.retain(|_, semaphore| {
154                semaphore.available_permits() < self.max_inflight_per_checkpoint
155            });
156        }
157    }
158}
159
160impl<S> tower::layer::Layer<S> for CheckpointContentsDownloadLimitLayer {
161    type Service = CheckpointContentsDownloadLimit<S>;
162
163    fn layer(&self, inner: S) -> Self::Service {
164        CheckpointContentsDownloadLimit {
165            inner,
166            inflight_per_checkpoint: self.inflight_per_checkpoint.clone(),
167            max_inflight_per_checkpoint: self.max_inflight_per_checkpoint,
168        }
169    }
170}
171
172/// Middleware for adding a per-checkpoint limit to the number of inflight GetCheckpointContent
173/// requests.
174#[derive(Clone)]
175pub(super) struct CheckpointContentsDownloadLimit<S> {
176    inner: S,
177    inflight_per_checkpoint: Arc<DashMap<CheckpointContentsDigest, Arc<Semaphore>>>,
178    max_inflight_per_checkpoint: usize,
179}
180
181impl<S> tower::Service<Request<CheckpointContentsDigest>> for CheckpointContentsDownloadLimit<S>
182where
183    S: tower::Service<
184            Request<CheckpointContentsDigest>,
185            Response = Response<Option<FullCheckpointContents>>,
186            Error = Status,
187        >
188        + 'static
189        + Clone
190        + Send,
191    <S as tower::Service<Request<CheckpointContentsDigest>>>::Future: Send,
192    Request<CheckpointContentsDigest>: 'static + Send + Sync,
193{
194    type Response = Response<Option<FullCheckpointContents>>;
195    type Error = S::Error;
196    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
197
198    #[inline]
199    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200        self.inner.poll_ready(cx)
201    }
202
203    fn call(&mut self, req: Request<CheckpointContentsDigest>) -> Self::Future {
204        let inflight_per_checkpoint = self.inflight_per_checkpoint.clone();
205        let max_inflight_per_checkpoint = self.max_inflight_per_checkpoint;
206        let mut inner = self.inner.clone();
207
208        let fut = async move {
209            let semaphore = {
210                let semaphore_entry = inflight_per_checkpoint
211                    .entry(*req.body())
212                    .or_insert_with(|| Arc::new(Semaphore::new(max_inflight_per_checkpoint)));
213                semaphore_entry.value().clone()
214            };
215            let permit = semaphore.try_acquire_owned().map_err(|e| match e {
216                tokio::sync::TryAcquireError::Closed => {
217                    anemo::rpc::Status::new(StatusCode::InternalServerError)
218                }
219                tokio::sync::TryAcquireError::NoPermits => {
220                    anemo::rpc::Status::new(StatusCode::TooManyRequests)
221                }
222            })?;
223
224            struct SemaphoreExtension(#[allow(unused)] OwnedSemaphorePermit);
225            inner.call(req).await.map(move |mut response| {
226                // Insert permit as extension so it's not dropped until the response is sent.
227                response
228                    .extensions_mut()
229                    .insert(Arc::new(SemaphoreExtension(permit)));
230                response
231            })
232        };
233        Box::pin(fut)
234    }
235}