sui_aws_orchestrator/client/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    fmt::Display,
6    net::{Ipv4Addr, SocketAddr},
7};
8
9use serde::{Deserialize, Serialize};
10
11use super::error::CloudProviderResult;
12
13pub mod aws;
14
15/// Represents a cloud provider instance.
16#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
17pub struct Instance {
18    /// The unique identifier of the instance.
19    pub id: String,
20    /// The region where the instance runs.
21    pub region: String,
22    /// The public ip address of the instance (accessible from anywhere).
23    pub main_ip: Ipv4Addr,
24    /// The list of tags associated with the instance.
25    pub tags: Vec<String>,
26    /// The specs of the instance.
27    pub specs: String,
28    /// The current status of the instance.
29    pub status: String,
30}
31
32impl Instance {
33    /// Return whether the instance is active and running.
34    pub fn is_active(&self) -> bool {
35        self.status.to_lowercase() == "running"
36    }
37
38    /// Return whether the instance is inactive and not ready for use.
39    pub fn is_inactive(&self) -> bool {
40        !self.is_active()
41    }
42
43    /// Return whether the instance is terminated and in the process of being deleted.
44    pub fn is_terminated(&self) -> bool {
45        self.status.to_lowercase() == "terminated"
46    }
47
48    /// Return the ssh address to connect to the instance.
49    pub fn ssh_address(&self) -> SocketAddr {
50        format!("{}:22", self.main_ip).parse().unwrap()
51    }
52
53    #[cfg(test)]
54    pub fn new_for_test(id: String) -> Self {
55        Self {
56            id,
57            region: Default::default(),
58            main_ip: Ipv4Addr::new(127, 0, 0, 1),
59            tags: Default::default(),
60            specs: Default::default(),
61            status: Default::default(),
62        }
63    }
64}
65
66#[async_trait::async_trait]
67pub trait ServerProviderClient: Display {
68    /// The username used to connect to the instances.
69    const USERNAME: &'static str;
70
71    /// List all existing instances (regardless of their status).
72    async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>>;
73
74    /// Start the specified instances.
75    async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
76    where
77        I: Iterator<Item = &'a Instance> + Send;
78
79    /// Halt/Stop the specified instances. We may still be billed for stopped instances.
80    async fn stop_instances<'a, I>(&self, instance_ids: I) -> CloudProviderResult<()>
81    where
82        I: Iterator<Item = &'a Instance> + Send;
83
84    /// Create an instance in a specific region.
85    async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
86    where
87        S: Into<String> + Serialize + Send;
88
89    /// Delete a specific instance. Calling this function ensures we are no longer billed for
90    /// the specified instance.
91    async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()>;
92
93    /// Authorize the provided ssh public key to access machines.
94    async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()>;
95
96    /// Return provider-specific commands to setup the instance.
97    async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>>;
98}
99
100#[cfg(test)]
101pub mod test_client {
102    use std::{fmt::Display, sync::Mutex};
103
104    use serde::Serialize;
105
106    use crate::{error::CloudProviderResult, settings::Settings};
107
108    use super::{Instance, ServerProviderClient};
109
110    pub struct TestClient {
111        settings: Settings,
112        instances: Mutex<Vec<Instance>>,
113    }
114
115    impl TestClient {
116        pub fn new(settings: Settings) -> Self {
117            Self {
118                settings,
119                instances: Mutex::new(Vec::new()),
120            }
121        }
122    }
123
124    impl Display for TestClient {
125        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126            write!(f, "TestClient")
127        }
128    }
129
130    #[async_trait::async_trait]
131    impl ServerProviderClient for TestClient {
132        const USERNAME: &'static str = "root";
133
134        async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>> {
135            let guard = self.instances.lock().unwrap();
136            Ok(guard.clone())
137        }
138
139        async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
140        where
141            I: Iterator<Item = &'a Instance> + Send,
142        {
143            let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
144            let mut guard = self.instances.lock().unwrap();
145            for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
146                instance.status = "running".into();
147            }
148            Ok(())
149        }
150
151        async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
152        where
153            I: Iterator<Item = &'a Instance> + Send,
154        {
155            let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
156            let mut guard = self.instances.lock().unwrap();
157            for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
158                instance.status = "stopped".into();
159            }
160            Ok(())
161        }
162
163        async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
164        where
165            S: Into<String> + Serialize + Send,
166        {
167            let mut guard = self.instances.lock().unwrap();
168            let id = guard.len();
169            let instance = Instance {
170                id: id.to_string(),
171                region: region.into(),
172                main_ip: format!("0.0.0.{id}").parse().unwrap(),
173                tags: Vec::new(),
174                specs: self.settings.specs.clone(),
175                status: "running".into(),
176            };
177            guard.push(instance.clone());
178            Ok(instance)
179        }
180
181        async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
182            let mut guard = self.instances.lock().unwrap();
183            guard.retain(|x| x.id != instance.id);
184            Ok(())
185        }
186
187        async fn register_ssh_public_key(&self, _public_key: String) -> CloudProviderResult<()> {
188            Ok(())
189        }
190
191        async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
192            Ok(Vec::new())
193        }
194    }
195}