1use std::task::{Context, Poll};
7
8use futures::{Future, TryFutureExt as _};
9use prometheus::IntGauge;
10use tap::Tap;
11use tokio::sync::mpsc::{
12 self,
13 error::{SendError, TryRecvError, TrySendError},
14};
15
16use crate::get_metrics;
17
18#[derive(Debug)]
20pub struct Sender<T> {
21 inner: mpsc::Sender<T>,
22 inflight: Option<IntGauge>,
23 sent: Option<IntGauge>,
24}
25
26impl<T> Sender<T> {
27 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
30 self.inner
31 .send(value)
32 .inspect_ok(|_| {
33 if let Some(inflight) = &self.inflight {
34 inflight.inc();
35 }
36 if let Some(sent) = &self.sent {
37 sent.inc();
38 }
39 })
40 .await
41 }
42
43 pub async fn closed(&self) {
45 self.inner.closed().await
46 }
47
48 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
51 self.inner
52 .try_send(message)
53 .map(|_| {
55 if let Some(inflight) = &self.inflight {
56 inflight.inc();
57 }
58 if let Some(sent) = &self.sent {
59 sent.inc();
60 }
61 })
62 }
63
64 pub fn is_closed(&self) -> bool {
71 self.inner.is_closed()
72 }
73
74 pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
78 self.inner.reserve().await.map(|permit| {
79 if let Some(inflight) = &self.inflight {
80 inflight.inc();
81 }
82 Permit::new(permit, &self.inflight, &self.sent)
83 })
84 }
85
86 pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
90 self.inner.try_reserve().map(|val| {
91 if let Some(inflight) = &self.inflight {
92 inflight.inc();
93 }
94 Permit::new(val, &self.inflight, &self.sent)
95 })
96 }
97
98 pub fn capacity(&self) -> usize {
102 self.inner.capacity()
103 }
104
105 pub fn downgrade(&self) -> WeakSender<T> {
106 let sender = self.inner.downgrade();
107 WeakSender {
108 inner: sender,
109 inflight: self.inflight.clone(),
110 sent: self.sent.clone(),
111 }
112 }
113
114 #[cfg(test)]
116 fn inflight(&self) -> &IntGauge {
117 self.inflight
118 .as_ref()
119 .expect("Metrics should have initialized")
120 }
121
122 #[cfg(test)]
124 fn sent(&self) -> &IntGauge {
125 self.sent.as_ref().expect("Metrics should have initialized")
126 }
127}
128
129impl<T> Clone for Sender<T> {
131 fn clone(&self) -> Self {
132 Self {
133 inner: self.inner.clone(),
134 inflight: self.inflight.clone(),
135 sent: self.sent.clone(),
136 }
137 }
138}
139
140pub struct Permit<'a, T> {
143 permit: Option<mpsc::Permit<'a, T>>,
144 inflight_ref: &'a Option<IntGauge>,
145 sent_ref: &'a Option<IntGauge>,
146}
147
148impl<'a, T> Permit<'a, T> {
149 pub fn new(
150 permit: mpsc::Permit<'a, T>,
151 inflight_ref: &'a Option<IntGauge>,
152 sent_ref: &'a Option<IntGauge>,
153 ) -> Permit<'a, T> {
154 Permit {
155 permit: Some(permit),
156 inflight_ref,
157 sent_ref,
158 }
159 }
160
161 pub fn send(mut self, value: T) {
162 let sender = self.permit.take().expect("Permit invariant violated!");
163 sender.send(value);
164 if let Some(sent_ref) = self.sent_ref {
165 sent_ref.inc();
166 }
167 std::mem::forget(self);
169 }
170}
171
172impl<T> Drop for Permit<'_, T> {
173 fn drop(&mut self) {
174 if self.permit.is_some()
177 && let Some(inflight_ref) = self.inflight_ref
178 {
179 inflight_ref.dec();
180 }
181 }
182}
183
184#[async_trait::async_trait]
185pub trait WithPermit<T> {
186 async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)>
187 where
188 T: 'static;
189}
190
191#[async_trait::async_trait]
192impl<T: Send> WithPermit<T> for Sender<T> {
193 async fn with_permit<F: Future + Send>(&self, f: F) -> Option<(Permit<T>, F::Output)> {
194 let permit = self.reserve().await.ok()?;
195 Some((permit, f.await))
196 }
197}
198
199#[derive(Debug)]
201pub struct WeakSender<T> {
202 inner: mpsc::WeakSender<T>,
203 inflight: Option<IntGauge>,
204 sent: Option<IntGauge>,
205}
206
207impl<T> WeakSender<T> {
208 pub fn upgrade(&self) -> Option<Sender<T>> {
209 self.inner.upgrade().map(|s| Sender {
210 inner: s,
211 inflight: self.inflight.clone(),
212 sent: self.sent.clone(),
213 })
214 }
215}
216
217impl<T> Clone for WeakSender<T> {
219 fn clone(&self) -> Self {
220 Self {
221 inner: self.inner.clone(),
222 inflight: self.inflight.clone(),
223 sent: self.sent.clone(),
224 }
225 }
226}
227
228#[derive(Debug)]
230pub struct Receiver<T> {
231 inner: mpsc::Receiver<T>,
232 inflight: Option<IntGauge>,
233 received: Option<IntGauge>,
234}
235
236impl<T> Receiver<T> {
237 pub async fn recv(&mut self) -> Option<T> {
240 self.inner.recv().await.tap(|opt| {
241 if opt.is_some() {
242 if let Some(inflight) = &self.inflight {
243 inflight.dec();
244 }
245 if let Some(received) = &self.received {
246 received.inc();
247 }
248 }
249 })
250 }
251
252 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
255 self.inner.try_recv().inspect(|_| {
256 if let Some(inflight) = &self.inflight {
257 inflight.dec();
258 }
259 if let Some(received) = &self.received {
260 received.inc();
261 }
262 })
263 }
264
265 pub fn blocking_recv(&mut self) -> Option<T> {
266 self.inner.blocking_recv().inspect(|_| {
267 if let Some(inflight) = &self.inflight {
268 inflight.dec();
269 }
270 if let Some(received) = &self.received {
271 received.inc();
272 }
273 })
274 }
275
276 pub fn close(&mut self) {
278 self.inner.close()
279 }
280
281 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
284 match self.inner.poll_recv(cx) {
285 res @ Poll::Ready(Some(_)) => {
286 if let Some(inflight) = &self.inflight {
287 inflight.dec();
288 }
289 if let Some(received) = &self.received {
290 received.inc();
291 }
292 res
293 }
294 s => s,
295 }
296 }
297
298 #[cfg(test)]
300 fn received(&self) -> &IntGauge {
301 self.received
302 .as_ref()
303 .expect("Metrics should have initialized")
304 }
305}
306
307impl<T> Unpin for Receiver<T> {}
308
309pub fn channel<T>(name: &str, size: usize) -> (Sender<T>, Receiver<T>) {
311 let metrics = get_metrics();
312 let (sender, receiver) = mpsc::channel(size);
313 (
314 Sender {
315 inner: sender,
316 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
317 sent: metrics.map(|m| m.channel_sent.with_label_values(&[name])),
318 },
319 Receiver {
320 inner: receiver,
321 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
322 received: metrics.map(|m| m.channel_received.with_label_values(&[name])),
323 },
324 )
325}
326
327#[derive(Debug)]
329pub struct UnboundedSender<T> {
330 inner: mpsc::UnboundedSender<T>,
331 inflight: Option<IntGauge>,
332 sent: Option<IntGauge>,
333}
334
335impl<T> UnboundedSender<T> {
336 pub fn send(&self, value: T) -> Result<(), SendError<T>> {
339 self.inner.send(value).map(|_| {
340 if let Some(inflight) = &self.inflight {
341 inflight.inc();
342 }
343 if let Some(sent) = &self.sent {
344 sent.inc();
345 }
346 })
347 }
348
349 pub async fn closed(&self) {
351 self.inner.closed().await
352 }
353
354 pub fn is_closed(&self) -> bool {
358 self.inner.is_closed()
359 }
360
361 pub fn downgrade(&self) -> WeakUnboundedSender<T> {
362 let sender = self.inner.downgrade();
363 WeakUnboundedSender {
364 inner: sender,
365 inflight: self.inflight.clone(),
366 sent: self.sent.clone(),
367 }
368 }
369
370 #[cfg(test)]
372 fn inflight(&self) -> &IntGauge {
373 self.inflight
374 .as_ref()
375 .expect("Metrics should have initialized")
376 }
377
378 #[cfg(test)]
380 fn sent(&self) -> &IntGauge {
381 self.sent.as_ref().expect("Metrics should have initialized")
382 }
383}
384
385impl<T> Clone for UnboundedSender<T> {
387 fn clone(&self) -> Self {
388 Self {
389 inner: self.inner.clone(),
390 inflight: self.inflight.clone(),
391 sent: self.sent.clone(),
392 }
393 }
394}
395
396#[derive(Debug)]
398pub struct WeakUnboundedSender<T> {
399 inner: mpsc::WeakUnboundedSender<T>,
400 inflight: Option<IntGauge>,
401 sent: Option<IntGauge>,
402}
403
404impl<T> WeakUnboundedSender<T> {
405 pub fn upgrade(&self) -> Option<UnboundedSender<T>> {
406 self.inner.upgrade().map(|s| UnboundedSender {
407 inner: s,
408 inflight: self.inflight.clone(),
409 sent: self.sent.clone(),
410 })
411 }
412}
413
414impl<T> Clone for WeakUnboundedSender<T> {
416 fn clone(&self) -> Self {
417 Self {
418 inner: self.inner.clone(),
419 inflight: self.inflight.clone(),
420 sent: self.sent.clone(),
421 }
422 }
423}
424
425#[derive(Debug)]
427pub struct UnboundedReceiver<T> {
428 inner: mpsc::UnboundedReceiver<T>,
429 inflight: Option<IntGauge>,
430 received: Option<IntGauge>,
431}
432
433impl<T> UnboundedReceiver<T> {
434 pub async fn recv(&mut self) -> Option<T> {
437 self.inner.recv().await.tap(|opt| {
438 if opt.is_some() {
439 if let Some(inflight) = &self.inflight {
440 inflight.dec();
441 }
442 if let Some(received) = &self.received {
443 received.inc();
444 }
445 }
446 })
447 }
448
449 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
452 self.inner.try_recv().inspect(|_| {
453 if let Some(inflight) = &self.inflight {
454 inflight.dec();
455 }
456 if let Some(received) = &self.received {
457 received.inc();
458 }
459 })
460 }
461
462 pub fn blocking_recv(&mut self) -> Option<T> {
463 self.inner.blocking_recv().inspect(|_| {
464 if let Some(inflight) = &self.inflight {
465 inflight.dec();
466 }
467 if let Some(received) = &self.received {
468 received.inc();
469 }
470 })
471 }
472
473 pub fn close(&mut self) {
475 self.inner.close()
476 }
477
478 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
481 match self.inner.poll_recv(cx) {
482 res @ Poll::Ready(Some(_)) => {
483 if let Some(inflight) = &self.inflight {
484 inflight.dec();
485 }
486 if let Some(received) = &self.received {
487 received.inc();
488 }
489 res
490 }
491 s => s,
492 }
493 }
494
495 #[cfg(test)]
497 fn received(&self) -> &IntGauge {
498 self.received
499 .as_ref()
500 .expect("Metrics should have initialized")
501 }
502}
503
504impl<T> Unpin for UnboundedReceiver<T> {}
505
506pub fn unbounded_channel<T>(name: &str) -> (UnboundedSender<T>, UnboundedReceiver<T>) {
508 let metrics = get_metrics();
509 #[allow(clippy::disallowed_methods)]
510 let (sender, receiver) = mpsc::unbounded_channel();
511 (
512 UnboundedSender {
513 inner: sender,
514 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
515 sent: metrics.map(|m| m.channel_sent.with_label_values(&[name])),
516 },
517 UnboundedReceiver {
518 inner: receiver,
519 inflight: metrics.map(|m| m.channel_inflight.with_label_values(&[name])),
520 received: metrics.map(|m| m.channel_received.with_label_values(&[name])),
521 },
522 )
523}
524
525#[cfg(test)]
526mod test {
527 use std::task::{Context, Poll};
528
529 use futures::{FutureExt as _, task::noop_waker};
530 use prometheus::Registry;
531 use tokio::sync::mpsc::error::TrySendError;
532
533 use crate::{
534 init_metrics,
535 monitored_mpsc::{channel, unbounded_channel},
536 };
537
538 #[tokio::test]
539 async fn test_bounded_send_and_receive() {
540 init_metrics(&Registry::new());
541 let (tx, mut rx) = channel("test_bounded_send_and_receive", 8);
542 let inflight = tx.inflight();
543 let sent = tx.sent();
544 let received = rx.received().clone();
545
546 assert_eq!(inflight.get(), 0);
547 let item = 42;
548 tx.send(item).await.unwrap();
549 assert_eq!(inflight.get(), 1);
550 assert_eq!(sent.get(), 1);
551 assert_eq!(received.get(), 0);
552
553 let received_item = rx.recv().await.unwrap();
554 assert_eq!(received_item, item);
555 assert_eq!(inflight.get(), 0);
556 assert_eq!(sent.get(), 1);
557 assert_eq!(received.get(), 1);
558 }
559
560 #[tokio::test]
561 async fn test_try_send() {
562 init_metrics(&Registry::new());
563 let (tx, mut rx) = channel("test_try_send", 1);
564 let inflight = tx.inflight();
565 let sent = tx.sent();
566 let received = rx.received().clone();
567
568 assert_eq!(inflight.get(), 0);
569 assert_eq!(sent.get(), 0);
570 assert_eq!(received.get(), 0);
571
572 let item = 42;
573 tx.try_send(item).unwrap();
574 assert_eq!(inflight.get(), 1);
575 assert_eq!(sent.get(), 1);
576 assert_eq!(received.get(), 0);
577
578 let received_item = rx.recv().await.unwrap();
579 assert_eq!(received_item, item);
580 assert_eq!(inflight.get(), 0);
581 assert_eq!(sent.get(), 1);
582 assert_eq!(received.get(), 1);
583 }
584
585 #[tokio::test]
586 async fn test_try_send_full() {
587 init_metrics(&Registry::new());
588 let (tx, mut rx) = channel("test_try_send_full", 2);
589 let inflight = tx.inflight();
590 let sent = tx.sent();
591 let received = rx.received().clone();
592
593 assert_eq!(inflight.get(), 0);
594
595 let item = 42;
596 tx.try_send(item).unwrap();
597 assert_eq!(inflight.get(), 1);
598 assert_eq!(sent.get(), 1);
599 assert_eq!(received.get(), 0);
600
601 tx.try_send(item).unwrap();
602 assert_eq!(inflight.get(), 2);
603 assert_eq!(sent.get(), 2);
604 assert_eq!(received.get(), 0);
605
606 if let Err(e) = tx.try_send(item) {
607 assert!(matches!(e, TrySendError::Full(_)));
608 } else {
609 panic!("Expect try_send return channel being full error");
610 }
611 assert_eq!(inflight.get(), 2);
612 assert_eq!(sent.get(), 2);
613 assert_eq!(received.get(), 0);
614
615 let received_item = rx.recv().await.unwrap();
616 assert_eq!(received_item, item);
617 assert_eq!(inflight.get(), 1);
618 assert_eq!(sent.get(), 2);
619 assert_eq!(received.get(), 1);
620
621 let received_item = rx.recv().await.unwrap();
622 assert_eq!(received_item, item);
623 assert_eq!(inflight.get(), 0);
624 assert_eq!(sent.get(), 2);
625 assert_eq!(received.get(), 2);
626 }
627
628 #[tokio::test]
629 async fn test_unbounded_send_and_receive() {
630 init_metrics(&Registry::new());
631 let (tx, mut rx) = unbounded_channel("test_unbounded_send_and_receive");
632 let inflight = tx.inflight();
633 let sent = tx.sent();
634 let received = rx.received().clone();
635
636 assert_eq!(inflight.get(), 0);
637 let item = 42;
638 tx.send(item).unwrap();
639 assert_eq!(inflight.get(), 1);
640 assert_eq!(sent.get(), 1);
641 assert_eq!(received.get(), 0);
642
643 let received_item = rx.recv().await.unwrap();
644 assert_eq!(received_item, item);
645 assert_eq!(inflight.get(), 0);
646 assert_eq!(sent.get(), 1);
647 assert_eq!(received.get(), 1);
648 }
649
650 #[tokio::test]
651 async fn test_empty_closed_channel() {
652 init_metrics(&Registry::new());
653 let (tx, mut rx) = channel("test_empty_closed_channel", 8);
654 let inflight = tx.inflight();
655 let received = rx.received().clone();
656
657 assert_eq!(inflight.get(), 0);
658 let item = 42;
659 tx.send(item).await.unwrap();
660 assert_eq!(inflight.get(), 1);
661 assert_eq!(received.get(), 0);
662
663 let received_item = rx.recv().await.unwrap();
664 assert_eq!(received_item, item);
665 assert_eq!(inflight.get(), 0);
666 assert_eq!(received.get(), 1);
667
668 let res = rx.try_recv();
670 assert!(res.is_err());
671 assert_eq!(inflight.get(), 0);
672 assert_eq!(received.get(), 1);
673
674 rx.close();
676 let res2 = rx.recv().now_or_never().unwrap();
677 assert!(res2.is_none());
678 assert_eq!(inflight.get(), 0);
679 assert_eq!(received.get(), 1);
680 }
681
682 #[tokio::test]
683 async fn test_reserve() {
684 init_metrics(&Registry::new());
685 let (tx, mut rx) = channel("test_reserve", 8);
686 let inflight = tx.inflight();
687 let sent = tx.sent();
688 let received = rx.received().clone();
689
690 assert_eq!(inflight.get(), 0);
691
692 let permit = tx.reserve().await.unwrap();
693 assert_eq!(inflight.get(), 1);
694 assert_eq!(sent.get(), 0);
695 assert_eq!(received.get(), 0);
696
697 let item = 42;
698 permit.send(item);
699 assert_eq!(inflight.get(), 1);
700 assert_eq!(sent.get(), 1);
701 assert_eq!(received.get(), 0);
702
703 let permit_2 = tx.reserve().await.unwrap();
704 assert_eq!(inflight.get(), 2);
705 assert_eq!(sent.get(), 1);
706 assert_eq!(received.get(), 0);
707
708 drop(permit_2);
709 assert_eq!(inflight.get(), 1);
710 assert_eq!(sent.get(), 1);
711 assert_eq!(received.get(), 0);
712
713 let received_item = rx.recv().await.unwrap();
714 assert_eq!(received_item, item);
715
716 assert_eq!(inflight.get(), 0);
717 assert_eq!(sent.get(), 1);
718 assert_eq!(received.get(), 1);
719 }
720
721 #[tokio::test]
722 async fn test_reserve_and_drop() {
723 init_metrics(&Registry::new());
724 let (tx, _rx) = channel::<usize>("test_reserve_and_drop", 8);
725 let inflight = tx.inflight();
726
727 assert_eq!(inflight.get(), 0);
728
729 let permit = tx.reserve().await.unwrap();
730 assert_eq!(inflight.get(), 1);
731
732 drop(permit);
733
734 assert_eq!(inflight.get(), 0);
735 }
736
737 #[tokio::test]
738 async fn test_send_backpressure() {
739 init_metrics(&Registry::new());
740 let waker = noop_waker();
741 let mut cx = Context::from_waker(&waker);
742
743 let (tx, mut rx) = channel("test_send_backpressure", 1);
744 let inflight = tx.inflight();
745 let sent = tx.sent();
746 let received = rx.received().clone();
747
748 assert_eq!(inflight.get(), 0);
749
750 tx.send(1).await.unwrap();
751 assert_eq!(inflight.get(), 1);
752 assert_eq!(sent.get(), 1);
753 assert_eq!(received.get(), 0);
754
755 let mut task = Box::pin(tx.send(2));
757 assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
758 assert_eq!(inflight.get(), 1);
759 assert_eq!(sent.get(), 1);
760 assert_eq!(received.get(), 0);
761
762 let item = rx.recv().await.unwrap();
763 assert_eq!(item, 1);
764 assert_eq!(inflight.get(), 0);
765 assert_eq!(sent.get(), 1);
766 assert_eq!(received.get(), 1);
767
768 assert!(task.now_or_never().is_some());
769 assert_eq!(inflight.get(), 1);
770 assert_eq!(sent.get(), 2);
771 assert_eq!(received.get(), 1);
772 }
773
774 #[tokio::test]
775 async fn test_reserve_backpressure() {
776 init_metrics(&Registry::new());
777 let waker = noop_waker();
778 let mut cx = Context::from_waker(&waker);
779
780 let (tx, mut rx) = channel("test_reserve_backpressure", 1);
781 let inflight = tx.inflight();
782 let sent = tx.sent();
783 let received = rx.received().clone();
784
785 assert_eq!(inflight.get(), 0);
786
787 let permit = tx.reserve().await.unwrap();
788 assert_eq!(inflight.get(), 1);
789 assert_eq!(sent.get(), 0);
790 assert_eq!(received.get(), 0);
791
792 let mut task = Box::pin(tx.send(2));
793 assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
794 assert_eq!(inflight.get(), 1);
795 assert_eq!(sent.get(), 0);
796 assert_eq!(received.get(), 0);
797
798 permit.send(1);
799 assert_eq!(inflight.get(), 1);
800 assert_eq!(sent.get(), 1);
801 assert_eq!(received.get(), 0);
802
803 let item = rx.recv().await.unwrap();
804 assert_eq!(item, 1);
805 assert_eq!(inflight.get(), 0);
806 assert_eq!(sent.get(), 1);
807 assert_eq!(received.get(), 1);
808
809 assert!(task.now_or_never().is_some());
810 assert_eq!(inflight.get(), 1);
811 assert_eq!(sent.get(), 2);
812 assert_eq!(received.get(), 1);
813 }
814
815 #[tokio::test]
816 async fn test_send_backpressure_multi_senders() {
817 init_metrics(&Registry::new());
818 let waker = noop_waker();
819 let mut cx = Context::from_waker(&waker);
820 let (tx1, mut rx) = channel("test_send_backpressure_multi_senders", 1);
821 let inflight = tx1.inflight();
822 let sent = tx1.sent();
823 let received = rx.received().clone();
824
825 assert_eq!(inflight.get(), 0);
826
827 tx1.send(1).await.unwrap();
828 assert_eq!(inflight.get(), 1);
829 assert_eq!(sent.get(), 1);
830 assert_eq!(received.get(), 0);
831
832 let tx2 = tx1.clone();
833 let mut task = Box::pin(tx2.send(2));
834 assert!(matches!(task.poll_unpin(&mut cx), Poll::Pending));
835 assert_eq!(inflight.get(), 1);
836 assert_eq!(sent.get(), 1);
837 assert_eq!(received.get(), 0);
838
839 let item = rx.recv().await.unwrap();
840 assert_eq!(item, 1);
841 assert_eq!(inflight.get(), 0);
842 assert_eq!(sent.get(), 1);
843 assert_eq!(received.get(), 1);
844
845 assert!(task.now_or_never().is_some());
846 assert_eq!(inflight.get(), 1);
847 assert_eq!(sent.get(), 2);
848 assert_eq!(received.get(), 1);
849 }
850}