1use std::{
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use tokio::task::{JoinError, JoinHandle};
10
11#[must_use = "Dropping the handle aborts the task immediately"]
15#[derive(Debug)]
16pub struct TaskGuard<T>(JoinHandle<T>);
17
18impl<T> TaskGuard<T> {
19 pub fn new(handle: JoinHandle<T>) -> Self {
20 Self(handle)
21 }
22}
23
24impl<T> Future for TaskGuard<T> {
25 type Output = Result<T, JoinError>;
26
27 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
28 Pin::new(&mut self.0).poll(cx)
29 }
30}
31
32impl<T> AsRef<JoinHandle<T>> for TaskGuard<T> {
33 fn as_ref(&self) -> &JoinHandle<T> {
34 &self.0
35 }
36}
37
38impl<T> Drop for TaskGuard<T> {
39 fn drop(&mut self) {
40 self.0.abort();
41 }
42}
43
44#[cfg(test)]
45mod tests {
46 use std::time::Duration;
47
48 use tokio::sync::oneshot;
49
50 use super::*;
51
52 #[tokio::test]
53 async fn test_abort_on_drop() {
54 let (mut tx, rx) = oneshot::channel::<()>();
55
56 let guard = TaskGuard::new(tokio::spawn(async move {
57 let _ = rx.await;
58 }));
59
60 drop(guard);
63 tokio::time::timeout(Duration::from_millis(100), tx.closed())
64 .await
65 .unwrap();
66 }
67}