sui_aws_orchestrator/
settings.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    env,
6    fmt::Display,
7    fs::{self},
8    path::{Path, PathBuf},
9};
10
11use reqwest::Url;
12use serde::{Deserialize, Deserializer, de::Error};
13
14use crate::{
15    client::Instance,
16    error::{SettingsError, SettingsResult},
17};
18
19/// The git repository holding the codebase.
20#[derive(Deserialize, Clone)]
21pub struct Repository {
22    /// The url of the repository.
23    #[serde(deserialize_with = "parse_url")]
24    pub url: Url,
25    /// The commit (or branch name) to deploy.
26    pub commit: String,
27}
28
29fn parse_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
30where
31    D: Deserializer<'de>,
32{
33    let s: &str = Deserialize::deserialize(deserializer)?;
34    let url = Url::parse(s).map_err(D::Error::custom)?;
35
36    match url.path_segments().map(|x| x.count() >= 2) {
37        None | Some(false) => Err(D::Error::custom(SettingsError::MalformedRepositoryUrl(url))),
38        _ => Ok(url),
39    }
40}
41
42/// The list of supported cloud providers.
43#[derive(Deserialize, Clone)]
44pub enum CloudProvider {
45    #[serde(alias = "aws")]
46    Aws,
47}
48
49/// The testbed settings. Those are topically specified in a file.
50#[derive(Deserialize, Clone)]
51pub struct Settings {
52    /// The testbed unique id. This allows multiple users to run concurrent testbeds on the
53    /// same cloud provider's account without interference with each others.
54    pub testbed_id: String,
55    /// The cloud provider hosting the testbed.
56    pub cloud_provider: CloudProvider,
57    /// The path to the secret token for authentication with the cloud provider.
58    pub token_file: PathBuf,
59    /// The ssh private key to access the instances.
60    pub ssh_private_key_file: PathBuf,
61    /// The corresponding ssh public key registered on the instances. If not specified. the
62    /// public key defaults the same path as the private key with an added extension 'pub'.
63    pub ssh_public_key_file: Option<PathBuf>,
64    /// The list of cloud provider regions to deploy the testbed.
65    pub regions: Vec<String>,
66    /// The specs of the instances to deploy. Those are dependent on the cloud provider, e.g.,
67    /// specifying 't3.medium' creates instances with 2 vCPU and 4GBo of ram on AWS.
68    pub specs: String,
69    /// The details of the git reposit to deploy.
70    pub repository: Repository,
71    /// The working directory on the remote instance (containing all configuration files).
72    #[serde(default = "default_working_dir")]
73    pub working_dir: PathBuf,
74    /// The directory (on the local machine) where to save benchmarks measurements.
75    #[serde(default = "default_results_dir")]
76    pub results_dir: PathBuf,
77    /// The directory (on the local machine) where to download logs files from the instances.
78    #[serde(default = "default_logs_dir")]
79    pub logs_dir: PathBuf,
80}
81
82fn default_working_dir() -> PathBuf {
83    ["~/", "working_dir"].iter().collect()
84}
85
86fn default_results_dir() -> PathBuf {
87    ["./", "results"].iter().collect()
88}
89
90fn default_logs_dir() -> PathBuf {
91    ["./", "logs"].iter().collect()
92}
93
94impl Settings {
95    /// Load the settings from a json file.
96    pub fn load<P>(path: P) -> SettingsResult<Self>
97    where
98        P: AsRef<Path> + Display + Clone,
99    {
100        let reader = || -> Result<Self, std::io::Error> {
101            let data = fs::read(path.clone())?;
102            let data = resolve_env(std::str::from_utf8(&data).unwrap());
103            let settings: Settings = serde_json::from_slice(data.as_bytes())?;
104
105            fs::create_dir_all(&settings.results_dir)?;
106            fs::create_dir_all(&settings.logs_dir)?;
107
108            Ok(settings)
109        };
110
111        reader().map_err(|e| SettingsError::InvalidSettings {
112            file: path.to_string(),
113            message: e.to_string(),
114        })
115    }
116
117    /// Get the name of the repository (from its url).
118    pub fn repository_name(&self) -> String {
119        self.repository
120            .url
121            .path_segments()
122            .expect("Url should already be checked when loading settings")
123            .collect::<Vec<_>>()[1]
124            .split('.')
125            .next()
126            .unwrap()
127            .to_string()
128    }
129
130    /// Load the secret token to authenticate with the cloud provider.
131    pub fn load_token(&self) -> SettingsResult<String> {
132        match fs::read_to_string(&self.token_file) {
133            Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
134            Err(e) => Err(SettingsError::InvalidTokenFile {
135                file: self.token_file.display().to_string(),
136                message: e.to_string(),
137            }),
138        }
139    }
140
141    /// Load the ssh public key from file.
142    pub fn load_ssh_public_key(&self) -> SettingsResult<String> {
143        let ssh_public_key_file = self.ssh_public_key_file.clone().unwrap_or_else(|| {
144            let mut private = self.ssh_private_key_file.clone();
145            private.set_extension("pub");
146            private
147        });
148        match fs::read_to_string(&ssh_public_key_file) {
149            Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
150            Err(e) => Err(SettingsError::InvalidSshPublicKeyFile {
151                file: ssh_public_key_file.display().to_string(),
152                message: e.to_string(),
153            }),
154        }
155    }
156
157    /// Check whether the input instance matches the criteria described in the settings.
158    pub fn filter_instances(&self, instance: &Instance) -> bool {
159        self.regions.contains(&instance.region)
160            && instance.specs.to_lowercase().replace('.', "")
161                == self.specs.to_lowercase().replace('.', "")
162    }
163
164    /// The number of regions specified in the settings.
165    #[cfg(test)]
166    pub fn number_of_regions(&self) -> usize {
167        self.regions.len()
168    }
169
170    /// Test settings for unit tests.
171    #[cfg(test)]
172    pub fn new_for_test() -> Self {
173        // Create a temporary public key file.
174        let mut path = tempfile::tempdir().unwrap().keep();
175        path.push("test_public_key.pub");
176        let public_key = "This is a fake public key for tests";
177        fs::write(&path, public_key).unwrap();
178
179        // Return set settings.
180        Self {
181            testbed_id: "testbed".into(),
182            cloud_provider: CloudProvider::Aws,
183            token_file: "/path/to/token/file".into(),
184            ssh_private_key_file: "/path/to/private/key/file".into(),
185            ssh_public_key_file: Some(path),
186            regions: vec!["London".into(), "New York".into()],
187            specs: "small".into(),
188            repository: Repository {
189                url: Url::parse("https://example.net/author/repo").unwrap(),
190                commit: "main".into(),
191            },
192            working_dir: "/path/to/working_dir".into(),
193            results_dir: "results".into(),
194            logs_dir: "logs".into(),
195        }
196    }
197}
198
199// Resolves ${ENV} into it's value for each env variable.
200fn resolve_env(s: &str) -> String {
201    let mut s = s.to_string();
202    for (name, value) in env::vars() {
203        s = s.replace(&format!("${{{}}}", name), &value);
204    }
205    if s.contains("${") {
206        eprintln!("settings.json:\n{}\n", s);
207        panic!("Unresolved env variables in the settings.json");
208    }
209    s
210}
211
212#[cfg(test)]
213mod test {
214    use reqwest::Url;
215
216    use crate::settings::Settings;
217
218    #[test]
219    fn repository_name() {
220        let mut settings = Settings::new_for_test();
221        settings.repository.url = Url::parse("https://example.com/author/name").unwrap();
222        assert_eq!(settings.repository_name(), "name");
223    }
224}