1use async_trait::async_trait;
5use std::io::Write;
6use std::sync::Arc;
7use std::{
8 net::SocketAddr,
9 path::{Path, PathBuf},
10 time::Duration,
11};
12
13use futures::future::try_join_all;
14use russh::client::Msg;
15use russh::{client, Channel};
16use russh_keys::key;
17use tokio::task::JoinHandle;
18use tokio::time::sleep;
19
20use crate::{
21 client::Instance,
22 ensure,
23 error::{SshError, SshResult},
24};
25
26#[derive(PartialEq, Eq)]
27pub enum CommandStatus {
29 Running,
30 Terminated,
31}
32
33impl CommandStatus {
34 pub fn status(command_id: &str, text: &str) -> Self {
37 if text.contains(command_id) {
38 Self::Running
39 } else {
40 Self::Terminated
41 }
42 }
43}
44
45#[derive(Clone, Default)]
47pub struct CommandContext {
48 pub background: Option<String>,
51 pub path: Option<PathBuf>,
53 pub log_file: Option<PathBuf>,
55}
56
57impl CommandContext {
58 pub fn new() -> Self {
60 Self {
61 background: None,
62 path: None,
63 log_file: None,
64 }
65 }
66
67 pub fn run_background(mut self, id: String) -> Self {
69 self.background = Some(id);
70 self
71 }
72
73 pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
75 self.path = Some(path);
76 self
77 }
78
79 pub fn with_log_file(mut self, path: PathBuf) -> Self {
81 self.log_file = Some(path);
82 self
83 }
84
85 pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
87 let mut str = base_command.into();
88 if let Some(log_file) = &self.log_file {
89 str = format!("{str} |& tee {}", log_file.as_path().display());
90 }
91 if let Some(id) = &self.background {
92 str = format!("tmux new -d -s \"{id}\" \"{str}\"");
93 }
94 if let Some(exec_path) = &self.path {
95 str = format!("(cd {} && {str})", exec_path.as_path().display());
96 }
97 str
98 }
99}
100
101#[derive(Clone)]
102pub struct SshConnectionManager {
103 username: String,
105 private_key_file: PathBuf,
107 timeout: Option<Duration>,
109 retries: usize,
111}
112
113impl SshConnectionManager {
114 const RETRY_DELAY: Duration = Duration::from_secs(5);
116
117 pub fn new(username: String, private_key_file: PathBuf) -> Self {
119 Self {
120 username,
121 private_key_file,
122 timeout: None,
123 retries: 0,
124 }
125 }
126
127 pub fn with_timeout(mut self, timeout: Duration) -> Self {
129 self.timeout = Some(timeout);
130 self
131 }
132
133 pub fn with_retries(mut self, retries: usize) -> Self {
135 self.retries = retries;
136 self
137 }
138
139 pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
141 let mut error = None;
142 for _ in 0..self.retries + 1 {
143 match SshConnection::new(
144 address,
145 &self.username,
146 self.private_key_file.clone(),
147 self.timeout,
148 Some(self.retries),
149 )
150 .await
151 {
152 Ok(x) => return Ok(x),
153 Err(e) => error = Some(e),
154 }
155 sleep(Self::RETRY_DELAY).await;
156 }
157 Err(error.unwrap())
158 }
159
160 pub async fn execute<I, S>(
162 &self,
163 instances: I,
164 command: S,
165 context: CommandContext,
166 ) -> SshResult<Vec<(String, String)>>
167 where
168 I: IntoIterator<Item = Instance>,
169 S: Into<String> + Clone + Send + 'static,
170 {
171 let targets = instances
172 .into_iter()
173 .map(|instance| (instance, command.clone()));
174 self.execute_per_instance(targets, context).await
175 }
176
177 pub async fn execute_per_instance<I, S>(
179 &self,
180 instances: I,
181 context: CommandContext,
182 ) -> SshResult<Vec<(String, String)>>
183 where
184 I: IntoIterator<Item = (Instance, S)>,
185 S: Into<String> + Send + 'static,
186 {
187 let handles = self.run_per_instance(instances, context).await;
188
189 try_join_all(handles)
190 .await
191 .unwrap()
192 .into_iter()
193 .collect::<SshResult<_>>()
194 }
195
196 async fn run_per_instance<I, S>(
197 &self,
198 instances: I,
199 context: CommandContext,
200 ) -> Vec<JoinHandle<SshResult<(String, String)>>>
201 where
202 I: IntoIterator<Item = (Instance, S)>,
203 S: Into<String> + Send + 'static,
204 {
205 instances
206 .into_iter()
207 .map(|(instance, command)| {
208 let ssh_manager = self.clone();
209 let context = context.clone();
210
211 tokio::spawn(async move {
212 let connection = ssh_manager.connect(instance.ssh_address()).await?;
213 connection.execute(context.apply(command)).await
215 })
216 })
217 .collect::<Vec<_>>()
218 }
219
220 pub async fn wait_for_command<I>(
222 &self,
223 instances: I,
224 command_id: &str,
225 status: CommandStatus,
226 ) -> SshResult<()>
227 where
228 I: IntoIterator<Item = Instance> + Clone,
229 {
230 loop {
231 sleep(Self::RETRY_DELAY).await;
232
233 let result = self
234 .execute(
235 instances.clone(),
236 "(tmux ls || true)",
237 CommandContext::default(),
238 )
239 .await?;
240 if result
241 .iter()
242 .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
243 {
244 break;
245 }
246 }
247 Ok(())
248 }
249
250 pub async fn wait_for_success<I, S>(&self, instances: I)
251 where
252 I: IntoIterator<Item = (Instance, S)> + Clone,
253 S: Into<String> + Send + 'static + Clone,
254 {
255 loop {
256 sleep(Self::RETRY_DELAY).await;
257
258 if self
259 .execute_per_instance(instances.clone(), CommandContext::default())
260 .await
261 .is_ok()
262 {
263 break;
264 }
265 }
266 }
267
268 pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
270 where
271 I: IntoIterator<Item = Instance>,
272 {
273 let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
274 let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
275 self.execute_per_instance(targets, CommandContext::default())
276 .await?;
277 Ok(())
278 }
279}
280
281struct Session {}
282
283#[async_trait]
284impl client::Handler for Session {
285 type Error = russh::Error;
286
287 async fn check_server_key(
288 self,
289 _server_public_key: &key::PublicKey,
290 ) -> Result<(Self, bool), Self::Error> {
291 Ok((self, true))
292 }
293}
294
295pub struct SshConnection {
297 session: client::Handle<Session>,
299 address: SocketAddr,
301 retries: usize,
303}
304
305impl SshConnection {
306 const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
308
309 pub async fn new<P: AsRef<Path>>(
311 address: SocketAddr,
312 username: &str,
313 private_key_file: P,
314 inactivity_timeout: Option<Duration>,
315 retries: Option<usize>,
316 ) -> SshResult<Self> {
317 let key = russh_keys::load_secret_key(private_key_file, None)
318 .map_err(|error| SshError::PrivateKeyError { address, error })?;
319
320 let config = client::Config {
321 inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
322 ..<_>::default()
323 };
324
325 let mut session = client::connect(Arc::new(config), address, Session {})
326 .await
327 .map_err(|error| SshError::ConnectionError { address, error })?;
328
329 let _auth_res = session
330 .authenticate_publickey(username, Arc::new(key))
331 .await
332 .map_err(|error| SshError::SessionError { address, error })?;
333
334 Ok(Self {
335 session,
336 address,
337 retries: retries.unwrap_or_default(),
338 })
339 }
340
341 fn make_session_error(&self, error: russh::Error) -> SshError {
343 SshError::SessionError {
344 address: self.address,
345 error,
346 }
347 }
348
349 pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
351 let mut error = None;
352 for _ in 0..self.retries + 1 {
353 let channel = match self.session.channel_open_session().await {
354 Ok(x) => x,
355 Err(e) => {
356 error = Some(self.make_session_error(e));
357 continue;
358 }
359 };
360 match self.execute_impl(channel, command.clone()).await {
361 r @ Ok(..) => return r,
362 Err(e) => error = Some(e),
363 }
364 }
365 Err(error.unwrap())
366 }
367
368 async fn execute_impl(
370 &self,
371 mut channel: Channel<Msg>,
372 command: String,
373 ) -> SshResult<(String, String)> {
374 channel
375 .exec(true, command)
376 .await
377 .map_err(|e| self.make_session_error(e))?;
378
379 let mut output = Vec::new();
380 let mut exit_code = None;
381
382 while let Some(msg) = channel.wait().await {
383 match msg {
384 russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
385 russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
386 _ => {}
387 }
388 }
389
390 channel
391 .close()
392 .await
393 .map_err(|error| self.make_session_error(error))?;
394
395 let output_str: String = String::from_utf8_lossy(&output).into();
396
397 ensure!(
398 exit_code.is_some() && exit_code.unwrap() == 0,
399 SshError::NonZeroExitCode {
400 address: self.address,
401 code: exit_code.unwrap(),
402 message: output_str
403 }
404 );
405
406 Ok((output_str.clone(), output_str))
407 }
408
409 pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
412 let mut error = None;
413 for _ in 0..self.retries + 1 {
414 match self
415 .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
416 .await
417 {
418 Ok((file_data, _)) => return Ok(file_data),
419 Err(err) => error = Some(err),
420 }
421 }
422 Err(error.unwrap())
423 }
424}