1use std::fmt;
5use std::panic;
6use std::time::Duration;
7
8use futures::future;
9use futures::future::BoxFuture;
10use futures::future::FutureExt;
11use tap::TapFallible;
12use tokio::signal;
13use tokio::task::JoinSet;
14use tokio::time::timeout;
15use tracing::error;
16use tracing::info;
17
18pub const GRACE: Duration = Duration::from_secs(30);
23
24#[must_use = "Dropping the service aborts all its tasks immediately"]
45#[derive(Default)]
46pub struct Service {
47 exits: Vec<BoxFuture<'static, ()>>,
49
50 fsts: JoinSet<anyhow::Result<()>>,
52
53 snds: JoinSet<anyhow::Result<()>>,
55}
56
57#[derive(thiserror::Error, Debug)]
58pub enum Error {
59 #[error("Service has been terminated gracefully")]
60 Terminated,
61
62 #[error("Service has been aborted due to ungraceful shutdown")]
63 Aborted,
64
65 #[error(transparent)]
66 Task(anyhow::Error),
67}
68
69impl Service {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn spawn(
81 mut self,
82 task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
83 ) -> Self {
84 self.fsts.spawn(task);
85 self
86 }
87
88 pub fn spawn_aborting(
92 mut self,
93 task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
94 ) -> Self {
95 let h = self.fsts.spawn(task);
96 self.with_shutdown_signal(async move { h.abort() })
97 }
98
99 pub fn with_shutdown_signal(mut self, exit: impl Future<Output = ()> + Send + 'static) -> Self {
107 self.exits.push(exit.boxed());
108 self
109 }
110
111 pub fn merge(mut self, mut other: Service) -> Self {
113 self.exits.extend(other.exits);
114
115 if !other.fsts.is_empty() {
116 self.fsts.spawn(async move { run(&mut other.fsts).await });
117 }
118
119 if !other.snds.is_empty() {
120 self.snds.spawn(async move { run(&mut other.snds).await });
121 }
122
123 self
124 }
125
126 pub fn attach(mut self, mut other: Service) -> Self {
130 self.exits.extend(other.exits);
131
132 if !other.fsts.is_empty() {
133 self.snds.spawn(async move { run(&mut other.fsts).await });
134 }
135
136 if !other.snds.is_empty() {
137 self.snds.spawn(async move { run(&mut other.snds).await });
138 }
139
140 self
141 }
142
143 pub async fn main(self) -> Result<(), Error> {
146 self.wait_for_shutdown(GRACE, terminate).await
147 }
148
149 async fn wait_for_shutdown<T: Future<Output = ()>>(
160 mut self,
161 grace: Duration,
162 mut terminate: impl FnMut() -> T,
163 ) -> Result<(), Error> {
164 let exec = tokio::select! {
165 res = self.join() => {
166 res.map_err(Error::Task)
167 }
168
169 _ = terminate() => {
170 info!("Termination received");
171 Err(Error::Terminated)
172 }
173 };
174
175 info!("Shutting down gracefully...");
176 tokio::select! {
177 res = timeout(grace, self.shutdown()) => {
178 match res {
179 Ok(Ok(())) => {},
180 Ok(Err(_)) => return Err(Error::Aborted),
181 Err(_) => {
182 error!("Grace period elapsed, aborting...");
183 return Err(Error::Aborted);
184 }
185 }
186 }
187
188 _ = terminate() => {
189 error!("Termination received during shutdown, aborting...");
190 return Err(Error::Aborted);
191 },
192 }
193
194 exec
195 }
196
197 pub async fn join(&mut self) -> anyhow::Result<()> {
203 tokio::select! {
204 res = run(&mut self.fsts) => {
205 res.tap_err(|e| error!("Primary task failure: {e:#}"))
206 },
207
208 res = run_secondary(&mut self.snds) => {
209 res.tap_err(|e| error!("Secondary task failure: {e:#}"))
210 }
211 }
212 }
213
214 pub async fn shutdown(mut self) -> Result<(), Error> {
219 let _ = future::join_all(self.exits).await;
220 if let Err(e) = future::try_join(run(&mut self.fsts), run(&mut self.snds)).await {
221 error!("Task failure during shutdown: {e:#}");
222 return Err(Error::Task(e));
223 }
224
225 Ok(())
226 }
227}
228
229unsafe impl Sync for Service {}
233
234impl fmt::Debug for Service {
235 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
236 f.debug_struct("Service")
237 .field("exits", &self.exits.len())
238 .field("fsts", &self.fsts)
239 .field("snds", &self.snds)
240 .finish()
241 }
242}
243
244async fn run(tasks: &mut JoinSet<anyhow::Result<()>>) -> anyhow::Result<()> {
247 while let Some(res) = tasks.join_next().await {
248 match res {
249 Ok(Ok(())) => continue,
250 Ok(Err(e)) => return Err(e),
251
252 Err(e) => {
253 if e.is_panic() {
254 panic::resume_unwind(e.into_panic());
255 }
256 }
257 }
258 }
259
260 Ok(())
261}
262
263async fn run_secondary(tasks: &mut JoinSet<anyhow::Result<()>>) -> anyhow::Result<()> {
268 run(tasks).await?;
269 std::future::pending().await
270}
271
272pub async fn terminate() {
277 tokio::select! {
278 _ = signal::ctrl_c() => {},
279
280 _ = async {
281 #[cfg(unix)]
282 let _ = signal::unix::signal(signal::unix::SignalKind::terminate()).unwrap().recv().await;
283
284 #[cfg(not(unix))]
285 future::pending::<()>().await;
286 } => {}
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use std::sync::Arc;
293
294 use anyhow::bail;
295 use tokio::sync::Notify;
296 use tokio::sync::oneshot;
297
298 use super::*;
299
300 #[tokio::test]
301 async fn test_empty() {
302 Service::new()
304 .wait_for_shutdown(GRACE, std::future::pending)
305 .await
306 .unwrap();
307 }
308
309 #[tokio::test]
310 async fn test_empty_attach_merge() {
311 Service::new()
313 .attach(Service::new())
314 .merge(Service::new())
315 .wait_for_shutdown(GRACE, std::future::pending)
316 .await
317 .unwrap();
318 }
319
320 #[tokio::test]
321 async fn test_completion() {
322 let (atx, arx) = oneshot::channel::<()>();
323 let (btx, brx) = oneshot::channel::<()>();
324
325 let svc = Service::new().spawn(async move {
326 let _brx = brx;
327 Ok(arx.await?)
328 });
329
330 assert!(!btx.is_closed());
332
333 atx.send(()).unwrap();
336 svc.wait_for_shutdown(GRACE, std::future::pending)
337 .await
338 .unwrap();
339 assert!(btx.is_closed());
340 }
341
342 #[tokio::test]
343 async fn test_failure() {
344 let svc = Service::new().spawn(async move { bail!("boom") });
345 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
346 assert!(matches!(res, Err(Error::Task(_))));
347 }
348
349 #[tokio::test]
350 #[should_panic]
351 async fn test_panic() {
352 let svc = Service::new().spawn(async move { panic!("boom") });
353 let _ = svc.wait_for_shutdown(GRACE, std::future::pending).await;
354 }
355
356 #[tokio::test]
357 async fn test_graceful_shutdown() {
358 let (atx, arx) = oneshot::channel::<()>();
359 let (btx, brx) = oneshot::channel::<()>();
360
361 let srx = Arc::new(Notify::new());
362 let stx = srx.clone();
363
364 let svc = Service::new()
365 .with_shutdown_signal(async move { atx.send(()).unwrap() })
366 .spawn(async move {
367 arx.await?;
368 btx.send(()).unwrap();
369 Ok(())
370 });
371
372 let handle =
374 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
375
376 stx.notify_one();
379 brx.await.unwrap();
380
381 let res = handle.await.unwrap();
383 assert!(matches!(res, Err(Error::Terminated)));
384 }
385
386 #[tokio::test]
387 async fn test_multiple_tasks() {
388 let (atx, arx) = oneshot::channel::<()>();
389 let (btx, brx) = oneshot::channel::<()>();
390 let (ctx, crx) = oneshot::channel::<()>();
391
392 let svc = Service::new()
395 .spawn(async move { Ok(arx.await?) })
396 .spawn(async move { Ok(brx.await?) })
397 .spawn(async move { Ok(crx.await?) });
398
399 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
400
401 atx.send(()).unwrap();
402 tokio::task::yield_now().await;
403
404 btx.send(()).unwrap();
405 tokio::task::yield_now().await;
406
407 ctx.send(()).unwrap();
408 handle.await.unwrap().unwrap();
409 }
410
411 #[tokio::test]
412 async fn test_multiple_task_failure() {
413 let (atx, arx) = oneshot::channel::<()>();
414
415 let svc = Service::new()
418 .spawn_aborting(async move { Ok(arx.await?) })
419 .spawn(async move { bail!("boom") });
420
421 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
422 let res = handle.await.unwrap();
423
424 assert!(matches!(res, Err(Error::Task(_))));
427 assert!(atx.is_closed());
428 }
429
430 #[tokio::test]
431 async fn test_secondary_stuck() {
432 let (atx, arx) = oneshot::channel::<()>();
433 let (btx, brx) = oneshot::channel::<()>();
434
435 let snd = Service::new().spawn_aborting(async move { Ok(brx.await?) });
437 let svc = Service::new()
438 .spawn(async move { Ok(arx.await?) })
439 .attach(snd);
440
441 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
442
443 atx.send(()).unwrap();
445 handle.await.unwrap().unwrap();
446 assert!(btx.is_closed());
447 }
448
449 #[tokio::test]
450 async fn test_secondary_complete() {
451 let (atx, arx) = oneshot::channel::<()>();
452 let (btx, brx) = oneshot::channel::<()>();
453 let (mut ctx, crx) = oneshot::channel::<()>();
454
455 let snd = Service::new().spawn(async move {
457 let _crx = crx;
458 brx.await?;
459 Ok(())
460 });
461
462 let svc = Service::new()
463 .spawn(async move { Ok(arx.await?) })
464 .attach(snd);
465
466 let handle = tokio::spawn(svc.wait_for_shutdown(GRACE, std::future::pending));
467
468 btx.send(()).unwrap();
470 ctx.closed().await;
471 tokio::task::yield_now().await;
472
473 atx.send(()).unwrap();
476 handle.await.unwrap().unwrap();
477 }
478
479 #[tokio::test]
480 async fn test_secondary_failure() {
481 let (atx, arx) = oneshot::channel::<()>();
482
483 let snd = Service::new().spawn(async move { bail!("boom") });
485 let svc = Service::new()
486 .spawn_aborting(async move { Ok(arx.await?) })
487 .attach(snd);
488
489 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
492 assert!(matches!(res, Err(Error::Task(_))));
493 assert!(atx.is_closed());
494 }
495
496 #[tokio::test]
497 #[should_panic]
498 async fn test_secondary_panic() {
499 let (_atx, arx) = oneshot::channel::<()>();
500
501 let snd = Service::new().spawn(async move { panic!("boom") });
503 let svc = Service::new()
504 .spawn_aborting(async move { Ok(arx.await?) })
505 .attach(snd);
506
507 let _ = svc.wait_for_shutdown(GRACE, std::future::pending).await;
509 }
510
511 #[tokio::test]
512 async fn test_secondary_graceful_shutdown() {
513 let (atx, arx) = oneshot::channel::<()>();
514 let (btx, brx) = oneshot::channel::<()>();
515 let (ctx, crx) = oneshot::channel::<()>();
516
517 let srx = Arc::new(Notify::new());
518 let stx = srx.clone();
519
520 let snd = Service::new()
522 .with_shutdown_signal(async move { atx.send(()).unwrap() })
523 .spawn(async move {
524 let _crx = crx;
525 Ok(arx.await?)
526 });
527
528 let svc = Service::new()
530 .spawn_aborting(async move { Ok(brx.await?) })
531 .attach(snd);
532
533 let handle =
535 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
536
537 assert!(!btx.is_closed());
539 assert!(!ctx.is_closed());
540
541 stx.notify_one();
543 let res = handle.await.unwrap();
544 assert!(matches!(res, Err(Error::Terminated)));
545 assert!(btx.is_closed());
546 assert!(ctx.is_closed());
547 }
548
549 #[tokio::test]
550 async fn test_merge() {
551 let (atx, arx) = oneshot::channel::<()>();
552 let (btx, brx) = oneshot::channel::<()>();
553 let (ctx, crx) = oneshot::channel::<()>();
554 let (dtx, drx) = oneshot::channel::<()>();
555 let (etx, erx) = oneshot::channel::<()>();
556 let (ftx, frx) = oneshot::channel::<()>();
557
558 let srx = Arc::new(Notify::new());
559 let stx = srx.clone();
560
561 let a = Service::new()
564 .spawn(async move { Ok(arx.await?) })
565 .with_shutdown_signal(async move { ctx.send(()).unwrap() })
566 .spawn(async move {
567 crx.await?;
568 dtx.send(()).unwrap();
569 Ok(())
570 });
571
572 let b = Service::new()
573 .spawn(async move { Ok(brx.await?) })
574 .with_shutdown_signal(async move { etx.send(()).unwrap() })
575 .spawn(async move {
576 erx.await?;
577 ftx.send(()).unwrap();
578 Ok(())
579 });
580
581 let svc = Service::new().merge(a).merge(b);
583 let handle =
584 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
585
586 atx.send(()).unwrap();
588 tokio::task::yield_now().await;
589
590 btx.send(()).unwrap();
591 tokio::task::yield_now().await;
592
593 stx.notify_one();
596 drx.await.unwrap();
597 frx.await.unwrap();
598
599 let res = handle.await.unwrap();
601 assert!(matches!(res, Err(Error::Terminated)));
602 }
603
604 #[tokio::test]
605 async fn test_drop_abort() {
606 let (mut atx, arx) = oneshot::channel::<()>();
607 let (mut btx, brx) = oneshot::channel::<()>();
608
609 let svc = Service::new()
610 .spawn(async move { Ok(arx.await?) })
611 .spawn_aborting(async move { Ok(brx.await?) });
612
613 assert!(!atx.is_closed());
614 assert!(!btx.is_closed());
615
616 drop(svc);
619 atx.closed().await;
620 btx.closed().await;
621 }
622
623 #[tokio::test]
624 async fn test_shutdown() {
625 let (atx, arx) = oneshot::channel::<()>();
626 let (btx, brx) = oneshot::channel::<()>();
627
628 let svc = Service::new()
629 .with_shutdown_signal(async move { atx.send(()).unwrap() })
630 .spawn(async move { Ok(arx.await?) })
631 .spawn_aborting(async move { Ok(brx.await?) });
632
633 svc.shutdown().await.unwrap();
637 assert!(btx.is_closed());
638 }
639
640 #[tokio::test]
641 async fn test_error_cascade() {
642 let (atx, arx) = oneshot::channel::<()>();
643
644 let svc = Service::new()
646 .spawn(async move { bail!("boom") })
647 .with_shutdown_signal(async move { atx.send(()).unwrap() })
648 .spawn(async move {
649 arx.await?;
650 bail!("boom, again")
651 });
652
653 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
655 assert!(matches!(res, Err(Error::Aborted)));
656 }
657
658 #[tokio::test]
659 async fn test_multiple_errors() {
660 let svc = Service::new()
663 .spawn(async move { bail!("boom") })
664 .spawn(async move { bail!("boom, again") });
665
666 let res = svc.wait_for_shutdown(GRACE, std::future::pending).await;
668 assert!(matches!(res, Err(Error::Aborted)));
669 }
670
671 #[tokio::test]
672 async fn test_termination_cascade() {
673 let svc = Service::new().spawn(std::future::pending());
675
676 let srx = Arc::new(Notify::new());
677 let stx = srx.clone();
678
679 let handle =
681 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
682
683 stx.notify_one();
685 tokio::task::yield_now().await;
686
687 stx.notify_one();
689 tokio::task::yield_now().await;
690
691 let res = handle.await.unwrap();
692 assert!(matches!(res, Err(Error::Aborted)));
693 }
694
695 #[tokio::test]
696 #[should_panic]
697 async fn test_panic_during_shutdown() {
698 let (atx, arx) = oneshot::channel::<()>();
699
700 let srx = Arc::new(Notify::new());
701 let stx = srx.clone();
702
703 let svc = Service::new()
704 .with_shutdown_signal(async move { atx.send(()).unwrap() })
705 .spawn(async move {
706 arx.await?;
707 panic!("boom")
708 });
709
710 let handle =
712 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
713
714 stx.notify_one();
716 let _ = handle.await.unwrap();
717 }
718
719 #[tokio::test(start_paused = true)]
720 async fn test_graceful_shutdown_timeout() {
721 let srx = Arc::new(Notify::new());
722 let stx = srx.clone();
723
724 let svc = Service::new().spawn(std::future::pending());
726
727 let handle =
728 tokio::spawn(svc.wait_for_shutdown(GRACE, move || srx.clone().notified_owned()));
729
730 stx.notify_one();
733 tokio::time::advance(GRACE * 2).await;
734
735 let res = handle.await.unwrap();
736 assert!(matches!(res, Err(Error::Aborted)));
737 }
738}