sui_futures/
task.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use tokio::task::{JoinError, JoinHandle};
10
11/// A wrapper around `JoinHandle` that aborts the task when dropped.
12///
13/// The abort on drop does not wait for the task to finish, it simply sends the abort signal.
14#[must_use = "Dropping the handle aborts the task immediately"]
15#[derive(Debug)]
16pub struct TaskGuard<T>(JoinHandle<T>);
17
18impl<T> TaskGuard<T> {
19    pub fn new(handle: JoinHandle<T>) -> Self {
20        Self(handle)
21    }
22}
23
24impl<T> Future for TaskGuard<T> {
25    type Output = Result<T, JoinError>;
26
27    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
28        Pin::new(&mut self.0).poll(cx)
29    }
30}
31
32impl<T> AsRef<JoinHandle<T>> for TaskGuard<T> {
33    fn as_ref(&self) -> &JoinHandle<T> {
34        &self.0
35    }
36}
37
38impl<T> Drop for TaskGuard<T> {
39    fn drop(&mut self) {
40        self.0.abort();
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use std::time::Duration;
47
48    use tokio::sync::oneshot;
49
50    use super::*;
51
52    #[tokio::test]
53    async fn test_abort_on_drop() {
54        let (mut tx, rx) = oneshot::channel::<()>();
55
56        let guard = TaskGuard::new(tokio::spawn(async move {
57            let _ = rx.await;
58        }));
59
60        // When the guard is dropped, the task should be aborted, cleaning up its future, which
61        // will close the receiving side of the channel.
62        drop(guard);
63        tokio::time::timeout(Duration::from_millis(100), tx.closed())
64            .await
65            .unwrap();
66    }
67}