1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use anyhow::anyhow;
use async_trait::async_trait;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot::{
    channel as oneshotChannel, Receiver as oneshotReceiver, Sender as oneshotSender,
};

pub type IrrecoverableError = anyhow::Error;
type JoinHandle = tokio::task::JoinHandle<()>;

static CHANNEL_SIZE: usize = 10;

/// A Supervisor is instantiated to supervise a task that should always be running.
/// A running supervisor will start a component task, and ensure that it is restarted
/// if it ever stops.
pub struct Supervisor<M: Manageable> {
    /// a message from our supervised component containing an error that we cannot recover
    /// receiving on this will trigger the component to restart
    irrecoverable_signal: Receiver<IrrecoverableError>,

    /// a signal to the supervised component that is implicitly sent after we receive a
    /// message on the channel above, which provides the supervised component with an "ack"
    cancellation_signal: Option<oneshotSender<()>>,

    /// the join handle of the tokio task that was spawned by the Manageable start method
    join_handle: Option<JoinHandle>,

    /// the Manageable trait object that contains functions to start the component and to
    /// handle_irrecoverable in the case that a restart is needed
    manageable: M,
}

/// In order to be Manageable, a user defines the following two functions:
///
/// 1. A start function that launches a tokio task, as input it takes:
/// - an irrecoverable error sender, on which the component sends information to the supervisor about
/// an irrecoverable event that has occurred
/// - a cancellation handle, which will be listened to in the task once an irrecoverable message
/// has been sent, used as an "ack" that the message has been received and so the function can return
///
/// 2. A handle_irrecoverable which takes actions on a relaunch due to an irrecoverable error
/// that happened. It takes the error message that may contain a stack trace and other information
/// that was sent to the Supervisor via the tx_irrecoverable passed into start.
#[async_trait]
pub trait Manageable {
    // The function that spawns a tokio task
    async fn start(
        &self,
        tx_irrecoverable: Sender<anyhow::Error>,
        rx_cancellation: oneshotReceiver<()>,
    ) -> tokio::task::JoinHandle<()>; // Note the task is "silent" (returns nothing)

    // The function for cleanup after the task has encountered an irrecoverable error
    fn handle_irrecoverable(
        &mut self,
        irrecoverable: IrrecoverableError,
    ) -> Result<(), anyhow::Error>;
}

impl<M: Manageable + Send> Supervisor<M> {
    /// Creates a new supervisor using a Manageable component.
    pub fn new(component: M) -> Self {
        let (_, tr_irrecoverable) = channel(CHANNEL_SIZE);
        Supervisor {
            irrecoverable_signal: tr_irrecoverable,
            cancellation_signal: None,
            join_handle: None,
            manageable: component,
        }
    }

    /// Spawn calls the start function of the Manageable component and runs supervision.
    pub async fn spawn(mut self) -> Result<(), anyhow::Error> {
        let (tx_irrecoverable, tr_irrecoverable) = channel(CHANNEL_SIZE);
        let (tx_cancellation, tr_cancellation) = oneshotChannel();

        // call Manageable start method
        let wrapped_handle = self
            .manageable
            .start(tx_irrecoverable, tr_cancellation)
            .await;

        self.irrecoverable_signal = tr_irrecoverable;
        self.cancellation_signal = Some(tx_cancellation);
        self.join_handle = Some(wrapped_handle);

        self.run().await
    }

    /// Run watches continuously for irrecoverable errors or JoinHandle completion.
    async fn run(mut self) -> Result<(), anyhow::Error> {
        // select statement that listens for the following cases:
        //
        // Irrecoverable signal incoming => log, terminate and restart
        // completion of the task => already terminated, log and restart
        //
        // The handle_irrecoverable is run before the existing task gets
        // cancelled by restart in the case that an irrecoverable signal
        // was sent to us. This makes resource cleanup possible.

        loop {
            let mut message = anyhow!("An unexpected shutdown was observed in a component.");
            tokio::select! {
                Some(m) = self.irrecoverable_signal.recv() => {
                    message = m;
                },

               // Poll the JoinHandle<O>
               _result =  self.join_handle.as_mut().unwrap(), if self.join_handle.is_some() => {
                    // this could be due to an un-caught panic
                    // we don't have a user-supplied message to log, so we use the generic one
                }
            }
            self.manageable.handle_irrecoverable(message)?;
            self.restart().await;
        }
    }

    async fn restart(&mut self) {
        // restart
        let (tx_irrecoverable, tr_irrecoverable) = channel(CHANNEL_SIZE);
        let (tx_cancellation, tr_cancellation) = oneshotChannel();

        // call the start method
        let wrapped_handle: JoinHandle = self
            .manageable
            .start(tx_irrecoverable, tr_cancellation)
            .await;

        // reset the supervision handles & channel end points
        // dropping the old cancellation_signal implicitly sends cancellation by closing the channel
        self.irrecoverable_signal = tr_irrecoverable;
        self.cancellation_signal = Some(tx_cancellation);

        self.join_handle = Some(wrapped_handle);
    }
}