mas_tasks/
new_queue.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use std::{collections::HashMap, sync::Arc};
7
8use async_trait::async_trait;
9use chrono::{DateTime, Duration, Utc};
10use cron::Schedule;
11use mas_storage::{
12    Clock, RepositoryAccess, RepositoryError,
13    queue::{InsertableJob, Job, JobMetadata, Worker},
14};
15use mas_storage_pg::{DatabaseError, PgRepository};
16use opentelemetry::{
17    KeyValue,
18    metrics::{Counter, Histogram, UpDownCounter},
19};
20use rand::{Rng, RngCore, distributions::Uniform};
21use rand_chacha::ChaChaRng;
22use serde::de::DeserializeOwned;
23use sqlx::{
24    Acquire, Either,
25    postgres::{PgAdvisoryLock, PgListener},
26};
27use thiserror::Error;
28use tokio::{task::JoinSet, time::Instant};
29use tokio_util::sync::CancellationToken;
30use tracing::{Instrument as _, Span};
31use tracing_opentelemetry::OpenTelemetrySpanExt as _;
32use ulid::Ulid;
33
34use crate::{METER, State};
35
36type JobPayload = serde_json::Value;
37
38#[derive(Clone)]
39pub struct JobContext {
40    pub id: Ulid,
41    pub metadata: JobMetadata,
42    pub queue_name: String,
43    pub attempt: usize,
44    pub start: Instant,
45
46    #[expect(
47        dead_code,
48        reason = "we're not yet using this, but will be in the future"
49    )]
50    pub cancellation_token: CancellationToken,
51}
52
53impl JobContext {
54    pub fn span(&self) -> Span {
55        let span = tracing::info_span!(
56            parent: Span::none(),
57            "job.run",
58            job.id = %self.id,
59            job.queue.name = self.queue_name,
60            job.attempt = self.attempt,
61        );
62
63        span.add_link(self.metadata.span_context());
64
65        span
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub enum JobErrorDecision {
71    Retry,
72
73    #[default]
74    Fail,
75}
76
77impl std::fmt::Display for JobErrorDecision {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            Self::Retry => f.write_str("retry"),
81            Self::Fail => f.write_str("fail"),
82        }
83    }
84}
85
86#[derive(Debug, Error)]
87#[error("Job failed to run, will {decision}")]
88pub struct JobError {
89    decision: JobErrorDecision,
90    #[source]
91    error: anyhow::Error,
92}
93
94impl JobError {
95    pub fn retry<T: Into<anyhow::Error>>(error: T) -> Self {
96        Self {
97            decision: JobErrorDecision::Retry,
98            error: error.into(),
99        }
100    }
101
102    pub fn fail<T: Into<anyhow::Error>>(error: T) -> Self {
103        Self {
104            decision: JobErrorDecision::Fail,
105            error: error.into(),
106        }
107    }
108}
109
110pub trait FromJob {
111    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
112    where
113        Self: Sized;
114}
115
116impl<T> FromJob for T
117where
118    T: DeserializeOwned,
119{
120    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
121        serde_json::from_value(payload).map_err(Into::into)
122    }
123}
124
125#[async_trait]
126pub trait RunnableJob: FromJob + Send + 'static {
127    async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>;
128}
129
130fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
131    Box::new(job)
132}
133
134#[derive(Debug, Error)]
135pub enum QueueRunnerError {
136    #[error("Failed to setup listener")]
137    SetupListener(#[source] sqlx::Error),
138
139    #[error("Failed to start transaction")]
140    StartTransaction(#[source] sqlx::Error),
141
142    #[error("Failed to commit transaction")]
143    CommitTransaction(#[source] sqlx::Error),
144
145    #[error("Failed to acquire leader lock")]
146    LeaderLock(#[source] sqlx::Error),
147
148    #[error(transparent)]
149    Repository(#[from] RepositoryError),
150
151    #[error(transparent)]
152    Database(#[from] DatabaseError),
153
154    #[error("Invalid schedule expression")]
155    InvalidSchedule(#[from] cron::error::Error),
156
157    #[error("Worker is not the leader")]
158    NotLeader,
159}
160
161// When the worker waits for a notification, we still want to wake it up every
162// second. Because we don't want all the workers to wake up at the same time, we
163// add a random jitter to the sleep duration, so they effectively sleep between
164// 0.9 and 1.1 seconds.
165const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
166const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
167
168// How many jobs can we run concurrently
169const MAX_CONCURRENT_JOBS: usize = 10;
170
171// How many jobs can we fetch at once
172const MAX_JOBS_TO_FETCH: usize = 5;
173
174// How many attempts a job should be retried
175const MAX_ATTEMPTS: usize = 10;
176
177/// Returns the delay to wait before retrying a job
178///
179/// Uses an exponential backoff: 5s, 10s, 20s, 40s, 1m20s, 2m40s, 5m20s, 10m50s,
180/// 21m40s, 43m20s
181fn retry_delay(attempt: usize) -> Duration {
182    let attempt = u32::try_from(attempt).unwrap_or(u32::MAX);
183    Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
184}
185
186type JobResult = Result<(), JobError>;
187type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
188
189struct ScheduleDefinition {
190    schedule_name: &'static str,
191    expression: Schedule,
192    queue_name: &'static str,
193    payload: serde_json::Value,
194}
195
196pub struct QueueWorker {
197    rng: ChaChaRng,
198    clock: Box<dyn Clock + Send>,
199    listener: PgListener,
200    registration: Worker,
201    am_i_leader: bool,
202    last_heartbeat: DateTime<Utc>,
203    cancellation_token: CancellationToken,
204    #[expect(dead_code, reason = "This is used on Drop")]
205    cancellation_guard: tokio_util::sync::DropGuard,
206    state: State,
207    schedules: Vec<ScheduleDefinition>,
208    tracker: JobTracker,
209    wakeup_reason: Counter<u64>,
210    tick_time: Histogram<u64>,
211}
212
213impl QueueWorker {
214    #[tracing::instrument(
215        name = "worker.init",
216        skip_all,
217        fields(worker.id)
218    )]
219    pub async fn new(
220        state: State,
221        cancellation_token: CancellationToken,
222    ) -> Result<Self, QueueRunnerError> {
223        let mut rng = state.rng();
224        let clock = state.clock();
225
226        let mut listener = PgListener::connect_with(state.pool())
227            .await
228            .map_err(QueueRunnerError::SetupListener)?;
229
230        // We get notifications of leader stepping down on this channel
231        listener
232            .listen("queue_leader_stepdown")
233            .await
234            .map_err(QueueRunnerError::SetupListener)?;
235
236        // We get notifications when a job is available on this channel
237        listener
238            .listen("queue_available")
239            .await
240            .map_err(QueueRunnerError::SetupListener)?;
241
242        let txn = listener
243            .begin()
244            .await
245            .map_err(QueueRunnerError::StartTransaction)?;
246        let mut repo = PgRepository::from_conn(txn);
247
248        let registration = repo.queue_worker().register(&mut rng, &clock).await?;
249        tracing::Span::current().record("worker.id", tracing::field::display(registration.id));
250        repo.into_inner()
251            .commit()
252            .await
253            .map_err(QueueRunnerError::CommitTransaction)?;
254
255        tracing::info!("Registered worker");
256        let now = clock.now();
257
258        let wakeup_reason = METER
259            .u64_counter("job.worker.wakeups")
260            .with_description("Counts how many time the worker has been woken up, for which reason")
261            .build();
262
263        // Pre-create the reasons on the counter
264        wakeup_reason.add(0, &[KeyValue::new("reason", "sleep")]);
265        wakeup_reason.add(0, &[KeyValue::new("reason", "task")]);
266        wakeup_reason.add(0, &[KeyValue::new("reason", "notification")]);
267
268        let tick_time = METER
269            .u64_histogram("job.worker.tick_duration")
270            .with_description(
271                "How much time the worker took to tick, including performing leader duties",
272            )
273            .build();
274
275        // We put a cancellation drop guard in the structure, so that when it gets
276        // dropped, we're sure to cancel the token
277        let cancellation_guard = cancellation_token.clone().drop_guard();
278
279        Ok(Self {
280            rng,
281            clock,
282            listener,
283            registration,
284            am_i_leader: false,
285            last_heartbeat: now,
286            cancellation_token,
287            cancellation_guard,
288            state,
289            schedules: Vec::new(),
290            tracker: JobTracker::new(),
291            wakeup_reason,
292            tick_time,
293        })
294    }
295
296    pub fn register_handler<T: RunnableJob + InsertableJob>(&mut self) -> &mut Self {
297        // There is a potential panic here, which is fine as it's going to be caught
298        // within the job task
299        let factory = |payload: JobPayload| {
300            box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
301        };
302
303        self.tracker
304            .factories
305            .insert(T::QUEUE_NAME, Arc::new(factory));
306        self
307    }
308
309    pub fn add_schedule<T: InsertableJob>(
310        &mut self,
311        schedule_name: &'static str,
312        expression: Schedule,
313        job: T,
314    ) -> &mut Self {
315        let payload = serde_json::to_value(job).expect("failed to serialize job payload");
316
317        self.schedules.push(ScheduleDefinition {
318            schedule_name,
319            expression,
320            queue_name: T::QUEUE_NAME,
321            payload,
322        });
323
324        self
325    }
326
327    pub async fn run(mut self) {
328        if let Err(e) = self.run_inner().await {
329            tracing::error!(
330                error = &e as &dyn std::error::Error,
331                "Failed to run new queue"
332            );
333        }
334    }
335
336    async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
337        self.setup_schedules().await?;
338
339        while !self.cancellation_token.is_cancelled() {
340            self.run_loop().await?;
341        }
342
343        self.shutdown().await?;
344
345        Ok(())
346    }
347
348    #[tracing::instrument(name = "worker.setup_schedules", skip_all, err)]
349    pub async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
350        let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
351
352        // Start a transaction on the existing PgListener connection
353        let txn = self
354            .listener
355            .begin()
356            .await
357            .map_err(QueueRunnerError::StartTransaction)?;
358
359        let mut repo = PgRepository::from_conn(txn);
360
361        // Setup the entries in the queue_schedules table
362        repo.queue_schedule().setup(&schedules).await?;
363
364        repo.into_inner()
365            .commit()
366            .await
367            .map_err(QueueRunnerError::CommitTransaction)?;
368
369        Ok(())
370    }
371
372    #[tracing::instrument(name = "worker.run_loop", skip_all, err)]
373    async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
374        self.wait_until_wakeup().await?;
375
376        if self.cancellation_token.is_cancelled() {
377            return Ok(());
378        }
379
380        let start = Instant::now();
381        self.tick().await?;
382
383        if self.am_i_leader {
384            self.perform_leader_duties().await?;
385        }
386
387        let elapsed = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
388        self.tick_time.record(elapsed, &[]);
389
390        Ok(())
391    }
392
393    #[tracing::instrument(name = "worker.shutdown", skip_all, err)]
394    async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
395        tracing::info!("Shutting down worker");
396
397        // Start a transaction on the existing PgListener connection
398        let txn = self
399            .listener
400            .begin()
401            .await
402            .map_err(QueueRunnerError::StartTransaction)?;
403
404        let mut repo = PgRepository::from_conn(txn);
405
406        // Log about any job still running
407        match self.tracker.running_jobs() {
408            0 => {}
409            1 => tracing::warn!("There is one job still running, waiting for it to finish"),
410            n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"),
411        }
412
413        // TODO: we may want to introduce a timeout here, and abort the tasks if they
414        // take too long. It's fine for now, as we don't have long-running
415        // tasks, most of them are idempotent, and the only effect might be that
416        // the worker would 'dirtily' shutdown, meaning that its tasks would be
417        // considered, later retried by another worker
418
419        // Wait for all the jobs to finish
420        self.tracker
421            .process_jobs(&mut self.rng, &self.clock, &mut repo, true)
422            .await?;
423
424        // Tell the other workers we're shutting down
425        // This also releases the leader election lease
426        repo.queue_worker()
427            .shutdown(&self.clock, &self.registration)
428            .await?;
429
430        repo.into_inner()
431            .commit()
432            .await
433            .map_err(QueueRunnerError::CommitTransaction)?;
434
435        Ok(())
436    }
437
438    #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all, err)]
439    async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
440        // This is to make sure we wake up every second to do the maintenance tasks
441        // We add a little bit of random jitter to the duration, so that we don't get
442        // fully synced workers waking up at the same time after each notification
443        let sleep_duration = self
444            .rng
445            .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
446        let wakeup_sleep = tokio::time::sleep(sleep_duration);
447
448        tokio::select! {
449            () = self.cancellation_token.cancelled() => {
450                tracing::debug!("Woke up from cancellation");
451            },
452
453            () = wakeup_sleep => {
454                tracing::debug!("Woke up from sleep");
455                self.wakeup_reason.add(1, &[KeyValue::new("reason", "sleep")]);
456            },
457
458            () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => {
459                tracing::debug!("Joined job task");
460                self.wakeup_reason.add(1, &[KeyValue::new("reason", "task")]);
461            },
462
463            notification = self.listener.recv() => {
464                self.wakeup_reason.add(1, &[KeyValue::new("reason", "notification")]);
465                match notification {
466                    Ok(notification) => {
467                        tracing::debug!(
468                            notification.channel = notification.channel(),
469                            notification.payload = notification.payload(),
470                            "Woke up from notification"
471                        );
472                    },
473                    Err(e) => {
474                        tracing::error!(error = &e as &dyn std::error::Error, "Failed to receive notification");
475                    },
476                }
477            },
478        }
479
480        Ok(())
481    }
482
483    #[tracing::instrument(
484        name = "worker.tick",
485        skip_all,
486        fields(worker.id = %self.registration.id),
487        err,
488    )]
489    async fn tick(&mut self) -> Result<(), QueueRunnerError> {
490        tracing::debug!("Tick");
491        let now = self.clock.now();
492
493        // Start a transaction on the existing PgListener connection
494        let txn = self
495            .listener
496            .begin()
497            .await
498            .map_err(QueueRunnerError::StartTransaction)?;
499        let mut repo = PgRepository::from_conn(txn);
500
501        // We send a heartbeat every minute, to avoid writing to the database too often
502        // on a logged table
503        if now - self.last_heartbeat >= chrono::Duration::minutes(1) {
504            tracing::info!("Sending heartbeat");
505            repo.queue_worker()
506                .heartbeat(&self.clock, &self.registration)
507                .await?;
508            self.last_heartbeat = now;
509        }
510
511        // Remove any dead worker leader leases
512        repo.queue_worker()
513            .remove_leader_lease_if_expired(&self.clock)
514            .await?;
515
516        // Try to become (or stay) the leader
517        let leader = repo
518            .queue_worker()
519            .try_get_leader_lease(&self.clock, &self.registration)
520            .await?;
521
522        // Process any job task which finished
523        self.tracker
524            .process_jobs(&mut self.rng, &self.clock, &mut repo, false)
525            .await?;
526
527        // Compute how many jobs we should fetch at most
528        let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
529            .saturating_sub(self.tracker.running_jobs())
530            .max(MAX_JOBS_TO_FETCH);
531
532        if max_jobs_to_fetch == 0 {
533            tracing::warn!("Internal job queue is full, not fetching any new jobs");
534        } else {
535            // Grab a few jobs in the queue
536            let queues = self.tracker.queues();
537            let jobs = repo
538                .queue_job()
539                .reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch)
540                .await?;
541
542            for Job {
543                id,
544                queue_name,
545                payload,
546                metadata,
547                attempt,
548            } in jobs
549            {
550                let cancellation_token = self.cancellation_token.child_token();
551                let start = Instant::now();
552                let context = JobContext {
553                    id,
554                    metadata,
555                    queue_name,
556                    attempt,
557                    start,
558                    cancellation_token,
559                };
560
561                self.tracker.spawn_job(self.state.clone(), context, payload);
562            }
563        }
564
565        // After this point, we are locking the leader table, so it's important that we
566        // commit as soon as possible to not block the other workers for too long
567        repo.into_inner()
568            .commit()
569            .await
570            .map_err(QueueRunnerError::CommitTransaction)?;
571
572        // Save the new leader state to log any change
573        if leader != self.am_i_leader {
574            // If we flipped state, log it
575            self.am_i_leader = leader;
576            if self.am_i_leader {
577                tracing::info!("I'm the leader now");
578            } else {
579                tracing::warn!("I am no longer the leader");
580            }
581        }
582
583        Ok(())
584    }
585
586    #[tracing::instrument(name = "worker.perform_leader_duties", skip_all, err)]
587    async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
588        // This should have been checked by the caller, but better safe than sorry
589        if !self.am_i_leader {
590            return Err(QueueRunnerError::NotLeader);
591        }
592
593        // Start a transaction on the existing PgListener connection
594        let txn = self
595            .listener
596            .begin()
597            .await
598            .map_err(QueueRunnerError::StartTransaction)?;
599
600        // The thing with the leader election is that it locks the table during the
601        // election, preventing other workers from going through the loop.
602        //
603        // Ideally, we would do the leader duties in the same transaction so that we
604        // make sure only one worker is doing the leader duties, but that
605        // would mean we would lock all the workers for the duration of the
606        // duties, which is not ideal.
607        //
608        // So we do the duties in a separate transaction, in which we take an advisory
609        // lock, so that in the very rare case where two workers think they are the
610        // leader, we still don't have two workers doing the duties at the same time.
611        let lock = PgAdvisoryLock::new("leader-duties");
612
613        let locked = lock
614            .try_acquire(txn)
615            .await
616            .map_err(QueueRunnerError::LeaderLock)?;
617
618        let locked = match locked {
619            Either::Left(locked) => locked,
620            Either::Right(txn) => {
621                tracing::error!("Another worker has the leader lock, aborting");
622                txn.rollback()
623                    .await
624                    .map_err(QueueRunnerError::CommitTransaction)?;
625                return Ok(());
626            }
627        };
628
629        let mut repo = PgRepository::from_conn(locked);
630
631        // Look at the state of schedules in the database
632        let schedules_status = repo.queue_schedule().list().await?;
633
634        let now = self.clock.now();
635        for schedule in &self.schedules {
636            // Find the schedule status from the database
637            let Some(schedule_status) = schedules_status
638                .iter()
639                .find(|s| s.schedule_name == schedule.schedule_name)
640            else {
641                tracing::error!(
642                    "Schedule {} was not found in the database",
643                    schedule.schedule_name
644                );
645                continue;
646            };
647
648            // Figure out if we should schedule a new job
649            if let Some(next_time) = schedule_status.last_scheduled_at {
650                if next_time > now {
651                    // We already have a job scheduled in the future, skip
652                    continue;
653                }
654
655                if schedule_status.last_scheduled_job_completed == Some(false) {
656                    // The last scheduled job has not completed yet, skip
657                    continue;
658                }
659            }
660
661            let next_tick = schedule.expression.after(&now).next().unwrap();
662
663            tracing::info!(
664                "Scheduling job for {}, next run at {}",
665                schedule.schedule_name,
666                next_tick
667            );
668
669            repo.queue_job()
670                .schedule_later(
671                    &mut self.rng,
672                    &self.clock,
673                    schedule.queue_name,
674                    schedule.payload.clone(),
675                    serde_json::json!({}),
676                    next_tick,
677                    Some(schedule.schedule_name),
678                )
679                .await?;
680        }
681
682        // We also check if the worker is dead, and if so, we shutdown all the dead
683        // workers that haven't checked in the last two minutes
684        repo.queue_worker()
685            .shutdown_dead_workers(&self.clock, Duration::minutes(2))
686            .await?;
687
688        // TODO: mark tasks those workers had as lost
689
690        // Mark all the scheduled jobs as available
691        let scheduled = repo
692            .queue_job()
693            .schedule_available_jobs(&self.clock)
694            .await?;
695        match scheduled {
696            0 => {}
697            1 => tracing::info!("One scheduled job marked as available"),
698            n => tracing::info!("{n} scheduled jobs marked as available"),
699        }
700
701        // Release the leader lock
702        let txn = repo
703            .into_inner()
704            .release_now()
705            .await
706            .map_err(QueueRunnerError::LeaderLock)?;
707
708        txn.commit()
709            .await
710            .map_err(QueueRunnerError::CommitTransaction)?;
711
712        Ok(())
713    }
714}
715
716/// Tracks running jobs
717///
718/// This is a separate structure to be able to borrow it mutably at the same
719/// time as the connection to the database is borrowed
720struct JobTracker {
721    /// Stores a mapping from the job queue name to the job factory
722    factories: HashMap<&'static str, JobFactory>,
723
724    /// A join set of all the currently running jobs
725    running_jobs: JoinSet<JobResult>,
726
727    /// Stores a mapping from the Tokio task ID to the job context
728    job_contexts: HashMap<tokio::task::Id, JobContext>,
729
730    /// Stores the last `join_next_with_id` result for processing, in case we
731    /// got woken up in `collect_next_job`
732    last_join_result: Option<Result<(tokio::task::Id, JobResult), tokio::task::JoinError>>,
733
734    /// An histogram which records the time it takes to process a job
735    job_processing_time: Histogram<u64>,
736
737    /// A counter which records the number of jobs currently in flight
738    in_flight_jobs: UpDownCounter<i64>,
739}
740
741impl JobTracker {
742    fn new() -> Self {
743        let job_processing_time = METER
744            .u64_histogram("job.process.duration")
745            .with_description("The time it takes to process a job in milliseconds")
746            .with_unit("ms")
747            .build();
748
749        let in_flight_jobs = METER
750            .i64_up_down_counter("job.active_tasks")
751            .with_description("The number of jobs currently in flight")
752            .with_unit("{job}")
753            .build();
754
755        Self {
756            factories: HashMap::new(),
757            running_jobs: JoinSet::new(),
758            job_contexts: HashMap::new(),
759            last_join_result: None,
760            job_processing_time,
761            in_flight_jobs,
762        }
763    }
764
765    /// Returns the queue names that are currently being tracked
766    fn queues(&self) -> Vec<&'static str> {
767        self.factories.keys().copied().collect()
768    }
769
770    /// Spawn a job on the job tracker
771    fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
772        let factory = self.factories.get(context.queue_name.as_str()).cloned();
773        let task = {
774            let context = context.clone();
775            let span = context.span();
776            async move {
777                // We should never crash, but in case we do, we do that in the task and
778                // don't crash the worker
779                let job = factory.expect("unknown job factory")(payload);
780                tracing::info!("Running job");
781                job.run(&state, context).await
782            }
783            .instrument(span)
784        };
785
786        self.in_flight_jobs.add(
787            1,
788            &[KeyValue::new("job.queue.name", context.queue_name.clone())],
789        );
790
791        let handle = self.running_jobs.spawn(task);
792        self.job_contexts.insert(handle.id(), context);
793    }
794
795    /// Returns `true` if there are currently running jobs
796    fn has_jobs(&self) -> bool {
797        !self.running_jobs.is_empty()
798    }
799
800    /// Returns the number of currently running jobs
801    ///
802    /// This also includes the job result which may be stored for processing
803    fn running_jobs(&self) -> usize {
804        self.running_jobs.len() + usize::from(self.last_join_result.is_some())
805    }
806
807    async fn collect_next_job(&mut self) {
808        // Double-check that we don't have a job result stored
809        if self.last_join_result.is_some() {
810            tracing::error!(
811                "Job tracker already had a job result stored, this should never happen!"
812            );
813            return;
814        }
815
816        self.last_join_result = self.running_jobs.join_next_with_id().await;
817    }
818
819    /// Process all the jobs which are currently running
820    ///
821    /// If `blocking` is `true`, this function will block until all the jobs
822    /// are finished. Otherwise, it will return as soon as it processed the
823    /// already finished jobs.
824    #[allow(clippy::too_many_lines)]
825    async fn process_jobs<E: std::error::Error + Send + Sync + 'static>(
826        &mut self,
827        rng: &mut (dyn RngCore + Send),
828        clock: &dyn Clock,
829        repo: &mut dyn RepositoryAccess<Error = E>,
830        blocking: bool,
831    ) -> Result<(), E> {
832        if self.last_join_result.is_none() {
833            if blocking {
834                self.last_join_result = self.running_jobs.join_next_with_id().await;
835            } else {
836                self.last_join_result = self.running_jobs.try_join_next_with_id();
837            }
838        }
839
840        // XXX: the time measurement isn't accurate, as it would include the
841        // time spent between the task finishing, and us processing the result.
842        // It's fine for now, as it at least gives us an idea of how many tasks
843        // we run, and what their status is
844
845        while let Some(result) = self.last_join_result.take() {
846            match result {
847                // The job succeeded
848                Ok((id, Ok(()))) => {
849                    let context = self
850                        .job_contexts
851                        .remove(&id)
852                        .expect("Job context not found");
853
854                    self.in_flight_jobs.add(
855                        -1,
856                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
857                    );
858
859                    let elapsed = context
860                        .start
861                        .elapsed()
862                        .as_millis()
863                        .try_into()
864                        .unwrap_or(u64::MAX);
865                    tracing::info!(
866                        job.id = %context.id,
867                        job.queue.name = %context.queue_name,
868                        job.attempt = %context.attempt,
869                        job.elapsed = format!("{elapsed}ms"),
870                        "Job completed"
871                    );
872
873                    self.job_processing_time.record(
874                        elapsed,
875                        &[
876                            KeyValue::new("job.queue.name", context.queue_name),
877                            KeyValue::new("job.result", "success"),
878                        ],
879                    );
880
881                    repo.queue_job()
882                        .mark_as_completed(clock, context.id)
883                        .await?;
884                }
885
886                // The job failed
887                Ok((id, Err(e))) => {
888                    let context = self
889                        .job_contexts
890                        .remove(&id)
891                        .expect("Job context not found");
892
893                    self.in_flight_jobs.add(
894                        -1,
895                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
896                    );
897
898                    let reason = format!("{:?}", e.error);
899                    repo.queue_job()
900                        .mark_as_failed(clock, context.id, &reason)
901                        .await?;
902
903                    let elapsed = context
904                        .start
905                        .elapsed()
906                        .as_millis()
907                        .try_into()
908                        .unwrap_or(u64::MAX);
909
910                    match e.decision {
911                        JobErrorDecision::Fail => {
912                            tracing::error!(
913                                error = &e as &dyn std::error::Error,
914                                job.id = %context.id,
915                                job.queue.name = %context.queue_name,
916                                job.attempt = %context.attempt,
917                                job.elapsed = format!("{elapsed}ms"),
918                                "Job failed, not retrying"
919                            );
920
921                            self.job_processing_time.record(
922                                elapsed,
923                                &[
924                                    KeyValue::new("job.queue.name", context.queue_name),
925                                    KeyValue::new("job.result", "failed"),
926                                    KeyValue::new("job.decision", "fail"),
927                                ],
928                            );
929                        }
930
931                        JobErrorDecision::Retry => {
932                            if context.attempt < MAX_ATTEMPTS {
933                                let delay = retry_delay(context.attempt);
934                                tracing::warn!(
935                                    error = &e as &dyn std::error::Error,
936                                    job.id = %context.id,
937                                    job.queue.name = %context.queue_name,
938                                    job.attempt = %context.attempt,
939                                    job.elapsed = format!("{elapsed}ms"),
940                                    "Job failed, will retry in {}s",
941                                    delay.num_seconds()
942                                );
943
944                                self.job_processing_time.record(
945                                    elapsed,
946                                    &[
947                                        KeyValue::new("job.queue.name", context.queue_name),
948                                        KeyValue::new("job.result", "failed"),
949                                        KeyValue::new("job.decision", "retry"),
950                                    ],
951                                );
952
953                                repo.queue_job()
954                                    .retry(&mut *rng, clock, context.id, delay)
955                                    .await?;
956                            } else {
957                                tracing::error!(
958                                    error = &e as &dyn std::error::Error,
959                                    job.id = %context.id,
960                                    job.queue.name = %context.queue_name,
961                                    job.attempt = %context.attempt,
962                                    job.elapsed = format!("{elapsed}ms"),
963                                    "Job failed too many times, abandonning"
964                                );
965
966                                self.job_processing_time.record(
967                                    elapsed,
968                                    &[
969                                        KeyValue::new("job.queue.name", context.queue_name),
970                                        KeyValue::new("job.result", "failed"),
971                                        KeyValue::new("job.decision", "abandon"),
972                                    ],
973                                );
974                            }
975                        }
976                    }
977                }
978
979                // The job crashed (or was aborted)
980                Err(e) => {
981                    let id = e.id();
982                    let context = self
983                        .job_contexts
984                        .remove(&id)
985                        .expect("Job context not found");
986
987                    self.in_flight_jobs.add(
988                        -1,
989                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
990                    );
991
992                    let elapsed = context
993                        .start
994                        .elapsed()
995                        .as_millis()
996                        .try_into()
997                        .unwrap_or(u64::MAX);
998
999                    let reason = e.to_string();
1000                    repo.queue_job()
1001                        .mark_as_failed(clock, context.id, &reason)
1002                        .await?;
1003
1004                    if context.attempt < MAX_ATTEMPTS {
1005                        let delay = retry_delay(context.attempt);
1006                        tracing::warn!(
1007                            error = &e as &dyn std::error::Error,
1008                            job.id = %context.id,
1009                            job.queue.name = %context.queue_name,
1010                            job.attempt = %context.attempt,
1011                            job.elapsed = format!("{elapsed}ms"),
1012                            "Job crashed, will retry in {}s",
1013                            delay.num_seconds()
1014                        );
1015
1016                        self.job_processing_time.record(
1017                            elapsed,
1018                            &[
1019                                KeyValue::new("job.queue.name", context.queue_name),
1020                                KeyValue::new("job.result", "crashed"),
1021                                KeyValue::new("job.decision", "retry"),
1022                            ],
1023                        );
1024
1025                        repo.queue_job()
1026                            .retry(&mut *rng, clock, context.id, delay)
1027                            .await?;
1028                    } else {
1029                        tracing::error!(
1030                            error = &e as &dyn std::error::Error,
1031                            job.id = %context.id,
1032                            job.queue.name = %context.queue_name,
1033                            job.attempt = %context.attempt,
1034                            job.elapsed = format!("{elapsed}ms"),
1035                            "Job crashed too many times, abandonning"
1036                        );
1037
1038                        self.job_processing_time.record(
1039                            elapsed,
1040                            &[
1041                                KeyValue::new("job.queue.name", context.queue_name),
1042                                KeyValue::new("job.result", "crashed"),
1043                                KeyValue::new("job.decision", "abandon"),
1044                            ],
1045                        );
1046                    }
1047                }
1048            }
1049
1050            if blocking {
1051                self.last_join_result = self.running_jobs.join_next_with_id().await;
1052            } else {
1053                self.last_join_result = self.running_jobs.try_join_next_with_id();
1054            }
1055        }
1056
1057        Ok(())
1058    }
1059}