mysten_metrics/
metered_channel.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5use async_trait::async_trait;
6use std::future::Future;
7// TODO: complete tests - This kinda sorta facades the whole tokio::mpsc::{Sender, Receiver}: without tests, this will be fragile to maintain.
8use futures::{FutureExt, Stream, TryFutureExt};
9use prometheus::{IntCounter, IntGauge};
10use std::task::{Context, Poll};
11use tokio::sync::mpsc::{
12    self,
13    error::{SendError, TryRecvError, TrySendError},
14};
15
16#[cfg(test)]
17#[path = "tests/metered_channel_tests.rs"]
18mod metered_channel_tests;
19
20/// An [`mpsc::Sender`] with an [`IntGauge`]
21/// counting the number of currently queued items.
22#[derive(Debug)]
23pub struct Sender<T> {
24    inner: mpsc::Sender<T>,
25    gauge: IntGauge,
26}
27
28impl<T> Clone for Sender<T> {
29    fn clone(&self) -> Self {
30        Self {
31            inner: self.inner.clone(),
32            gauge: self.gauge.clone(),
33        }
34    }
35}
36
37impl<T> Sender<T> {
38    pub fn downgrade(&self) -> WeakSender<T> {
39        let sender = self.inner.downgrade();
40        WeakSender {
41            inner: sender,
42            gauge: self.gauge.clone(),
43        }
44    }
45}
46
47/// An [`mpsc::WeakSender`] with an [`IntGauge`]
48/// counting the number of currently queued items.
49#[derive(Debug)]
50pub struct WeakSender<T> {
51    inner: mpsc::WeakSender<T>,
52    gauge: IntGauge,
53}
54
55impl<T> Clone for WeakSender<T> {
56    fn clone(&self) -> Self {
57        Self {
58            inner: self.inner.clone(),
59            gauge: self.gauge.clone(),
60        }
61    }
62}
63
64impl<T> WeakSender<T> {
65    pub fn upgrade(&self) -> Option<Sender<T>> {
66        self.inner.upgrade().map(|s| Sender {
67            inner: s,
68            gauge: self.gauge.clone(),
69        })
70    }
71}
72
73/// An [`mpsc::Receiver`] with an [`IntGauge`]
74/// counting the number of currently queued items.
75#[derive(Debug)]
76pub struct Receiver<T> {
77    inner: mpsc::Receiver<T>,
78    gauge: IntGauge,
79    total: Option<IntCounter>,
80}
81
82impl<T> Receiver<T> {
83    /// Receives the next value for this receiver.
84    /// Decrements the gauge in case of a successful `recv`.
85    pub async fn recv(&mut self) -> Option<T> {
86        self.inner
87            .recv()
88            .inspect(|opt| {
89                if opt.is_some() {
90                    self.gauge.dec();
91                    if let Some(total_gauge) = &self.total {
92                        total_gauge.inc();
93                    }
94                }
95            })
96            .await
97    }
98
99    /// Attempts to receive the next value for this receiver.
100    /// Decrements the gauge in case of a successful `try_recv`.
101    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
102        self.inner.try_recv().inspect(|_| {
103            self.gauge.dec();
104            if let Some(total_gauge) = &self.total {
105                total_gauge.inc();
106            }
107        })
108    }
109
110    pub fn blocking_recv(&mut self) -> Option<T> {
111        self.inner.blocking_recv().inspect(|_| {
112            self.gauge.dec();
113            if let Some(total_gauge) = &self.total {
114                total_gauge.inc();
115            }
116        })
117    }
118
119    /// Closes the receiving half of a channel without dropping it.
120    pub fn close(&mut self) {
121        self.inner.close()
122    }
123
124    /// Polls to receive the next message on this channel.
125    /// Decrements the gauge in case of a successful `poll_recv`.
126    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
127        match self.inner.poll_recv(cx) {
128            res @ Poll::Ready(Some(_)) => {
129                self.gauge.dec();
130                if let Some(total_gauge) = &self.total {
131                    total_gauge.inc();
132                }
133                res
134            }
135            s => s,
136        }
137    }
138}
139
140impl<T> Unpin for Receiver<T> {}
141
142/// A newtype for an `mpsc::Permit` which allows us to inject gauge accounting
143/// in the case the permit is dropped w/o sending
144pub struct Permit<'a, T> {
145    permit: Option<mpsc::Permit<'a, T>>,
146    gauge_ref: &'a IntGauge,
147}
148
149impl<'a, T> Permit<'a, T> {
150    pub fn new(permit: mpsc::Permit<'a, T>, gauge_ref: &'a IntGauge) -> Permit<'a, T> {
151        Permit {
152            permit: Some(permit),
153            gauge_ref,
154        }
155    }
156
157    pub fn send(mut self, value: T) {
158        let sender = self.permit.take().expect("Permit invariant violated!");
159        sender.send(value);
160        // skip the drop logic, see https://github.com/tokio-rs/tokio/blob/a66884a2fb80d1180451706f3c3e006a3fdcb036/tokio/src/sync/mpsc/bounded.rs#L1155-L1163
161        std::mem::forget(self);
162    }
163}
164
165impl<T> Drop for Permit<'_, T> {
166    fn drop(&mut self) {
167        // in the case the permit is dropped without sending, we still want to decrease the occupancy of the channel
168        if self.permit.is_some() {
169            self.gauge_ref.dec();
170        }
171    }
172}
173
174impl<T> Sender<T> {
175    /// Sends a value, waiting until there is capacity.
176    /// Increments the gauge in case of a successful `send`.
177    pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
178        self.inner
179            .send(value)
180            .inspect_ok(|_| self.gauge.inc())
181            .await
182    }
183
184    /// Completes when the receiver has dropped.
185    pub async fn closed(&self) {
186        self.inner.closed().await
187    }
188
189    /// Attempts to immediately send a message on this `Sender`
190    /// Increments the gauge in case of a successful `try_send`.
191    pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
192        self.inner
193            .try_send(message)
194            // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
195            .inspect(|_| {
196                self.gauge.inc();
197            })
198    }
199
200    // TODO: facade [`send_timeout`](tokio::mpsc::Sender::send_timeout) under the tokio feature flag "time"
201    // TODO: facade [`blocking_send`](tokio::mpsc::Sender::blocking_send) under the tokio feature flag "sync"
202
203    /// Checks if the channel has been closed. This happens when the
204    /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is
205    /// called.
206    pub fn is_closed(&self) -> bool {
207        self.inner.is_closed()
208    }
209
210    /// Waits for channel capacity. Once capacity to send one message is
211    /// available, it is reserved for the caller.
212    /// Increments the gauge in case of a successful `reserve`.
213    pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
214        self.inner
215            .reserve()
216            // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
217            .map(|val| {
218                val.map(|permit| {
219                    self.gauge.inc();
220                    Permit::new(permit, &self.gauge)
221                })
222            })
223            .await
224    }
225
226    /// Tries to acquire a slot in the channel without waiting for the slot to become
227    /// available.
228    /// Increments the gauge in case of a successful `try_reserve`.
229    pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
230        self.inner.try_reserve().map(|val| {
231            // remove this unsightly hack once https://github.com/rust-lang/rust/issues/91345 is resolved
232            self.gauge.inc();
233            Permit::new(val, &self.gauge)
234        })
235    }
236
237    // TODO: consider exposing the _owned methods
238
239    // Note: not exposing `same_channel`, as it is hard to implement with callers able to
240    // break the coupling between channel and gauge using `gauge`.
241
242    /// Returns the current capacity of the channel.
243    pub fn capacity(&self) -> usize {
244        self.inner.capacity()
245    }
246
247    // We're voluntarily not putting WeakSender under a facade.
248
249    /// Returns a reference to the underlying gauge.
250    pub fn gauge(&self) -> &IntGauge {
251        &self.gauge
252    }
253}
254
255////////////////////////////////
256// Stream API Wrappers!
257////////////////////////////////
258
259/// A wrapper around [`crate::metered_channel::Receiver`] that implements [`Stream`].
260#[derive(Debug)]
261pub struct ReceiverStream<T> {
262    inner: Receiver<T>,
263}
264
265impl<T> ReceiverStream<T> {
266    /// Create a new `ReceiverStream`.
267    pub fn new(recv: Receiver<T>) -> Self {
268        Self { inner: recv }
269    }
270
271    /// Get back the inner `Receiver`.
272    pub fn into_inner(self) -> Receiver<T> {
273        self.inner
274    }
275
276    /// Closes the receiving half of a channel without dropping it.
277    pub fn close(&mut self) {
278        self.inner.close()
279    }
280}
281
282impl<T> Stream for ReceiverStream<T> {
283    type Item = T;
284
285    fn poll_next(
286        mut self: std::pin::Pin<&mut Self>,
287        cx: &mut Context<'_>,
288    ) -> Poll<Option<Self::Item>> {
289        self.inner.poll_recv(cx)
290    }
291}
292
293impl<T> AsRef<Receiver<T>> for ReceiverStream<T> {
294    fn as_ref(&self) -> &Receiver<T> {
295        &self.inner
296    }
297}
298
299impl<T> AsMut<Receiver<T>> for ReceiverStream<T> {
300    fn as_mut(&mut self) -> &mut Receiver<T> {
301        &mut self.inner
302    }
303}
304
305impl<T> From<Receiver<T>> for ReceiverStream<T> {
306    fn from(recv: Receiver<T>) -> Self {
307        Self::new(recv)
308    }
309}
310
311// TODO: facade PollSender
312// TODO: add prom metrics reporting for gauge and migrate all existing use cases.
313
314////////////////////////////////////////////////////////////////
315// Constructor
316////////////////////////////////////////////////////////////////
317
318/// Similar to `mpsc::channel`, `channel` creates a pair of `Sender` and `Receiver`
319/// Deprecated: use `monitored_mpsc::channel` instead.
320#[track_caller]
321pub fn channel<T>(size: usize, gauge: &IntGauge) -> (Sender<T>, Receiver<T>) {
322    gauge.set(0);
323    let (sender, receiver) = mpsc::channel(size);
324    (
325        Sender {
326            inner: sender,
327            gauge: gauge.clone(),
328        },
329        Receiver {
330            inner: receiver,
331            gauge: gauge.clone(),
332            total: None,
333        },
334    )
335}
336
337/// Deprecated: use `monitored_mpsc::channel` instead.
338#[track_caller]
339pub fn channel_with_total<T>(
340    size: usize,
341    gauge: &IntGauge,
342    total_gauge: &IntCounter,
343) -> (Sender<T>, Receiver<T>) {
344    gauge.set(0);
345    let (sender, receiver) = mpsc::channel(size);
346    (
347        Sender {
348            inner: sender,
349            gauge: gauge.clone(),
350        },
351        Receiver {
352            inner: receiver,
353            gauge: gauge.clone(),
354            total: Some(total_gauge.clone()),
355        },
356    )
357}
358
359#[async_trait]
360pub trait WithPermit<T> {
361    async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)>;
362}
363
364#[async_trait]
365impl<T: Send> WithPermit<T> for Sender<T> {
366    async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)> {
367        let permit = self.reserve().await.ok()?;
368        Some((permit, f.await))
369    }
370}