sui_aws_orchestrator/
settings.rs1use std::{
5 env,
6 fmt::Display,
7 fs::{self},
8 path::{Path, PathBuf},
9};
10
11use reqwest::Url;
12use serde::{Deserialize, Deserializer, de::Error};
13
14use crate::{
15 client::Instance,
16 error::{SettingsError, SettingsResult},
17};
18
19#[derive(Deserialize, Clone)]
21pub struct Repository {
22 #[serde(deserialize_with = "parse_url")]
24 pub url: Url,
25 pub commit: String,
27}
28
29fn parse_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
30where
31 D: Deserializer<'de>,
32{
33 let s: &str = Deserialize::deserialize(deserializer)?;
34 let url = Url::parse(s).map_err(D::Error::custom)?;
35
36 match url.path_segments().map(|x| x.count() >= 2) {
37 None | Some(false) => Err(D::Error::custom(SettingsError::MalformedRepositoryUrl(url))),
38 _ => Ok(url),
39 }
40}
41
42#[derive(Deserialize, Clone)]
44pub enum CloudProvider {
45 #[serde(alias = "aws")]
46 Aws,
47}
48
49#[derive(Deserialize, Clone)]
51pub struct Settings {
52 pub testbed_id: String,
55 pub cloud_provider: CloudProvider,
57 pub token_file: PathBuf,
59 pub ssh_private_key_file: PathBuf,
61 pub ssh_public_key_file: Option<PathBuf>,
64 pub regions: Vec<String>,
66 pub specs: String,
69 pub repository: Repository,
71 #[serde(default = "default_working_dir")]
73 pub working_dir: PathBuf,
74 #[serde(default = "default_results_dir")]
76 pub results_dir: PathBuf,
77 #[serde(default = "default_logs_dir")]
79 pub logs_dir: PathBuf,
80}
81
82fn default_working_dir() -> PathBuf {
83 ["~/", "working_dir"].iter().collect()
84}
85
86fn default_results_dir() -> PathBuf {
87 ["./", "results"].iter().collect()
88}
89
90fn default_logs_dir() -> PathBuf {
91 ["./", "logs"].iter().collect()
92}
93
94impl Settings {
95 pub fn load<P>(path: P) -> SettingsResult<Self>
97 where
98 P: AsRef<Path> + Display + Clone,
99 {
100 let reader = || -> Result<Self, std::io::Error> {
101 let data = fs::read(path.clone())?;
102 let data = resolve_env(std::str::from_utf8(&data).unwrap());
103 let settings: Settings = serde_json::from_slice(data.as_bytes())?;
104
105 fs::create_dir_all(&settings.results_dir)?;
106 fs::create_dir_all(&settings.logs_dir)?;
107
108 Ok(settings)
109 };
110
111 reader().map_err(|e| SettingsError::InvalidSettings {
112 file: path.to_string(),
113 message: e.to_string(),
114 })
115 }
116
117 pub fn repository_name(&self) -> String {
119 self.repository
120 .url
121 .path_segments()
122 .expect("Url should already be checked when loading settings")
123 .collect::<Vec<_>>()[1]
124 .split('.')
125 .next()
126 .unwrap()
127 .to_string()
128 }
129
130 pub fn load_token(&self) -> SettingsResult<String> {
132 match fs::read_to_string(&self.token_file) {
133 Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
134 Err(e) => Err(SettingsError::InvalidTokenFile {
135 file: self.token_file.display().to_string(),
136 message: e.to_string(),
137 }),
138 }
139 }
140
141 pub fn load_ssh_public_key(&self) -> SettingsResult<String> {
143 let ssh_public_key_file = self.ssh_public_key_file.clone().unwrap_or_else(|| {
144 let mut private = self.ssh_private_key_file.clone();
145 private.set_extension("pub");
146 private
147 });
148 match fs::read_to_string(&ssh_public_key_file) {
149 Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
150 Err(e) => Err(SettingsError::InvalidSshPublicKeyFile {
151 file: ssh_public_key_file.display().to_string(),
152 message: e.to_string(),
153 }),
154 }
155 }
156
157 pub fn filter_instances(&self, instance: &Instance) -> bool {
159 self.regions.contains(&instance.region)
160 && instance.specs.to_lowercase().replace('.', "")
161 == self.specs.to_lowercase().replace('.', "")
162 }
163
164 #[cfg(test)]
166 pub fn number_of_regions(&self) -> usize {
167 self.regions.len()
168 }
169
170 #[cfg(test)]
172 pub fn new_for_test() -> Self {
173 let mut path = tempfile::tempdir().unwrap().keep();
175 path.push("test_public_key.pub");
176 let public_key = "This is a fake public key for tests";
177 fs::write(&path, public_key).unwrap();
178
179 Self {
181 testbed_id: "testbed".into(),
182 cloud_provider: CloudProvider::Aws,
183 token_file: "/path/to/token/file".into(),
184 ssh_private_key_file: "/path/to/private/key/file".into(),
185 ssh_public_key_file: Some(path),
186 regions: vec!["London".into(), "New York".into()],
187 specs: "small".into(),
188 repository: Repository {
189 url: Url::parse("https://example.net/author/repo").unwrap(),
190 commit: "main".into(),
191 },
192 working_dir: "/path/to/working_dir".into(),
193 results_dir: "results".into(),
194 logs_dir: "logs".into(),
195 }
196 }
197}
198
199fn resolve_env(s: &str) -> String {
201 let mut s = s.to_string();
202 for (name, value) in env::vars() {
203 s = s.replace(&format!("${{{}}}", name), &value);
204 }
205 if s.contains("${") {
206 eprintln!("settings.json:\n{}\n", s);
207 panic!("Unresolved env variables in the settings.json");
208 }
209 s
210}
211
212#[cfg(test)]
213mod test {
214 use reqwest::Url;
215
216 use crate::settings::Settings;
217
218 #[test]
219 fn repository_name() {
220 let mut settings = Settings::new_for_test();
221 settings.repository.url = Url::parse("https://example.com/author/name").unwrap();
222 assert_eq!(settings.repository_name(), "name");
223 }
224}