sui_aws_orchestrator/
testbed.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::time::Duration;
5
6use futures::future::try_join_all;
7use prettytable::{Table, row};
8use tokio::time::{self, Instant};
9
10use crate::{
11    client::ServerProviderClient,
12    display,
13    error::{TestbedError, TestbedResult},
14    settings::Settings,
15    ssh::SshConnection,
16};
17
18use super::client::Instance;
19
20/// Represents a testbed running on a cloud provider.
21pub struct Testbed<C> {
22    /// The testbed's settings.
23    settings: Settings,
24    /// The client interfacing with the cloud provider.
25    client: C,
26    /// The state of the testbed (reflecting accurately the state of the machines).
27    instances: Vec<Instance>,
28}
29
30impl<C: ServerProviderClient> Testbed<C> {
31    /// Create a new testbed instance with the specified settings and client.
32    pub async fn new(settings: Settings, client: C) -> TestbedResult<Self> {
33        let public_key = settings.load_ssh_public_key()?;
34        client.register_ssh_public_key(public_key).await?;
35        let instances = client.list_instances().await?;
36
37        Ok(Self {
38            settings,
39            client,
40            instances,
41        })
42    }
43
44    /// Return the username to connect to the instances through ssh.
45    pub fn username(&self) -> &'static str {
46        C::USERNAME
47    }
48
49    /// Return the list of instances of the testbed.
50    pub fn instances(&self) -> Vec<Instance> {
51        self.instances
52            .iter()
53            .filter(|x| self.settings.filter_instances(x))
54            .cloned()
55            .collect()
56    }
57
58    /// Return the list of provider-specific instance setup commands.
59    pub async fn setup_commands(&self) -> TestbedResult<Vec<String>> {
60        self.client
61            .instance_setup_commands()
62            .await
63            .map_err(TestbedError::from)
64    }
65
66    /// Print the current status of the testbed.
67    pub fn status(&self) {
68        let filtered = self
69            .instances
70            .iter()
71            .filter(|instance| self.settings.filter_instances(instance));
72        let sorted: Vec<(_, Vec<_>)> = self
73            .settings
74            .regions
75            .iter()
76            .map(|region| {
77                (
78                    region,
79                    filtered
80                        .clone()
81                        .filter(|instance| &instance.region == region)
82                        .collect(),
83                )
84            })
85            .collect();
86
87        let mut table = Table::new();
88        table.set_format(display::default_table_format());
89
90        let active = filtered.filter(|x| x.is_active()).count();
91        table.set_titles(row![bH2->format!("Instances ({active})")]);
92        for (i, (region, instances)) in sorted.iter().enumerate() {
93            table.add_row(row![bH2->region.to_uppercase()]);
94            let mut j = 0;
95            for instance in instances {
96                if j % 5 == 0 {
97                    table.add_row(row![]);
98                }
99                let private_key_file = self.settings.ssh_private_key_file.display();
100                let username = C::USERNAME;
101                let ip = instance.main_ip;
102                let connect = format!("ssh -i {private_key_file} {username}@{ip}");
103                if !instance.is_terminated() {
104                    if instance.is_active() {
105                        table.add_row(row![bFg->format!("{j}"), connect]);
106                    } else {
107                        table.add_row(row![bFr->format!("{j}"), connect]);
108                    }
109                    j += 1;
110                }
111            }
112            if i != sorted.len() - 1 {
113                table.add_row(row![]);
114            }
115        }
116
117        display::newline();
118        display::config("Client", &self.client);
119        let repo = &self.settings.repository;
120        display::config("Repo", format!("{} ({})", repo.url, repo.commit));
121        display::newline();
122        table.printstd();
123        display::newline();
124    }
125
126    /// Populate the testbed by creating the specified amount of instances per region. The total
127    /// number of instances created is thus the specified amount x the number of regions.
128    pub async fn deploy(&mut self, quantity: usize, region: Option<String>) -> TestbedResult<()> {
129        display::action(format!("Deploying instances ({quantity} per region)"));
130
131        let instances = match region {
132            Some(x) => {
133                try_join_all((0..quantity).map(|_| self.client.create_instance(x.clone()))).await?
134            }
135            None => {
136                try_join_all(self.settings.regions.iter().flat_map(|region| {
137                    (0..quantity).map(|_| self.client.create_instance(region.clone()))
138                }))
139                .await?
140            }
141        };
142
143        // Wait until the instances are booted.
144        if cfg!(not(test)) {
145            self.wait_until_reachable(instances.iter()).await?;
146        }
147        self.instances = self.client.list_instances().await?;
148
149        display::done();
150        Ok(())
151    }
152
153    /// Destroy all instances of the testbed.
154    pub async fn destroy(&mut self) -> TestbedResult<()> {
155        display::action("Destroying testbed");
156
157        try_join_all(
158            self.instances
159                .drain(..)
160                .map(|instance| self.client.delete_instance(instance)),
161        )
162        .await?;
163
164        display::done();
165        Ok(())
166    }
167
168    /// Start the specified number of instances in each region. Returns an error if there are not
169    /// enough available instances.
170    pub async fn start(&mut self, quantity: usize) -> TestbedResult<()> {
171        display::action("Booting instances");
172
173        // Gather available instances.
174        let mut available = Vec::new();
175        for region in &self.settings.regions {
176            available.extend(
177                self.instances
178                    .iter()
179                    .filter(|x| {
180                        x.is_inactive() && &x.region == region && self.settings.filter_instances(x)
181                    })
182                    .take(quantity)
183                    .cloned()
184                    .collect::<Vec<_>>(),
185            );
186        }
187
188        // Start instances.
189        self.client.start_instances(available.iter()).await?;
190
191        // Wait until the instances are started.
192        if cfg!(not(test)) {
193            self.wait_until_reachable(available.iter()).await?;
194        }
195        self.instances = self.client.list_instances().await?;
196
197        display::done();
198        Ok(())
199    }
200
201    /// Stop all instances of the testbed.
202    pub async fn stop(&mut self) -> TestbedResult<()> {
203        display::action("Stopping instances");
204
205        // Stop all instances.
206        self.client
207            .stop_instances(self.instances.iter().filter(|i| i.is_active()))
208            .await?;
209
210        // Wait until the instances are stopped.
211        loop {
212            let instances = self.client.list_instances().await?;
213            if instances.iter().all(|x| x.is_inactive()) {
214                self.instances = instances;
215                break;
216            }
217        }
218
219        display::done();
220        Ok(())
221    }
222
223    /// Wait until all specified instances are ready to accept ssh connections.
224    async fn wait_until_reachable<'a, I>(&self, instances: I) -> TestbedResult<()>
225    where
226        I: Iterator<Item = &'a Instance> + Clone,
227    {
228        let instances_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
229
230        let mut interval = time::interval(Duration::from_secs(5));
231        interval.tick().await; // The first tick returns immediately.
232
233        let start = Instant::now();
234        loop {
235            let now = interval.tick().await;
236            let elapsed = now.duration_since(start).as_secs_f64().ceil() as u64;
237            display::status(format!("{elapsed}s"));
238
239            let instances = self.client.list_instances().await?;
240            let futures = instances
241                .iter()
242                .filter(|x| instances_ids.contains(&x.id))
243                .map(|instance| {
244                    let private_key_file = self.settings.ssh_private_key_file.clone();
245                    SshConnection::new(
246                        instance.ssh_address(),
247                        C::USERNAME,
248                        private_key_file,
249                        None,
250                        None,
251                    )
252                });
253            if try_join_all(futures).await.is_ok() {
254                break;
255            }
256        }
257        Ok(())
258    }
259}
260
261#[cfg(test)]
262mod test {
263    use crate::{client::test_client::TestClient, settings::Settings, testbed::Testbed};
264
265    #[tokio::test]
266    async fn deploy() {
267        let settings = Settings::new_for_test();
268        let client = TestClient::new(settings.clone());
269        let mut testbed = Testbed::new(settings, client).await.unwrap();
270
271        testbed.deploy(5, None).await.unwrap();
272
273        assert_eq!(
274            testbed.instances.len(),
275            5 * testbed.settings.number_of_regions()
276        );
277        for (i, instance) in testbed.instances.iter().enumerate() {
278            assert_eq!(i.to_string(), instance.id);
279        }
280    }
281
282    #[tokio::test]
283    async fn destroy() {
284        let settings = Settings::new_for_test();
285        let client = TestClient::new(settings.clone());
286        let mut testbed = Testbed::new(settings, client).await.unwrap();
287
288        testbed.destroy().await.unwrap();
289
290        assert_eq!(testbed.instances.len(), 0);
291    }
292
293    #[tokio::test]
294    async fn start() {
295        let settings = Settings::new_for_test();
296        let client = TestClient::new(settings.clone());
297        let mut testbed = Testbed::new(settings, client).await.unwrap();
298        testbed.deploy(5, None).await.unwrap();
299        testbed.stop().await.unwrap();
300
301        let result = testbed.start(2).await;
302
303        assert!(result.is_ok());
304        for region in &testbed.settings.regions {
305            let active = testbed
306                .instances
307                .iter()
308                .filter(|x| x.is_active() && &x.region == region)
309                .count();
310            assert_eq!(active, 2);
311
312            let inactive = testbed
313                .instances
314                .iter()
315                .filter(|x| x.is_inactive() && &x.region == region)
316                .count();
317            assert_eq!(inactive, 3);
318        }
319    }
320
321    #[tokio::test]
322    async fn stop() {
323        let settings = Settings::new_for_test();
324        let client = TestClient::new(settings.clone());
325        let mut testbed = Testbed::new(settings, client).await.unwrap();
326        testbed.deploy(5, None).await.unwrap();
327        testbed.start(2).await.unwrap();
328
329        testbed.stop().await.unwrap();
330
331        assert!(testbed.instances.iter().all(|x| x.is_inactive()))
332    }
333}