sui_aws_orchestrator/
settings.rsuse std::{
env,
fmt::Display,
fs::{self},
path::{Path, PathBuf},
};
use reqwest::Url;
use serde::{de::Error, Deserialize, Deserializer};
use crate::{
client::Instance,
error::{SettingsError, SettingsResult},
};
#[derive(Deserialize, Clone)]
pub struct Repository {
#[serde(deserialize_with = "parse_url")]
pub url: Url,
pub commit: String,
}
fn parse_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
where
D: Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let url = Url::parse(s).map_err(D::Error::custom)?;
match url.path_segments().map(|x| x.count() >= 2) {
None | Some(false) => Err(D::Error::custom(SettingsError::MalformedRepositoryUrl(url))),
_ => Ok(url),
}
}
#[derive(Deserialize, Clone)]
pub enum CloudProvider {
#[serde(alias = "aws")]
Aws,
}
#[derive(Deserialize, Clone)]
pub struct Settings {
pub testbed_id: String,
pub cloud_provider: CloudProvider,
pub token_file: PathBuf,
pub ssh_private_key_file: PathBuf,
pub ssh_public_key_file: Option<PathBuf>,
pub regions: Vec<String>,
pub specs: String,
pub repository: Repository,
#[serde(default = "default_working_dir")]
pub working_dir: PathBuf,
#[serde(default = "default_results_dir")]
pub results_dir: PathBuf,
#[serde(default = "default_logs_dir")]
pub logs_dir: PathBuf,
}
fn default_working_dir() -> PathBuf {
["~/", "working_dir"].iter().collect()
}
fn default_results_dir() -> PathBuf {
["./", "results"].iter().collect()
}
fn default_logs_dir() -> PathBuf {
["./", "logs"].iter().collect()
}
impl Settings {
pub fn load<P>(path: P) -> SettingsResult<Self>
where
P: AsRef<Path> + Display + Clone,
{
let reader = || -> Result<Self, std::io::Error> {
let data = fs::read(path.clone())?;
let data = resolve_env(std::str::from_utf8(&data).unwrap());
let settings: Settings = serde_json::from_slice(data.as_bytes())?;
fs::create_dir_all(&settings.results_dir)?;
fs::create_dir_all(&settings.logs_dir)?;
Ok(settings)
};
reader().map_err(|e| SettingsError::InvalidSettings {
file: path.to_string(),
message: e.to_string(),
})
}
pub fn repository_name(&self) -> String {
self.repository
.url
.path_segments()
.expect("Url should already be checked when loading settings")
.collect::<Vec<_>>()[1]
.split('.')
.next()
.unwrap()
.to_string()
}
pub fn load_token(&self) -> SettingsResult<String> {
match fs::read_to_string(&self.token_file) {
Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
Err(e) => Err(SettingsError::InvalidTokenFile {
file: self.token_file.display().to_string(),
message: e.to_string(),
}),
}
}
pub fn load_ssh_public_key(&self) -> SettingsResult<String> {
let ssh_public_key_file = self.ssh_public_key_file.clone().unwrap_or_else(|| {
let mut private = self.ssh_private_key_file.clone();
private.set_extension("pub");
private
});
match fs::read_to_string(&ssh_public_key_file) {
Ok(token) => Ok(token.trim_end_matches('\n').to_string()),
Err(e) => Err(SettingsError::InvalidSshPublicKeyFile {
file: ssh_public_key_file.display().to_string(),
message: e.to_string(),
}),
}
}
pub fn filter_instances(&self, instance: &Instance) -> bool {
self.regions.contains(&instance.region)
&& instance.specs.to_lowercase().replace('.', "")
== self.specs.to_lowercase().replace('.', "")
}
#[cfg(test)]
pub fn number_of_regions(&self) -> usize {
self.regions.len()
}
#[cfg(test)]
pub fn new_for_test() -> Self {
let mut path = tempfile::tempdir().unwrap().keep();
path.push("test_public_key.pub");
let public_key = "This is a fake public key for tests";
fs::write(&path, public_key).unwrap();
Self {
testbed_id: "testbed".into(),
cloud_provider: CloudProvider::Aws,
token_file: "/path/to/token/file".into(),
ssh_private_key_file: "/path/to/private/key/file".into(),
ssh_public_key_file: Some(path),
regions: vec!["London".into(), "New York".into()],
specs: "small".into(),
repository: Repository {
url: Url::parse("https://example.net/author/repo").unwrap(),
commit: "main".into(),
},
working_dir: "/path/to/working_dir".into(),
results_dir: "results".into(),
logs_dir: "logs".into(),
}
}
}
fn resolve_env(s: &str) -> String {
let mut s = s.to_string();
for (name, value) in env::vars() {
s = s.replace(&format!("${{{}}}", name), &value);
}
if s.contains("${") {
eprintln!("settings.json:\n{}\n", s);
panic!("Unresolved env variables in the settings.json");
}
s
}
#[cfg(test)]
mod test {
use reqwest::Url;
use crate::settings::Settings;
#[test]
fn repository_name() {
let mut settings = Settings::new_for_test();
settings.repository.url = Url::parse("https://example.com/author/name").unwrap();
assert_eq!(settings.repository_name(), "name");
}
}