sui_security_watchdog/
scheduler.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::SecurityWatchdogConfig;
5use crate::metrics::WatchdogMetrics;
6use crate::pagerduty::{Body, CreateIncident, Incident, Pagerduty, Service};
7use crate::query_runner::{QueryRunner, SnowflakeQueryRunner};
8use anyhow::anyhow;
9use chrono::{DateTime, Utc};
10use prometheus::{IntGauge, Registry};
11use serde::{Deserialize, Serialize};
12use std::any::Any;
13use std::collections::BTreeMap;
14use std::fs::File;
15use std::io::Read;
16use std::sync::Arc;
17use tokio_cron_scheduler::{Job, JobScheduler};
18use tracing::{error, info};
19use uuid::Uuid;
20
21const MIST_PER_SUI: i128 = 1_000_000_000;
22
23// MonitoringEntry is an enum that represents the types of monitoring entries that can be scheduled.
24#[derive(Serialize, Deserialize)]
25#[serde(tag = "type")]
26enum MonitoringEntry {
27    MetricPublishingEntry(MetricPublishingEntry),
28    WalletMonitoringEntry(WalletMonitoringEntry),
29}
30
31// MetricPublishingEntry is a struct that represents the configuration for a job which runs a sql
32// query on a cron schedule and publishes metrics if the output is outside expected thresholds. Alerts
33// could be set on the metric dashboard in grafana if needed
34#[derive(Clone, Serialize, Deserialize)]
35pub struct MetricPublishingEntry {
36    name: String,
37    cron_schedule: String,
38    sql_query: String,
39    metric_name: String,
40    timed_upper_limits: BTreeMap<DateTime<Utc>, f64>,
41    timed_lower_limits: BTreeMap<DateTime<Utc>, f64>,
42    timed_exact_limits: BTreeMap<DateTime<Utc>, f64>,
43}
44
45// WalletMonitoringEntry is a struct that represents the configuration of a job which monitors wallet balances.
46// It creates pagerduty incidents based on the given SQL query and cron schedule.
47#[derive(Clone, Serialize, Deserialize)]
48pub struct WalletMonitoringEntry {
49    name: String,
50    cron_schedule: String,
51    sql_query: String,
52}
53
54pub struct SchedulerService {
55    scheduler: JobScheduler,
56    query_runner: Arc<dyn QueryRunner>,
57    metrics: Arc<WatchdogMetrics>,
58    entries: Vec<MonitoringEntry>,
59    pagerduty: Pagerduty,
60    pd_wallet_monitoring_service_id: String,
61}
62
63impl SchedulerService {
64    pub async fn new(
65        config: &SecurityWatchdogConfig,
66        registry: &Registry,
67        pd_api_key: String,
68        sf_password: String,
69    ) -> anyhow::Result<Self> {
70        let scheduler = JobScheduler::new().await?;
71        Ok(Self {
72            scheduler,
73            query_runner: Arc::new(SnowflakeQueryRunner::from_config(config, sf_password)?),
74            metrics: Arc::new(WatchdogMetrics::new(registry)),
75            entries: Self::from_config(config)?,
76            pagerduty: Pagerduty::new(pd_api_key.clone()),
77            pd_wallet_monitoring_service_id: config.pd_wallet_monitoring_service_id.clone(),
78        })
79    }
80
81    pub async fn schedule(&self) -> anyhow::Result<()> {
82        for monitoring_entry in &self.entries {
83            match monitoring_entry {
84                MonitoringEntry::MetricPublishingEntry(entry) => {
85                    Self::schedule_metric_publish_job(
86                        entry.clone(),
87                        self.scheduler.clone(),
88                        self.query_runner.clone(),
89                        self.metrics.clone(),
90                    )
91                    .await?;
92                }
93                MonitoringEntry::WalletMonitoringEntry(entry) => {
94                    self.schedule_wallet_monitoring_job(
95                        entry.clone(),
96                        self.scheduler.clone(),
97                        self.query_runner.clone(),
98                        self.pd_wallet_monitoring_service_id.clone(),
99                        self.metrics.clone(),
100                        self.pagerduty.clone(),
101                    )
102                    .await?;
103                }
104            }
105        }
106        Ok(())
107    }
108
109    pub async fn start(&self) -> anyhow::Result<()> {
110        self.scheduler.start().await?;
111        Ok(())
112    }
113
114    fn from_config(config: &SecurityWatchdogConfig) -> anyhow::Result<Vec<MonitoringEntry>> {
115        let mut file = File::open(&config.config)?;
116        let mut contents = String::new();
117        file.read_to_string(&mut contents)?;
118        let entries: Vec<MonitoringEntry> = serde_json::from_str(&contents)?;
119        Ok(entries)
120    }
121
122    async fn schedule_wallet_monitoring_job(
123        &self,
124        entry: WalletMonitoringEntry,
125        scheduler: JobScheduler,
126        query_runner: Arc<dyn QueryRunner>,
127        pd_service_id: String,
128        metrics: Arc<WatchdogMetrics>,
129        pagerduty: Pagerduty,
130    ) -> anyhow::Result<Uuid> {
131        let name = entry.name.clone();
132        let cron_schedule = entry.cron_schedule.clone();
133        let job = Job::new_async(cron_schedule.as_str(), move |_uuid, _lock| {
134            let entry = entry.clone();
135            let query_runner = query_runner.clone();
136            let pd_service_id = pd_service_id.to_string();
137            let pd = pagerduty.clone();
138            let metrics = metrics.clone();
139            Box::pin(async move {
140                info!("Running wallet monitoring job: {}", entry.name);
141                if let Err(err) =
142                    Self::run_wallet_monitoring_job(&pd, &pd_service_id, &query_runner, &entry)
143                        .await
144                {
145                    error!(
146                        "Failed to run wallet monitoring job: {} with err: {}",
147                        entry.name, err
148                    );
149                    metrics
150                        .get("wallet_monitoring_error")
151                        .await
152                        .iter()
153                        .for_each(|metric| metric.inc());
154                }
155            })
156        })?;
157        let job_id = scheduler.add(job).await?;
158        info!("Scheduled job: {}", name);
159        Ok(job_id)
160    }
161
162    async fn run_wallet_monitoring_job(
163        pagerduty: &Pagerduty,
164        service_id: &str,
165        query_runner: &Arc<dyn QueryRunner>,
166        entry: &WalletMonitoringEntry,
167    ) -> anyhow::Result<()> {
168        let WalletMonitoringEntry {
169            sql_query, name, ..
170        } = entry;
171        let rows = query_runner.run(sql_query).await?;
172        for row in rows {
173            let wallet_id = row
174                .get("WALLET_ID")
175                .ok_or_else(|| anyhow!("Missing wallet_id"))?
176                .downcast_ref::<String>()
177                .ok_or(anyhow!("Failed to downcast wallet_id"))?
178                .clone();
179            let current_balance = Self::extract_i128(
180                row.get("CURRENT_BALANCE")
181                    .ok_or_else(|| anyhow!("Missing current_balance"))?,
182            )
183            .ok_or(anyhow!("Failed to downcast current_balance"))?;
184            let lower_bound = Self::extract_i128(
185                row.get("LOWER_BOUND")
186                    .ok_or_else(|| anyhow!("Missing lower_bound"))?,
187            )
188            .ok_or(anyhow!("Failed to downcast lower_bound"))?;
189            Self::create_wallet_monitoring_incident(
190                pagerduty,
191                &wallet_id,
192                current_balance,
193                lower_bound,
194                service_id,
195                name,
196            )
197            .await?;
198        }
199        Ok(())
200    }
201
202    async fn create_wallet_monitoring_incident(
203        pagerduty: &Pagerduty,
204        wallet_id: &str,
205        current_balance: i128,
206        lower_bound: i128,
207        service_id: &str,
208        name: &str,
209    ) -> anyhow::Result<()> {
210        let service = Service {
211            id: service_id.to_string(),
212            ..Default::default()
213        };
214        let incident_body = Body {
215            details: format!(
216                "Current balance: {}, Lower bound: {}, for job: {}",
217                current_balance / MIST_PER_SUI,
218                lower_bound / MIST_PER_SUI,
219                name
220            ),
221            ..Default::default()
222        };
223        let incident = Incident {
224            title: format!(
225                "Wallet: {} is out of compliance, for job: {}",
226                wallet_id, name
227            ),
228            service,
229            incident_key: wallet_id.to_string(),
230            body: incident_body,
231            ..Default::default()
232        };
233        let create_incident = CreateIncident { incident };
234        pagerduty
235            .create_incident("sadhan@mystenlabs.com", create_incident)
236            .await?;
237        Ok(())
238    }
239
240    async fn schedule_metric_publish_job(
241        entry: MetricPublishingEntry,
242        scheduler: JobScheduler,
243        query_runner: Arc<dyn QueryRunner>,
244        metrics: Arc<WatchdogMetrics>,
245    ) -> anyhow::Result<Uuid> {
246        let name = entry.name.clone();
247        let cron_schedule = entry.cron_schedule.clone();
248        let job = Job::new_async(cron_schedule.as_str(), move |_uuid, _lock| {
249            let entry = entry.clone();
250            let query_runner = query_runner.clone();
251            let metrics = metrics.clone();
252            Box::pin(async move {
253                info!("Running metric publish job: {}", &entry.name);
254                if let Err(err) =
255                    Self::run_metric_publish_job(&query_runner, &metrics, &entry).await
256                {
257                    error!("Failed to run metric publish job with err: {}", err);
258                    metrics
259                        .get("metric_publishing_error")
260                        .await
261                        .iter()
262                        .for_each(|metric| metric.inc());
263                }
264            })
265        })?;
266        let job_id = scheduler.add(job).await?;
267        info!("Scheduled job: {}", name);
268        Ok(job_id)
269    }
270
271    async fn run_metric_publish_job(
272        query_runner: &Arc<dyn QueryRunner>,
273        metrics: &Arc<WatchdogMetrics>,
274        entry: &MetricPublishingEntry,
275    ) -> anyhow::Result<()> {
276        let MetricPublishingEntry {
277            sql_query,
278            timed_exact_limits,
279            timed_upper_limits,
280            timed_lower_limits,
281            metric_name,
282            ..
283        } = entry;
284        let res = query_runner.run_single_entry(sql_query).await?;
285        let update_metrics = |limits: &BTreeMap<DateTime<Utc>, f64>, metric: IntGauge| {
286            if let Some(value) = Self::get_current_limit(limits) {
287                metric.set((res - value) as i64);
288            } else {
289                metric.set(0);
290            }
291        };
292
293        update_metrics(timed_exact_limits, metrics.get_exact(metric_name).await?);
294        update_metrics(timed_upper_limits, metrics.get_upper(metric_name).await?);
295        update_metrics(timed_lower_limits, metrics.get_lower(metric_name).await?);
296        Ok(())
297    }
298
299    fn get_current_limit(limits: &BTreeMap<DateTime<Utc>, f64>) -> Option<f64> {
300        limits.range(..Utc::now()).next_back().map(|(_, val)| *val)
301    }
302
303    fn extract_i128(value: &Box<dyn Any + Send>) -> Option<i128> {
304        if let Some(value) = value.downcast_ref::<i128>() {
305            Some(*value)
306        } else if let Some(value) = value.downcast_ref::<u32>() {
307            Some(*value as i128)
308        } else if let Some(value) = value.downcast_ref::<u16>() {
309            Some(*value as i128)
310        } else if let Some(value) = value.downcast_ref::<u8>() {
311            Some(*value as i128)
312        } else if let Some(value) = value.downcast_ref::<i64>() {
313            Some(*value as i128)
314        } else if let Some(value) = value.downcast_ref::<i32>() {
315            Some(*value as i128)
316        } else if let Some(value) = value.downcast_ref::<i16>() {
317            Some(*value as i128)
318        } else {
319            value.downcast_ref::<i8>().map(|value| *value as i128)
320        }
321    }
322}