sui_aws_orchestrator/client/
mod.rs1use std::{
5    fmt::Display,
6    net::{Ipv4Addr, SocketAddr},
7};
8
9use serde::{Deserialize, Serialize};
10
11use super::error::CloudProviderResult;
12
13pub mod aws;
14
15#[derive(Debug, Deserialize, Clone, Eq, PartialEq, Hash)]
17pub struct Instance {
18    pub id: String,
20    pub region: String,
22    pub main_ip: Ipv4Addr,
24    pub tags: Vec<String>,
26    pub specs: String,
28    pub status: String,
30}
31
32impl Instance {
33    pub fn is_active(&self) -> bool {
35        self.status.to_lowercase() == "running"
36    }
37
38    pub fn is_inactive(&self) -> bool {
40        !self.is_active()
41    }
42
43    pub fn is_terminated(&self) -> bool {
45        self.status.to_lowercase() == "terminated"
46    }
47
48    pub fn ssh_address(&self) -> SocketAddr {
50        format!("{}:22", self.main_ip).parse().unwrap()
51    }
52
53    #[cfg(test)]
54    pub fn new_for_test(id: String) -> Self {
55        Self {
56            id,
57            region: Default::default(),
58            main_ip: Ipv4Addr::new(127, 0, 0, 1),
59            tags: Default::default(),
60            specs: Default::default(),
61            status: Default::default(),
62        }
63    }
64}
65
66#[async_trait::async_trait]
67pub trait ServerProviderClient: Display {
68    const USERNAME: &'static str;
70
71    async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>>;
73
74    async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
76    where
77        I: Iterator<Item = &'a Instance> + Send;
78
79    async fn stop_instances<'a, I>(&self, instance_ids: I) -> CloudProviderResult<()>
81    where
82        I: Iterator<Item = &'a Instance> + Send;
83
84    async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
86    where
87        S: Into<String> + Serialize + Send;
88
89    async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()>;
92
93    async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()>;
95
96    async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>>;
98}
99
100#[cfg(test)]
101pub mod test_client {
102    use std::{fmt::Display, sync::Mutex};
103
104    use serde::Serialize;
105
106    use crate::{error::CloudProviderResult, settings::Settings};
107
108    use super::{Instance, ServerProviderClient};
109
110    pub struct TestClient {
111        settings: Settings,
112        instances: Mutex<Vec<Instance>>,
113    }
114
115    impl TestClient {
116        pub fn new(settings: Settings) -> Self {
117            Self {
118                settings,
119                instances: Mutex::new(Vec::new()),
120            }
121        }
122    }
123
124    impl Display for TestClient {
125        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126            write!(f, "TestClient")
127        }
128    }
129
130    #[async_trait::async_trait]
131    impl ServerProviderClient for TestClient {
132        const USERNAME: &'static str = "root";
133
134        async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>> {
135            let guard = self.instances.lock().unwrap();
136            Ok(guard.clone())
137        }
138
139        async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
140        where
141            I: Iterator<Item = &'a Instance> + Send,
142        {
143            let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
144            let mut guard = self.instances.lock().unwrap();
145            for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
146                instance.status = "running".into();
147            }
148            Ok(())
149        }
150
151        async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
152        where
153            I: Iterator<Item = &'a Instance> + Send,
154        {
155            let instance_ids: Vec<_> = instances.map(|x| x.id.clone()).collect();
156            let mut guard = self.instances.lock().unwrap();
157            for instance in guard.iter_mut().filter(|x| instance_ids.contains(&x.id)) {
158                instance.status = "stopped".into();
159            }
160            Ok(())
161        }
162
163        async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
164        where
165            S: Into<String> + Serialize + Send,
166        {
167            let mut guard = self.instances.lock().unwrap();
168            let id = guard.len();
169            let instance = Instance {
170                id: id.to_string(),
171                region: region.into(),
172                main_ip: format!("0.0.0.{id}").parse().unwrap(),
173                tags: Vec::new(),
174                specs: self.settings.specs.clone(),
175                status: "running".into(),
176            };
177            guard.push(instance.clone());
178            Ok(instance)
179        }
180
181        async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
182            let mut guard = self.instances.lock().unwrap();
183            guard.retain(|x| x.id != instance.id);
184            Ok(())
185        }
186
187        async fn register_ssh_public_key(&self, _public_key: String) -> CloudProviderResult<()> {
188            Ok(())
189        }
190
191        async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
192            Ok(Vec::new())
193        }
194    }
195}