1use std::fmt;
5use std::panic;
6use std::time::Duration;
7
8use futures::future;
9use futures::future::BoxFuture;
10use futures::future::FutureExt;
11use tap::TapFallible;
12use tokio::signal;
13use tokio::task::JoinSet;
14use tokio::time::timeout;
15use tracing::error;
16use tracing::info;
17
18pub const GRACE: Duration = Duration::from_secs(30);
23
24#[must_use = "Dropping the service aborts all its tasks immediately"]
45#[derive(Default)]
46pub struct Service {
47 exits: Vec<BoxFuture<'static, ()>>,
49
50 fsts: JoinSet<anyhow::Result<()>>,
52
53 snds: JoinSet<anyhow::Result<()>>,
55}
56
57#[derive(thiserror::Error, Debug)]
58pub enum Error {
59 #[error("Service has been terminated gracefully")]
60 Terminated,
61
62 #[error("Service has been aborted due to ungraceful shutdown")]
63 Aborted,
64
65 #[error(transparent)]
66 Task(anyhow::Error),
67}
68
69impl Service {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn spawn(
81 mut self,
82 task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
83 ) -> Self {
84 self.fsts.spawn(task);
85 self
86 }
87
88 pub fn spawn_aborting(
92 mut self,
93 task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
94 ) -> Self {
95 let h = self.fsts.spawn(task);
96 self.with_shutdown_signal(async move { h.abort() })
97 }
98
99 pub fn with_shutdown_signal(mut self, exit: impl Future<Output = ()> + Send + 'static) -> Self {
107 self.exits.push(exit.boxed());
108 self
109 }
110
111 pub fn merge(mut self, mut other: Service) -> Self {
113 self.exits.extend(other.exits);
114
115 if !other.fsts.is_empty() {
116 self.fsts.spawn(async move { run(&mut other.fsts).await });
117 }
118
119 if !other.snds.is_empty() {
120 self.snds.spawn(async move { run(&mut other.snds).await });
121 }
122
123 self
124 }
125
126 pub fn attach(mut self, mut other: Service) -> Self {
130 self.exits.extend(other.exits);
131
132 if !other.fsts.is_empty() {
133 self.snds.spawn(async move { run(&mut other.fsts).await });
134 }
135
136 if !other.snds.is_empty() {
137 self.snds.spawn(async move { run(&mut other.snds).await });
138 }
139
140 self
141 }
142
143 pub async fn main(self) -> Result<(), Error> {
146 self.wait_for_shutdown(GRACE, terminate).await
147 }
148
149 async fn wait_for_shutdown<T: Future<Output = ()>>(
160 mut self,
161 grace: Duration,
162 mut terminate: impl FnMut() -> T,
163 ) -> Result<(), Error> {
164 let exec = tokio::select! {
165 res = self.join() => {
166 res.map_err(Error::Task)
167 }
168
169 _ = terminate() => {
170 info!("Termination received");
171 Err(Error::Terminated)
172 }
173 };
174
175 info!("Shutting down gracefully...");
176 tokio::select! {
177 res = timeout(grace, self.shutdown()) => {
178 match res {
179 Ok(Ok(())) => {},
180 Ok(Err(_)) => return Err(Error::Aborted),
181 Err(_) => {
182 error!("Grace period elapsed, aborting...");
183 return Err(Error::Aborted);
184 }
185 }
186 }
187
188 _ = terminate() => {
189 error!("Termination received during shutdown, aborting...");
190 return Err(Error::Aborted);
191 },
192 }
193
194 exec
195 }
196
197 pub async fn join(&mut self) -> anyhow::Result<()> {
203 tokio::select! {
204 res = run(&mut self.fsts) => {
205 res.tap_err(|e| error!("Primary task failure: {e:#}"))
206 },
207
208 res = run_secondary(&mut self.snds) => {
209 res.tap_err(|e| error!("Secondary task failure: {e:#}"))
210 }
211 }
212 }
213
214 pub async fn shutdown(mut self) -> Result<(), Error> {
219 for exit in self.exits {
220 exit.await;
221 }
222 if let Err(e) = future::try_join(run(&mut self.fsts), run(&mut self.snds)).await {
223 error!("Task failure during shutdown: {e:#}");
224 return Err(Error::Task(e));
225 }
226
227 Ok(())
228 }
229}
230
231unsafe impl Sync for Service {}
235
236impl fmt::Debug for Service {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 f.debug_struct("Service")
239 .field("exits", &self.exits.len())
240 .field("fsts", &self.fsts)
241 .field("snds", &self.snds)
242 .finish()
243 }
244}
245
246async fn run(tasks: &mut JoinSet<anyhow::Result<()>>) -> anyhow::Result<()> {
249 while let Some(res) = tasks.join_next().await {
250 match res {
251 Ok(Ok(())) => continue,
252 Ok(Err(e)) => return Err(e),
253
254 Err(e) => {
255 if e.is_panic() {
256 panic::resume_unwind(e.into_panic());
257 }
258 }
259 }
260 }
261
262 Ok(())
263}
264
265async fn run_secondary(tasks: &mut JoinSet<anyhow::Result<()>>) -> anyhow::Result<()> {
270 run(tasks).await?;
271 std::future::pending().await
272}
273
274pub async fn terminate() {
279 tokio::select! {
280 _ = signal::ctrl_c() => {},
281
282 _ = async {
283 #[cfg(unix)]
284 let _ = signal::unix::signal(signal::unix::SignalKind::terminate()).unwrap().recv().await;
285
286 #[cfg(not(unix))]
287 future::pending::<()>().await;
288 } => {}
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use std::sync::Arc;
295
296 use anyhow::bail;
297 use tokio::sync::Notify;
298 use tokio::sync::oneshot;
299
300 use super::*;
301
302 #[tokio::test]
303 async fn test_empty() {
304 Service::new()
306 .wait_for_shutdown(GRACE, std::future::pending)
307 .await
308 .unwrap();
309 }
310
311 #[tokio::test]
312 async fn test_empty_attach_merge() {
313 Service::new()
315 .attach(Service::new())
316 .merge(Service::new())
317 .wait_for_shutdown(GRACE, std::future::pending)
318 .await
319 .unwrap();
320 }
321
322 #[tokio::test]
323 async fn test_completion() {
324 let (atx, arx) = oneshot::channel::<()>();
325 let (btx, brx) = oneshot::channel::<()>();
326
327 let svc = Service::new().spawn(async move {
328 let _brx = brx;
329 Ok(arx.await?)
330 });
331
332 assert!(!btx.is_closed());
334
335 atx.send(()).unwrap();
338 svc.wait_for_shutdown(GRACE, std::future::pending)
339 .await
340 .unwrap();
341 assert!(btx.is_closed());
342 }
343
344 #[tokio::test]
345 async fn test_failure() {
346 let svc = Service::new().spawn(async move { bail!("boom") });
347 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
348 assert!(matches!(res, Err(Error::Task(_))));
349 }
350
351 #[tokio::test]
352 #[should_panic]
353 async fn test_panic() {
354 let svc = Service::new().spawn(async move { panic!("boom") });
355 let _ = svc.wait_for_shutdown(GRACE, std::future::pending).await;
356 }
357
358 #[tokio::test]
359 async fn test_graceful_shutdown() {
360 let (atx, arx) = oneshot::channel::<()>();
361 let (btx, brx) = oneshot::channel::<()>();
362
363 let srx = Arc::new(Notify::new());
364 let stx = srx.clone();
365
366 let svc = Service::new()
367 .with_shutdown_signal(async move { atx.send(()).unwrap() })
368 .spawn(async move {
369 arx.await?;
370 btx.send(()).unwrap();
371 Ok(())
372 });
373
374 let handle =
376 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
377
378 stx.notify_one();
381 brx.await.unwrap();
382
383 let res = handle.await.unwrap();
385 assert!(matches!(res, Err(Error::Terminated)));
386 }
387
388 #[tokio::test]
389 async fn test_multiple_tasks() {
390 let (atx, arx) = oneshot::channel::<()>();
391 let (btx, brx) = oneshot::channel::<()>();
392 let (ctx, crx) = oneshot::channel::<()>();
393
394 let svc = Service::new()
397 .spawn(async move { Ok(arx.await?) })
398 .spawn(async move { Ok(brx.await?) })
399 .spawn(async move { Ok(crx.await?) });
400
401 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
402
403 atx.send(()).unwrap();
404 tokio::task::yield_now().await;
405
406 btx.send(()).unwrap();
407 tokio::task::yield_now().await;
408
409 ctx.send(()).unwrap();
410 handle.await.unwrap().unwrap();
411 }
412
413 #[tokio::test]
414 async fn test_multiple_task_failure() {
415 let (atx, arx) = oneshot::channel::<()>();
416
417 let svc = Service::new()
420 .spawn_aborting(async move { Ok(arx.await?) })
421 .spawn(async move { bail!("boom") });
422
423 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
424 let res = handle.await.unwrap();
425
426 assert!(matches!(res, Err(Error::Task(_))));
429 assert!(atx.is_closed());
430 }
431
432 #[tokio::test]
433 async fn test_secondary_stuck() {
434 let (atx, arx) = oneshot::channel::<()>();
435 let (btx, brx) = oneshot::channel::<()>();
436
437 let snd = Service::new().spawn_aborting(async move { Ok(brx.await?) });
439 let svc = Service::new()
440 .spawn(async move { Ok(arx.await?) })
441 .attach(snd);
442
443 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
444
445 atx.send(()).unwrap();
447 handle.await.unwrap().unwrap();
448 assert!(btx.is_closed());
449 }
450
451 #[tokio::test]
452 async fn test_secondary_complete() {
453 let (atx, arx) = oneshot::channel::<()>();
454 let (btx, brx) = oneshot::channel::<()>();
455 let (mut ctx, crx) = oneshot::channel::<()>();
456
457 let snd = Service::new().spawn(async move {
459 let _crx = crx;
460 brx.await?;
461 Ok(())
462 });
463
464 let svc = Service::new()
465 .spawn(async move { Ok(arx.await?) })
466 .attach(snd);
467
468 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
469
470 btx.send(()).unwrap();
472 ctx.closed().await;
473 tokio::task::yield_now().await;
474
475 atx.send(()).unwrap();
478 handle.await.unwrap().unwrap();
479 }
480
481 #[tokio::test]
482 async fn test_secondary_failure() {
483 let (atx, arx) = oneshot::channel::<()>();
484
485 let snd = Service::new().spawn(async move { bail!("boom") });
487 let svc = Service::new()
488 .spawn_aborting(async move { Ok(arx.await?) })
489 .attach(snd);
490
491 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
494 assert!(matches!(res, Err(Error::Task(_))));
495 assert!(atx.is_closed());
496 }
497
498 #[tokio::test]
499 #[should_panic]
500 async fn test_secondary_panic() {
501 let (_atx, arx) = oneshot::channel::<()>();
502
503 let snd = Service::new().spawn(async move { panic!("boom") });
505 let svc = Service::new()
506 .spawn_aborting(async move { Ok(arx.await?) })
507 .attach(snd);
508
509 let _ = svc.wait_for_shutdown(GRACE, std::future::pending).await;
511 }
512
513 #[tokio::test]
514 async fn test_secondary_graceful_shutdown() {
515 let (atx, arx) = oneshot::channel::<()>();
516 let (btx, brx) = oneshot::channel::<()>();
517 let (ctx, crx) = oneshot::channel::<()>();
518
519 let srx = Arc::new(Notify::new());
520 let stx = srx.clone();
521
522 let snd = Service::new()
524 .with_shutdown_signal(async move { atx.send(()).unwrap() })
525 .spawn(async move {
526 let _crx = crx;
527 Ok(arx.await?)
528 });
529
530 let svc = Service::new()
532 .spawn_aborting(async move { Ok(brx.await?) })
533 .attach(snd);
534
535 let handle =
537 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
538
539 assert!(!btx.is_closed());
541 assert!(!ctx.is_closed());
542
543 stx.notify_one();
545 let res = handle.await.unwrap();
546 assert!(matches!(res, Err(Error::Terminated)));
547 assert!(btx.is_closed());
548 assert!(ctx.is_closed());
549 }
550
551 #[tokio::test]
552 async fn test_merge() {
553 let (atx, arx) = oneshot::channel::<()>();
554 let (btx, brx) = oneshot::channel::<()>();
555 let (ctx, crx) = oneshot::channel::<()>();
556 let (dtx, drx) = oneshot::channel::<()>();
557 let (etx, erx) = oneshot::channel::<()>();
558 let (ftx, frx) = oneshot::channel::<()>();
559
560 let srx = Arc::new(Notify::new());
561 let stx = srx.clone();
562
563 let a = Service::new()
566 .spawn(async move { Ok(arx.await?) })
567 .with_shutdown_signal(async move { ctx.send(()).unwrap() })
568 .spawn(async move {
569 crx.await?;
570 dtx.send(()).unwrap();
571 Ok(())
572 });
573
574 let b = Service::new()
575 .spawn(async move { Ok(brx.await?) })
576 .with_shutdown_signal(async move { etx.send(()).unwrap() })
577 .spawn(async move {
578 erx.await?;
579 ftx.send(()).unwrap();
580 Ok(())
581 });
582
583 let svc = Service::new().merge(a).merge(b);
585 let handle =
586 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
587
588 atx.send(()).unwrap();
590 tokio::task::yield_now().await;
591
592 btx.send(()).unwrap();
593 tokio::task::yield_now().await;
594
595 stx.notify_one();
598 drx.await.unwrap();
599 frx.await.unwrap();
600
601 let res = handle.await.unwrap();
603 assert!(matches!(res, Err(Error::Terminated)));
604 }
605
606 #[tokio::test]
607 async fn test_drop_abort() {
608 let (mut atx, arx) = oneshot::channel::<()>();
609 let (mut btx, brx) = oneshot::channel::<()>();
610
611 let svc = Service::new()
612 .spawn(async move { Ok(arx.await?) })
613 .spawn_aborting(async move { Ok(brx.await?) });
614
615 assert!(!atx.is_closed());
616 assert!(!btx.is_closed());
617
618 drop(svc);
621 atx.closed().await;
622 btx.closed().await;
623 }
624
625 #[tokio::test]
626 async fn test_shutdown() {
627 let (atx, arx) = oneshot::channel::<()>();
628 let (btx, brx) = oneshot::channel::<()>();
629
630 let svc = Service::new()
631 .with_shutdown_signal(async move { atx.send(()).unwrap() })
632 .spawn(async move { Ok(arx.await?) })
633 .spawn_aborting(async move { Ok(brx.await?) });
634
635 svc.shutdown().await.unwrap();
639 assert!(btx.is_closed());
640 }
641
642 #[tokio::test]
643 async fn test_error_cascade() {
644 let (atx, arx) = oneshot::channel::<()>();
645
646 let svc = Service::new()
648 .spawn(async move { bail!("boom") })
649 .with_shutdown_signal(async move { atx.send(()).unwrap() })
650 .spawn(async move {
651 arx.await?;
652 bail!("boom, again")
653 });
654
655 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
657 assert!(matches!(res, Err(Error::Aborted)));
658 }
659
660 #[tokio::test]
661 async fn test_multiple_errors() {
662 let svc = Service::new()
665 .spawn(async move { bail!("boom") })
666 .spawn(async move { bail!("boom, again") });
667
668 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
670 assert!(matches!(res, Err(Error::Aborted)));
671 }
672
673 #[tokio::test]
674 async fn test_termination_cascade() {
675 let svc = Service::new().spawn(std::future::pending());
677
678 let srx = Arc::new(Notify::new());
679 let stx = srx.clone();
680
681 let handle =
683 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
684
685 stx.notify_one();
687 tokio::task::yield_now().await;
688
689 stx.notify_one();
691 tokio::task::yield_now().await;
692
693 let res = handle.await.unwrap();
694 assert!(matches!(res, Err(Error::Aborted)));
695 }
696
697 #[tokio::test]
698 #[should_panic]
699 async fn test_panic_during_shutdown() {
700 let (atx, arx) = oneshot::channel::<()>();
701
702 let srx = Arc::new(Notify::new());
703 let stx = srx.clone();
704
705 let svc = Service::new()
706 .with_shutdown_signal(async move { atx.send(()).unwrap() })
707 .spawn(async move {
708 arx.await?;
709 panic!("boom")
710 });
711
712 let handle =
714 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
715
716 stx.notify_one();
718 let _ = handle.await.unwrap();
719 }
720
721 #[tokio::test(start_paused = true)]
722 async fn test_graceful_shutdown_timeout() {
723 let srx = Arc::new(Notify::new());
724 let stx = srx.clone();
725
726 let svc = Service::new().spawn(std::future::pending());
728
729 let handle =
730 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
731
732 stx.notify_one();
735 tokio::time::advance(GRACE * 2).await;
736
737 let res = handle.await.unwrap();
738 assert!(matches!(res, Err(Error::Aborted)));
739 }
740}