sui_aws_orchestrator/
ssh.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use async_trait::async_trait;
5use std::io::Write;
6use std::sync::Arc;
7use std::{
8    net::SocketAddr,
9    path::{Path, PathBuf},
10    time::Duration,
11};
12
13use futures::future::try_join_all;
14use russh::client::Msg;
15use russh::{client, Channel};
16use russh_keys::key;
17use tokio::task::JoinHandle;
18use tokio::time::sleep;
19
20use crate::{
21    client::Instance,
22    ensure,
23    error::{SshError, SshResult},
24};
25
26#[derive(PartialEq, Eq)]
27/// The status of an ssh command running in the background.
28pub enum CommandStatus {
29    Running,
30    Terminated,
31}
32
33impl CommandStatus {
34    /// Return whether a background command is still running. Returns `Terminated` if the
35    /// command is not running in the background.
36    pub fn status(command_id: &str, text: &str) -> Self {
37        if text.contains(command_id) {
38            Self::Running
39        } else {
40            Self::Terminated
41        }
42    }
43}
44
45/// The command to execute on all specified remote machines.
46#[derive(Clone, Default)]
47pub struct CommandContext {
48    /// Whether to run the command in the background (and return immediately). Commands
49    /// running in the background are identified by a unique id.
50    pub background: Option<String>,
51    /// The path from where to execute the command.
52    pub path: Option<PathBuf>,
53    /// The log file to redirect all stdout and stderr.
54    pub log_file: Option<PathBuf>,
55}
56
57impl CommandContext {
58    /// Create a new ssh command.
59    pub fn new() -> Self {
60        Self {
61            background: None,
62            path: None,
63            log_file: None,
64        }
65    }
66
67    /// Set id of the command and indicate that it should run in the background.
68    pub fn run_background(mut self, id: String) -> Self {
69        self.background = Some(id);
70        self
71    }
72
73    /// Set the path from where to execute the command.
74    pub fn with_execute_from_path(mut self, path: PathBuf) -> Self {
75        self.path = Some(path);
76        self
77    }
78
79    /// Set the log file where to redirect stdout and stderr.
80    pub fn with_log_file(mut self, path: PathBuf) -> Self {
81        self.log_file = Some(path);
82        self
83    }
84
85    /// Apply the context to a base command.
86    pub fn apply<S: Into<String>>(&self, base_command: S) -> String {
87        let mut str = base_command.into();
88        if let Some(log_file) = &self.log_file {
89            str = format!("{str} |& tee {}", log_file.as_path().display());
90        }
91        if let Some(id) = &self.background {
92            str = format!("tmux new -d -s \"{id}\" \"{str}\"");
93        }
94        if let Some(exec_path) = &self.path {
95            str = format!("(cd {} && {str})", exec_path.as_path().display());
96        }
97        str
98    }
99}
100
101#[derive(Clone)]
102pub struct SshConnectionManager {
103    /// The ssh username.
104    username: String,
105    /// The ssh primate key to connect to the instances.
106    private_key_file: PathBuf,
107    /// The timeout value of the connection.
108    timeout: Option<Duration>,
109    /// The number of retries before giving up to execute the command.
110    retries: usize,
111}
112
113impl SshConnectionManager {
114    /// Delay before re-attempting an ssh execution.
115    const RETRY_DELAY: Duration = Duration::from_secs(5);
116
117    /// Create a new ssh manager from the instances username and private keys.
118    pub fn new(username: String, private_key_file: PathBuf) -> Self {
119        Self {
120            username,
121            private_key_file,
122            timeout: None,
123            retries: 0,
124        }
125    }
126
127    /// Set a timeout duration for the connections.
128    pub fn with_timeout(mut self, timeout: Duration) -> Self {
129        self.timeout = Some(timeout);
130        self
131    }
132
133    /// Set the maximum number of times to retries to establish a connection and execute commands.
134    pub fn with_retries(mut self, retries: usize) -> Self {
135        self.retries = retries;
136        self
137    }
138
139    /// Create a new ssh connection with the provided host.
140    pub async fn connect(&self, address: SocketAddr) -> SshResult<SshConnection> {
141        let mut error = None;
142        for _ in 0..self.retries + 1 {
143            match SshConnection::new(
144                address,
145                &self.username,
146                self.private_key_file.clone(),
147                self.timeout,
148                Some(self.retries),
149            )
150            .await
151            {
152                Ok(x) => return Ok(x),
153                Err(e) => error = Some(e),
154            }
155            sleep(Self::RETRY_DELAY).await;
156        }
157        Err(error.unwrap())
158    }
159
160    /// Execute the specified ssh command on all provided instances.
161    pub async fn execute<I, S>(
162        &self,
163        instances: I,
164        command: S,
165        context: CommandContext,
166    ) -> SshResult<Vec<(String, String)>>
167    where
168        I: IntoIterator<Item = Instance>,
169        S: Into<String> + Clone + Send + 'static,
170    {
171        let targets = instances
172            .into_iter()
173            .map(|instance| (instance, command.clone()));
174        self.execute_per_instance(targets, context).await
175    }
176
177    /// Execute the ssh command associated with each instance.
178    pub async fn execute_per_instance<I, S>(
179        &self,
180        instances: I,
181        context: CommandContext,
182    ) -> SshResult<Vec<(String, String)>>
183    where
184        I: IntoIterator<Item = (Instance, S)>,
185        S: Into<String> + Send + 'static,
186    {
187        let handles = self.run_per_instance(instances, context).await;
188
189        try_join_all(handles)
190            .await
191            .unwrap()
192            .into_iter()
193            .collect::<SshResult<_>>()
194    }
195
196    async fn run_per_instance<I, S>(
197        &self,
198        instances: I,
199        context: CommandContext,
200    ) -> Vec<JoinHandle<SshResult<(String, String)>>>
201    where
202        I: IntoIterator<Item = (Instance, S)>,
203        S: Into<String> + Send + 'static,
204    {
205        instances
206            .into_iter()
207            .map(|(instance, command)| {
208                let ssh_manager = self.clone();
209                let context = context.clone();
210
211                tokio::spawn(async move {
212                    let connection = ssh_manager.connect(instance.ssh_address()).await?;
213                    // SshConnection::execute is a blocking call, needs to go to blocking pool
214                    connection.execute(context.apply(command)).await
215                })
216            })
217            .collect::<Vec<_>>()
218    }
219
220    /// Wait until a command running in the background returns or started.
221    pub async fn wait_for_command<I>(
222        &self,
223        instances: I,
224        command_id: &str,
225        status: CommandStatus,
226    ) -> SshResult<()>
227    where
228        I: IntoIterator<Item = Instance> + Clone,
229    {
230        loop {
231            sleep(Self::RETRY_DELAY).await;
232
233            let result = self
234                .execute(
235                    instances.clone(),
236                    "(tmux ls || true)",
237                    CommandContext::default(),
238                )
239                .await?;
240            if result
241                .iter()
242                .all(|(stdout, _)| CommandStatus::status(command_id, stdout) == status)
243            {
244                break;
245            }
246        }
247        Ok(())
248    }
249
250    pub async fn wait_for_success<I, S>(&self, instances: I)
251    where
252        I: IntoIterator<Item = (Instance, S)> + Clone,
253        S: Into<String> + Send + 'static + Clone,
254    {
255        loop {
256            sleep(Self::RETRY_DELAY).await;
257
258            if self
259                .execute_per_instance(instances.clone(), CommandContext::default())
260                .await
261                .is_ok()
262            {
263                break;
264            }
265        }
266    }
267
268    /// Kill a command running in the background of the specified instances.
269    pub async fn kill<I>(&self, instances: I, command_id: &str) -> SshResult<()>
270    where
271        I: IntoIterator<Item = Instance>,
272    {
273        let ssh_command = format!("(tmux kill-session -t {command_id} || true)");
274        let targets = instances.into_iter().map(|x| (x, ssh_command.clone()));
275        self.execute_per_instance(targets, CommandContext::default())
276            .await?;
277        Ok(())
278    }
279}
280
281struct Session {}
282
283#[async_trait]
284impl client::Handler for Session {
285    type Error = russh::Error;
286
287    async fn check_server_key(
288        self,
289        _server_public_key: &key::PublicKey,
290    ) -> Result<(Self, bool), Self::Error> {
291        Ok((self, true))
292    }
293}
294
295/// Representation of an ssh connection.
296pub struct SshConnection {
297    /// The ssh session.
298    session: client::Handle<Session>,
299    /// The host address.
300    address: SocketAddr,
301    /// The number of retries before giving up to execute the command.
302    retries: usize,
303}
304
305impl SshConnection {
306    /// Default duration before timing out the ssh connection.
307    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
308
309    /// Create a new ssh connection with a specific host.
310    pub async fn new<P: AsRef<Path>>(
311        address: SocketAddr,
312        username: &str,
313        private_key_file: P,
314        inactivity_timeout: Option<Duration>,
315        retries: Option<usize>,
316    ) -> SshResult<Self> {
317        let key = russh_keys::load_secret_key(private_key_file, None)
318            .map_err(|error| SshError::PrivateKeyError { address, error })?;
319
320        let config = client::Config {
321            inactivity_timeout: inactivity_timeout.or(Some(Self::DEFAULT_TIMEOUT)),
322            ..<_>::default()
323        };
324
325        let mut session = client::connect(Arc::new(config), address, Session {})
326            .await
327            .map_err(|error| SshError::ConnectionError { address, error })?;
328
329        let _auth_res = session
330            .authenticate_publickey(username, Arc::new(key))
331            .await
332            .map_err(|error| SshError::SessionError { address, error })?;
333
334        Ok(Self {
335            session,
336            address,
337            retries: retries.unwrap_or_default(),
338        })
339    }
340
341    /// Make a useful session error from the lower level error message.
342    fn make_session_error(&self, error: russh::Error) -> SshError {
343        SshError::SessionError {
344            address: self.address,
345            error,
346        }
347    }
348
349    /// Execute an ssh command on the remote machine.
350    pub async fn execute(&self, command: String) -> SshResult<(String, String)> {
351        let mut error = None;
352        for _ in 0..self.retries + 1 {
353            let channel = match self.session.channel_open_session().await {
354                Ok(x) => x,
355                Err(e) => {
356                    error = Some(self.make_session_error(e));
357                    continue;
358                }
359            };
360            match self.execute_impl(channel, command.clone()).await {
361                r @ Ok(..) => return r,
362                Err(e) => error = Some(e),
363            }
364        }
365        Err(error.unwrap())
366    }
367
368    /// Execute an ssh command on the remote machine and return both stdout and stderr.
369    async fn execute_impl(
370        &self,
371        mut channel: Channel<Msg>,
372        command: String,
373    ) -> SshResult<(String, String)> {
374        channel
375            .exec(true, command)
376            .await
377            .map_err(|e| self.make_session_error(e))?;
378
379        let mut output = Vec::new();
380        let mut exit_code = None;
381
382        while let Some(msg) = channel.wait().await {
383            match msg {
384                russh::ChannelMsg::Data { ref data } => output.write_all(data).unwrap(),
385                russh::ChannelMsg::ExitStatus { exit_status } => exit_code = Some(exit_status),
386                _ => {}
387            }
388        }
389
390        channel
391            .close()
392            .await
393            .map_err(|error| self.make_session_error(error))?;
394
395        let output_str: String = String::from_utf8_lossy(&output).into();
396
397        ensure!(
398            exit_code.is_some() && exit_code.unwrap() == 0,
399            SshError::NonZeroExitCode {
400                address: self.address,
401                code: exit_code.unwrap(),
402                message: output_str
403            }
404        );
405
406        Ok((output_str.clone(), output_str))
407    }
408
409    /// Download a file from the remote machines by doing a silly cat on the file.
410    /// TODO: if the files get too big then we should leverage a simple S3 bucket instead.
411    pub async fn download<P: AsRef<Path>>(&self, path: P) -> SshResult<String> {
412        let mut error = None;
413        for _ in 0..self.retries + 1 {
414            match self
415                .execute(format!("cat {}", path.as_ref().to_str().unwrap()))
416                .await
417            {
418                Ok((file_data, _)) => return Ok(file_data),
419                Err(err) => error = Some(err),
420            }
421        }
422        Err(error.unwrap())
423    }
424}