sui_aws_orchestrator/client/
aws.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    collections::HashMap,
6    fmt::{Debug, Display},
7};
8
9use aws_runtime::env_config::file::{EnvConfigFileKind, EnvConfigFiles};
10use aws_sdk_ec2::primitives::Blob;
11use aws_sdk_ec2::{
12    config::Region,
13    types::{
14        BlockDeviceMapping, EbsBlockDevice, EphemeralNvmeSupport, Filter, ResourceType, Tag,
15        TagSpecification, VolumeType,
16    },
17};
18use aws_smithy_runtime_api::client::result::SdkError;
19use serde::Serialize;
20
21use crate::{
22    error::{CloudProviderError, CloudProviderResult},
23    settings::Settings,
24};
25
26use super::{Instance, ServerProviderClient};
27
28// Make a request error from an AWS error message.
29impl<T> From<SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>
30    for CloudProviderError
31where
32    T: Debug + std::error::Error + Send + Sync + 'static,
33{
34    fn from(e: SdkError<T, aws_smithy_runtime_api::client::orchestrator::HttpResponse>) -> Self {
35        Self::RequestError(format!("{:?}", e.into_source()))
36    }
37}
38
39/// A AWS client.
40pub struct AwsClient {
41    /// The settings of the testbed.
42    settings: Settings,
43    /// A list of clients, one per AWS region.
44    clients: HashMap<String, aws_sdk_ec2::Client>,
45}
46
47impl Display for AwsClient {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "AWS EC2 client v{}", aws_sdk_ec2::meta::PKG_VERSION)
50    }
51}
52
53impl AwsClient {
54    const OS_IMAGE: &'static str =
55        "Canonical, Ubuntu, 22.04 LTS, amd64 jammy image build on 2023-02-16";
56
57    /// Make a new AWS client.
58    pub async fn new(settings: Settings) -> Self {
59        let profile_files = EnvConfigFiles::builder()
60            .with_file(EnvConfigFileKind::Credentials, &settings.token_file)
61            .with_contents(EnvConfigFileKind::Config, "[default]\noutput=json")
62            .build();
63
64        let mut clients = HashMap::new();
65        for region in settings.regions.clone() {
66            let sdk_config = aws_config::from_env()
67                .region(Region::new(region.clone()))
68                .profile_files(profile_files.clone())
69                .load()
70                .await;
71            let client = aws_sdk_ec2::Client::new(&sdk_config);
72            clients.insert(region, client);
73        }
74
75        Self { settings, clients }
76    }
77
78    /// Parse an AWS response and ignore errors if they mean a request is a duplicate.
79    fn check_but_ignore_duplicates<T, E>(
80        response: Result<
81            T,
82            SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
83        >,
84    ) -> CloudProviderResult<()>
85    where
86        E: Debug + std::error::Error + Send + Sync + 'static,
87    {
88        if let Err(e) = response {
89            let error_message = format!("{e:?}");
90            if !error_message.to_lowercase().contains("duplicate") {
91                return Err(e.into());
92            }
93        }
94        Ok(())
95    }
96
97    /// Convert an AWS instance into an orchestrator instance (used in the rest of the codebase).
98    fn make_instance(
99        &self,
100        region: String,
101        aws_instance: &aws_sdk_ec2::types::Instance,
102    ) -> Instance {
103        Instance {
104            id: aws_instance
105                .instance_id()
106                .expect("AWS instance should have an id")
107                .into(),
108            region,
109            main_ip: aws_instance
110                .public_ip_address()
111                .unwrap_or("0.0.0.0") // Stopped instances do not have an ip address.
112                .parse()
113                .expect("AWS instance should have a valid ip"),
114            tags: vec![self.settings.testbed_id.clone()],
115            specs: format!(
116                "{:?}",
117                aws_instance
118                    .instance_type()
119                    .expect("AWS instance should have a type")
120            ),
121            status: format!(
122                "{:?}",
123                aws_instance
124                    .state()
125                    .expect("AWS instance should have a state")
126                    .name()
127                    .expect("AWS status should have a name")
128            ),
129        }
130    }
131
132    /// Query the image id determining the os of the instances.
133    /// NOTE: The image id changes depending on the region.
134    async fn find_image_id(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<String> {
135        // Query all images that match the description.
136        let request = client.describe_images().filters(
137            Filter::builder()
138                .name("description")
139                .values(Self::OS_IMAGE)
140                .build(),
141        );
142        let response = request.send().await?;
143
144        // Parse the response to select the first returned image id.
145        response
146            .images()
147            .first()
148            .ok_or_else(|| CloudProviderError::RequestError("Cannot find image id".into()))?
149            .image_id
150            .clone()
151            .ok_or_else(|| {
152                CloudProviderError::UnexpectedResponse(
153                    "Received image description without id".into(),
154                )
155            })
156    }
157
158    /// Create a new security group for the instance (if it doesn't already exist).
159    async fn create_security_group(&self, client: &aws_sdk_ec2::Client) -> CloudProviderResult<()> {
160        // Create a security group (if it doesn't already exist).
161        let request = client
162            .create_security_group()
163            .group_name(&self.settings.testbed_id)
164            .description("Allow all traffic (used for benchmarks).");
165
166        let response = request.send().await;
167        Self::check_but_ignore_duplicates(response)?;
168
169        // Authorize all traffic on the security group.
170        for protocol in ["tcp", "udp", "icmp", "icmpv6"] {
171            let mut request = client
172                .authorize_security_group_ingress()
173                .group_name(&self.settings.testbed_id)
174                .ip_protocol(protocol)
175                .cidr_ip("0.0.0.0/0"); // todo - allowing 0.0.0.0 seem a bit wild?
176            if protocol == "icmp" || protocol == "icmpv6" {
177                request = request.from_port(-1).to_port(-1);
178            } else {
179                request = request.from_port(0).to_port(65535);
180            }
181
182            let response = request.send().await;
183            Self::check_but_ignore_duplicates(response)?;
184        }
185        Ok(())
186    }
187
188    /// Return the command to mount the first (standard) NVMe drive.
189    fn nvme_mount_command(&self) -> Vec<String> {
190        const DRIVE: &str = "nvme1n1";
191        let directory = self.settings.working_dir.display();
192        vec![
193            format!("(sudo mkfs.ext4 -E nodiscard /dev/{DRIVE} || true)"),
194            format!("(sudo mount /dev/{DRIVE} {directory} || true)"),
195            format!("sudo chmod 777 -R {directory}"),
196        ]
197    }
198
199    /// Check whether the instance type specified in the settings supports NVMe drives.
200    async fn check_nvme_support(&self) -> CloudProviderResult<bool> {
201        // Get the client for the first region. A given instance type should either have NVMe support
202        // in all regions or in none.
203        let client = match self
204            .settings
205            .regions
206            .first()
207            .and_then(|x| self.clients.get(x))
208        {
209            Some(client) => client,
210            None => return Ok(false),
211        };
212
213        // Request storage details for the instance type specified in the settings.
214        let request = client
215            .describe_instance_types()
216            .instance_types(self.settings.specs.as_str().into());
217
218        // Send the request.
219        let response = request.send().await?;
220
221        // Return true if the response contains references to NVMe drives.
222        if let Some(info) = response.instance_types().first()
223            && let Some(info) = info.instance_storage_info()
224            && info.nvme_support() == Some(&EphemeralNvmeSupport::Required)
225        {
226            return Ok(true);
227        }
228        Ok(false)
229    }
230}
231
232#[async_trait::async_trait]
233impl ServerProviderClient for AwsClient {
234    const USERNAME: &'static str = "ubuntu";
235
236    async fn list_instances(&self) -> CloudProviderResult<Vec<Instance>> {
237        let filter = Filter::builder()
238            .name("tag:Name")
239            .values(self.settings.testbed_id.clone())
240            .build();
241
242        let mut instances = Vec::new();
243        for (region, client) in &self.clients {
244            let request = client.describe_instances().filters(filter.clone());
245            for reservation in request.send().await?.reservations() {
246                for instance in reservation.instances() {
247                    instances.push(self.make_instance(region.clone(), instance));
248                }
249            }
250        }
251
252        Ok(instances)
253    }
254
255    async fn start_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
256    where
257        I: Iterator<Item = &'a Instance> + Send,
258    {
259        let mut instance_ids = HashMap::new();
260        for instance in instances {
261            instance_ids
262                .entry(&instance.region)
263                .or_insert_with(Vec::new)
264                .push(instance.id.clone());
265        }
266
267        for (region, client) in &self.clients {
268            let ids = instance_ids.remove(&region.to_string());
269            if ids.is_some() {
270                client
271                    .start_instances()
272                    .set_instance_ids(ids)
273                    .send()
274                    .await?;
275            }
276        }
277        Ok(())
278    }
279
280    async fn stop_instances<'a, I>(&self, instances: I) -> CloudProviderResult<()>
281    where
282        I: Iterator<Item = &'a Instance> + Send,
283    {
284        let mut instance_ids = HashMap::new();
285        for instance in instances {
286            instance_ids
287                .entry(&instance.region)
288                .or_insert_with(Vec::new)
289                .push(instance.id.clone());
290        }
291
292        for (region, client) in &self.clients {
293            let ids = instance_ids.remove(&region.to_string());
294            if ids.is_some() {
295                client.stop_instances().set_instance_ids(ids).send().await?;
296            }
297        }
298        Ok(())
299    }
300
301    async fn create_instance<S>(&self, region: S) -> CloudProviderResult<Instance>
302    where
303        S: Into<String> + Serialize + Send,
304    {
305        let region = region.into();
306        let testbed_id = &self.settings.testbed_id;
307
308        let client = self.clients.get(&region).ok_or_else(|| {
309            CloudProviderError::RequestError(format!("Undefined region {region:?}"))
310        })?;
311
312        // Create a security group (if needed).
313        self.create_security_group(client).await?;
314
315        // Query the image id.
316        let image_id = self.find_image_id(client).await?;
317
318        // Create a new instance.
319        let tags = TagSpecification::builder()
320            .resource_type(ResourceType::Instance)
321            .tags(Tag::builder().key("Name").value(testbed_id).build())
322            .build();
323
324        let storage = BlockDeviceMapping::builder()
325            .device_name("/dev/sda1")
326            .ebs(
327                EbsBlockDevice::builder()
328                    .delete_on_termination(true)
329                    .volume_size(500)
330                    .volume_type(VolumeType::Gp2)
331                    .build(),
332            )
333            .build();
334
335        let request = client
336            .run_instances()
337            .image_id(image_id)
338            .instance_type(self.settings.specs.as_str().into())
339            .key_name(testbed_id)
340            .min_count(1)
341            .max_count(1)
342            .security_groups(&self.settings.testbed_id)
343            .block_device_mappings(storage)
344            .tag_specifications(tags);
345
346        let response = request.send().await?;
347        let instance = response
348            .instances()
349            .first()
350            .expect("AWS instances list should contain instances");
351
352        Ok(self.make_instance(region, instance))
353    }
354
355    async fn delete_instance(&self, instance: Instance) -> CloudProviderResult<()> {
356        let client = self.clients.get(&instance.region).ok_or_else(|| {
357            CloudProviderError::RequestError(format!("Undefined region {:?}", instance.region))
358        })?;
359
360        client
361            .terminate_instances()
362            .set_instance_ids(Some(vec![instance.id.clone()]))
363            .send()
364            .await?;
365
366        Ok(())
367    }
368
369    async fn register_ssh_public_key(&self, public_key: String) -> CloudProviderResult<()> {
370        for client in self.clients.values() {
371            let request = client
372                .import_key_pair()
373                .key_name(&self.settings.testbed_id)
374                .public_key_material(Blob::new::<String>(public_key.clone()));
375
376            let response = request.send().await;
377            Self::check_but_ignore_duplicates(response)?;
378        }
379        Ok(())
380    }
381
382    async fn instance_setup_commands(&self) -> CloudProviderResult<Vec<String>> {
383        if self.check_nvme_support().await? {
384            Ok(self.nvme_mount_command())
385        } else {
386            Ok(Vec::new())
387        }
388    }
389}