sui_aws_orchestrator/client/
aws.rs1use std::{
5 collections::HashMap,
6 fmt::{Debug, Display},
7};
8
9use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
10use aws_sdk_ec2::primitives::Blob;
11use aws_sdk_ec2::{
12 config::Region,
13 types::{
14 BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter, ResourceType, Tag,
15 TagSpecification, VolumeType,
16 },
17};
18use aws_smithy_runtime_api::client::result::SdkError;
19use serde::Serialize;
20
21use crate::{
22 error::{CloudProviderError, CloudProviderResult},
23 settings::Settings,
24};
25
26use super::{Instance, ServerProviderClient};
27
28impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
30 for CloudProviderError
31where
32 T: Debug + std::error::Error + Send + Sync + 'static,
33{
34 fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
35 Self::RequestError(format!("{:?}", e.into_source()))
36 }
37}
38
39pub struct AwsClient {
41 settings: Settings,
43 clients: HashMap<String, aws_sdk_ec2::Client>,
45}
46
47impl Display for AwsClient {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
50 }
51}
52
53impl AwsClient {
54 const OS_IMAGE: &'static str =
55 "Canonical, Ubuntu, 22.04 LTS, amd64 jammy image build on 2023-02-16";
56
57 pub async fn new(settings: Settings) -> Self {
59 let profile_files = EnvConfigFiles::builder()
60 .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
61 .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
62 .build();
63
64 let mut clients = HashMap::new();
65 for region in settings.regions.clone() {
66 let sdk_config = aws_config::from_env()
67 .region(Region::new(region.clone()))
68 .profile_files(profile_files.clone())
69 .load()
70 .await;
71 let client = aws_sdk_ec2::Client::new(&sdk_config);
72 clients.insert(region, client);
73 }
74
75 Self { settings, clients }
76 }
77
78 fn check_but_ignore_duplicates<T, E>(
80 response: Result<
81 T,
82 SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
83 >,
84 ) -> CloudProviderResult<()>
85 where
86 E: Debug + std::error::Error + Send + Sync + 'static,
87 {
88 if let Err(e) = response {
89 let error_message = format!("{e:?}");
90 if !error_message.to_lowercase().contains("duplicate") {
91 return Err(e.into());
92 }
93 }
94 Ok(())
95 }
96
97 fn make_instance(
99 &self,
100 region: String,
101 aws_instance: &aws_sdk_ec2::types::Instance,
102 ) -> Instance {
103 Instance {
104 id: aws_instance
105 .instance_id()
106 .expect("AWS instance should have an id")
107 .into(),
108 region,
109 main_ip: aws_instance
110 .public_ip_address()
111 .unwrap_or("0.0.0.0") .parse()
113 .expect("AWS instance should have a valid ip"),
114 tags: vec![self.settings.testbed_id.clone()],
115 specs: format!(
116 "{:?}",
117 aws_instance
118 .instance_type()
119 .expect("AWS instance should have a type")
120 ),
121 status: format!(
122 "{:?}",
123 aws_instance
124 .state()
125 .expect("AWS instance should have a state")
126 .name()
127 .expect("AWS status should have a name")
128 ),
129 }
130 }
131
132 async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
135 let request = client.describe_images().filters(
137 Filter::builder()
138 .name("description")
139 .values(Self::OS_IMAGE)
140 .build(),
141 );
142 let response = request.send().await?;
143
144 response
146 .images()
147 .first()
148 .ok_or_else(|| CloudProviderError::RequestError("Cannot find image id".into()))?
149 .image_id
150 .clone()
151 .ok_or_else(|| {
152 CloudProviderError::UnexpectedResponse(
153 "Received image description without id".into(),
154 )
155 })
156 }
157
158 async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
160 let request = client
162 .create_security_group()
163 .group_name(&self.settings.testbed_id)
164 .description("Allow all traffic (used for benchmarks).");
165
166 let response = request.send().await;
167 Self::check_but_ignore_duplicates(response)?;
168
169 for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
171 let mut request = client
172 .authorize_security_group_ingress()
173 .group_name(&self.settings.testbed_id)
174 .ip_protocol(protocol)
175 .cidr_ip("0.0.0.0/0"); if protocol == "icmp" || protocol == "icmpv6" {
177 request = request.from_port(-1).to_port(-1);
178 } else {
179 request = request.from_port(0).to_port(65535);
180 }
181
182 let response = request.send().await;
183 Self::check_but_ignore_duplicates(response)?;
184 }
185 Ok(())
186 }
187
188 fn nvme_mount_command(&self) -> Vec<String> {
190 const DRIVE: &str = "nvme1n1";
191 let directory = self.settings.working_dir.display();
192 vec![
193 format!("(sudo mkfs.ext4 -E nodiscard /dev/{DRIVE} || true)"),
194 format!("(sudo mount /dev/{DRIVE} {directory} || true)"),
195 format!("sudo chmod 777 -R {directory}"),
196 ]
197 }
198
199 async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
201 let client = match self
204 .settings
205 .regions
206 .first()
207 .and_then(|x| self.clients.get(x))
208 {
209 Some(client) => client,
210 None => return Ok(false),
211 };
212
213 let request = client
215 .describe_instance_types()
216 .instance_types(self.settings.specs.as_str().into());
217
218 let response = request.send().await?;
220
221 if let Some(info) = response.instance_types().first()
223 && let Some(info) = info.instance_storage_info()
224 && info.nvme_support() == Some(&EphemeralNvmeSupport::Required)
225 {
226 return Ok(true);
227 }
228 Ok(false)
229 }
230}
231
232#[async_trait::async_trait]
233impl ServerProviderClient for AwsClient {
234 const USERNAME: &'static str = "ubuntu";
235
236 async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>> {
237 let filter = Filter::builder()
238 .name("tag:Name")
239 .values(self.settings.testbed_id.clone())
240 .build();
241
242 let mut instances = Vec::new();
243 for (region, client) in &self.clients {
244 let request = client.describe_instances().filters(filter.clone());
245 for reservation in request.send().await?.reservations() {
246 for instance in reservation.instances() {
247 instances.push(self.make_instance(region.clone(), instance));
248 }
249 }
250 }
251
252 Ok(instances)
253 }
254
255 async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
256 where
257 I: Iterator<Item = &'a Instance> + Send,
258 {
259 let mut instance_ids = HashMap::new();
260 for instance in instances {
261 instance_ids
262 .entry(&instance.region)
263 .or_insert_with(Vec::new)
264 .push(instance.id.clone());
265 }
266
267 for (region, client) in &self.clients {
268 let ids = instance_ids.remove(®ion.to_string());
269 if ids.is_some() {
270 client
271 .start_instances()
272 .set_instance_ids(ids)
273 .send()
274 .await?;
275 }
276 }
277 Ok(())
278 }
279
280 async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
281 where
282 I: Iterator<Item = &'a Instance> + Send,
283 {
284 let mut instance_ids = HashMap::new();
285 for instance in instances {
286 instance_ids
287 .entry(&instance.region)
288 .or_insert_with(Vec::new)
289 .push(instance.id.clone());
290 }
291
292 for (region, client) in &self.clients {
293 let ids = instance_ids.remove(®ion.to_string());
294 if ids.is_some() {
295 client.stop_instances().set_instance_ids(ids).send().await?;
296 }
297 }
298 Ok(())
299 }
300
301 async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
302 where
303 S: Into<String> + Serialize + Send,
304 {
305 let region = region.into();
306 let testbed_id = &self.settings.testbed_id;
307
308 let client = self.clients.get(®ion).ok_or_else(|| {
309 CloudProviderError::RequestError(format!("Undefined region {region:?}"))
310 })?;
311
312 self.create_security_group(client).await?;
314
315 let image_id = self.find_image_id(client).await?;
317
318 let tags = TagSpecification::builder()
320 .resource_type(ResourceType::Instance)
321 .tags(Tag::builder().key("Name").value(testbed_id).build())
322 .build();
323
324 let storage = BlockDeviceMapping::builder()
325 .device_name("/dev/sda1")
326 .ebs(
327 EbsBlockDevice::builder()
328 .delete_on_termination(true)
329 .volume_size(500)
330 .volume_type(VolumeType::Gp2)
331 .build(),
332 )
333 .build();
334
335 let request = client
336 .run_instances()
337 .image_id(image_id)
338 .instance_type(self.settings.specs.as_str().into())
339 .key_name(testbed_id)
340 .min_count(1)
341 .max_count(1)
342 .security_groups(&self.settings.testbed_id)
343 .block_device_mappings(storage)
344 .tag_specifications(tags);
345
346 let response = request.send().await?;
347 let instance = response
348 .instances()
349 .first()
350 .expect("AWS instances list should contain instances");
351
352 Ok(self.make_instance(region, instance))
353 }
354
355 async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
356 let client = self.clients.get(&instance.region).ok_or_else(|| {
357 CloudProviderError::RequestError(format!("Undefined region {:?}", instance.region))
358 })?;
359
360 client
361 .terminate_instances()
362 .set_instance_ids(Some(vec![instance.id.clone()]))
363 .send()
364 .await?;
365
366 Ok(())
367 }
368
369 async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
370 for client in self.clients.values() {
371 let request = client
372 .import_key_pair()
373 .key_name(&self.settings.testbed_id)
374 .public_key_material(Blob::new::<String>(public_key.clone()));
375
376 let response = request.send().await;
377 Self::check_but_ignore_duplicates(response)?;
378 }
379 Ok(())
380 }
381
382 async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
383 if self.check_nvme_support().await? {
384 Ok(self.nvme_mount_command())
385 } else {
386 Ok(Vec::new())
387 }
388 }
389}