mas_storage_pg/queue/
job.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the
8//! [`QueueJobRepository`].
9
10use async_trait::async_trait;
11use chrono::{DateTime, Duration, Utc};
12use mas_data_model::Clock;
13use mas_storage::queue::{Job, QueueJobRepository, Worker};
14use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
15use rand::RngCore;
16use sqlx::PgConnection;
17use tracing::Instrument;
18use ulid::Ulid;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt};
22
23/// An implementation of [`QueueJobRepository`] for a PostgreSQL connection.
24pub struct PgQueueJobRepository<'c> {
25    conn: &'c mut PgConnection,
26}
27
28impl<'c> PgQueueJobRepository<'c> {
29    /// Create a new [`PgQueueJobRepository`] from an active PostgreSQL
30    /// connection.
31    #[must_use]
32    pub fn new(conn: &'c mut PgConnection) -> Self {
33        Self { conn }
34    }
35}
36
37struct JobReservationResult {
38    queue_job_id: Uuid,
39    queue_name: String,
40    payload: serde_json::Value,
41    metadata: serde_json::Value,
42    attempt: i32,
43}
44
45impl TryFrom<JobReservationResult> for Job {
46    type Error = DatabaseInconsistencyError;
47
48    fn try_from(value: JobReservationResult) -> Result<Self, Self::Error> {
49        let id = value.queue_job_id.into();
50        let queue_name = value.queue_name;
51        let payload = value.payload;
52
53        let metadata = serde_json::from_value(value.metadata).map_err(|e| {
54            DatabaseInconsistencyError::on("queue_jobs")
55                .column("metadata")
56                .row(id)
57                .source(e)
58        })?;
59
60        let attempt = value.attempt.try_into().map_err(|e| {
61            DatabaseInconsistencyError::on("queue_jobs")
62                .column("attempt")
63                .row(id)
64                .source(e)
65        })?;
66
67        Ok(Self {
68            id,
69            queue_name,
70            payload,
71            metadata,
72            attempt,
73        })
74    }
75}
76
77#[async_trait]
78impl QueueJobRepository for PgQueueJobRepository<'_> {
79    type Error = DatabaseError;
80
81    #[tracing::instrument(
82        name = "db.queue_job.schedule",
83        fields(
84            queue_job.id,
85            queue_job.queue_name = queue_name,
86            db.query.text,
87        ),
88        skip_all,
89        err,
90    )]
91    async fn schedule(
92        &mut self,
93        rng: &mut (dyn RngCore + Send),
94        clock: &dyn Clock,
95        queue_name: &str,
96        payload: serde_json::Value,
97        metadata: serde_json::Value,
98    ) -> Result<(), Self::Error> {
99        let created_at = clock.now();
100        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
101        tracing::Span::current().record("queue_job.id", tracing::field::display(id));
102
103        sqlx::query!(
104            r#"
105                INSERT INTO queue_jobs
106                    (queue_job_id, queue_name, payload, metadata, created_at)
107                VALUES ($1, $2, $3, $4, $5)
108            "#,
109            Uuid::from(id),
110            queue_name,
111            payload,
112            metadata,
113            created_at,
114        )
115        .traced()
116        .execute(&mut *self.conn)
117        .await?;
118
119        Ok(())
120    }
121
122    #[tracing::instrument(
123        name = "db.queue_job.schedule_later",
124        fields(
125            queue_job.id,
126            queue_job.queue_name = queue_name,
127            queue_job.scheduled_at = %scheduled_at,
128            db.query.text,
129        ),
130        skip_all,
131        err,
132    )]
133    async fn schedule_later(
134        &mut self,
135        rng: &mut (dyn RngCore + Send),
136        clock: &dyn Clock,
137        queue_name: &str,
138        payload: serde_json::Value,
139        metadata: serde_json::Value,
140        scheduled_at: DateTime<Utc>,
141        schedule_name: Option<&str>,
142    ) -> Result<(), Self::Error> {
143        let created_at = clock.now();
144        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
145        tracing::Span::current().record("queue_job.id", tracing::field::display(id));
146
147        sqlx::query!(
148            r#"
149                INSERT INTO queue_jobs
150                    (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, schedule_name, status)
151                VALUES ($1, $2, $3, $4, $5, $6, $7, 'scheduled')
152            "#,
153            Uuid::from(id),
154            queue_name,
155            payload,
156            metadata,
157            created_at,
158            scheduled_at,
159            schedule_name,
160        )
161        .traced()
162        .execute(&mut *self.conn)
163        .await?;
164
165        // If there was a schedule name supplied, update the queue_schedules table
166        if let Some(schedule_name) = schedule_name {
167            let span = tracing::info_span!(
168                "db.queue_job.schedule_later.update_schedules",
169                { DB_QUERY_TEXT } = tracing::field::Empty,
170            );
171
172            let res = sqlx::query!(
173                r#"
174                    UPDATE queue_schedules
175                    SET last_scheduled_at = $1,
176                        last_scheduled_job_id = $2
177                    WHERE schedule_name = $3
178                "#,
179                scheduled_at,
180                Uuid::from(id),
181                schedule_name,
182            )
183            .record(&span)
184            .execute(&mut *self.conn)
185            .instrument(span)
186            .await?;
187
188            DatabaseError::ensure_affected_rows(&res, 1)?;
189        }
190
191        Ok(())
192    }
193
194    #[tracing::instrument(
195        name = "db.queue_job.reserve",
196        skip_all,
197        fields(
198            db.query.text,
199        ),
200        err,
201    )]
202    async fn reserve(
203        &mut self,
204        clock: &dyn Clock,
205        worker: &Worker,
206        queues: &[&str],
207        count: usize,
208    ) -> Result<Vec<Job>, Self::Error> {
209        let now = clock.now();
210        let max_count = i64::try_from(count).unwrap_or(i64::MAX);
211        let queues: Vec<String> = queues.iter().map(|&s| s.to_owned()).collect();
212        let results = sqlx::query_as!(
213            JobReservationResult,
214            r#"
215                -- We first grab a few jobs that are available,
216                -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently
217                -- and we don't get multiple workers grabbing the same jobs
218                WITH locked_jobs AS (
219                    SELECT queue_job_id
220                    FROM queue_jobs
221                    WHERE
222                        status = 'available'
223                        AND queue_name = ANY($1)
224                    ORDER BY queue_job_id ASC
225                    LIMIT $2
226                    FOR UPDATE
227                    SKIP LOCKED
228                )
229                -- then we update the status of those jobs to 'running', returning the job details
230                UPDATE queue_jobs
231                SET status = 'running', started_at = $3, started_by = $4
232                FROM locked_jobs
233                WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id
234                RETURNING
235                    queue_jobs.queue_job_id,
236                    queue_jobs.queue_name,
237                    queue_jobs.payload,
238                    queue_jobs.metadata,
239                    queue_jobs.attempt
240            "#,
241            &queues,
242            max_count,
243            now,
244            Uuid::from(worker.id),
245        )
246        .traced()
247        .fetch_all(&mut *self.conn)
248        .await?;
249
250        let jobs = results
251            .into_iter()
252            .map(TryFrom::try_from)
253            .collect::<Result<Vec<_>, _>>()?;
254
255        Ok(jobs)
256    }
257
258    #[tracing::instrument(
259        name = "db.queue_job.mark_as_completed",
260        skip_all,
261        fields(
262            db.query.text,
263            job.id = %id,
264        ),
265        err,
266    )]
267    async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> {
268        let now = clock.now();
269        let res = sqlx::query!(
270            r#"
271                UPDATE queue_jobs
272                SET status = 'completed', completed_at = $1
273                WHERE queue_job_id = $2 AND status = 'running'
274            "#,
275            now,
276            Uuid::from(id),
277        )
278        .traced()
279        .execute(&mut *self.conn)
280        .await?;
281
282        DatabaseError::ensure_affected_rows(&res, 1)?;
283
284        Ok(())
285    }
286
287    #[tracing::instrument(
288        name = "db.queue_job.mark_as_failed",
289        skip_all,
290        fields(
291            db.query.text,
292            job.id = %id,
293        ),
294        err
295    )]
296    async fn mark_as_failed(
297        &mut self,
298        clock: &dyn Clock,
299        id: Ulid,
300        reason: &str,
301    ) -> Result<(), Self::Error> {
302        let now = clock.now();
303        let res = sqlx::query!(
304            r#"
305                UPDATE queue_jobs
306                SET
307                    status = 'failed',
308                    failed_at = $1,
309                    failed_reason = $2
310                WHERE
311                    queue_job_id = $3
312                    AND status = 'running'
313            "#,
314            now,
315            reason,
316            Uuid::from(id),
317        )
318        .traced()
319        .execute(&mut *self.conn)
320        .await?;
321
322        DatabaseError::ensure_affected_rows(&res, 1)?;
323
324        Ok(())
325    }
326
327    #[tracing::instrument(
328        name = "db.queue_job.retry",
329        skip_all,
330        fields(
331            db.query.text,
332            job.id = %id,
333        ),
334        err
335    )]
336    async fn retry(
337        &mut self,
338        rng: &mut (dyn RngCore + Send),
339        clock: &dyn Clock,
340        id: Ulid,
341        delay: Duration,
342    ) -> Result<(), Self::Error> {
343        let now = clock.now();
344        let scheduled_at = now + delay;
345        let new_id = Ulid::from_datetime_with_source(now.into(), rng);
346
347        let span = tracing::info_span!(
348            "db.queue_job.retry.insert_job",
349            { DB_QUERY_TEXT } = tracing::field::Empty
350        );
351        // Create a new job with the same payload and metadata, but a new ID and
352        // increment the attempt
353        // We make sure we do this only for 'failed' jobs
354        let res = sqlx::query!(
355            r#"
356                INSERT INTO queue_jobs
357                    (queue_job_id, queue_name, payload, metadata, created_at,
358                     attempt, scheduled_at, schedule_name, status)
359                SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, schedule_name, 'scheduled'
360                FROM queue_jobs
361                WHERE queue_job_id = $4
362                  AND status = 'failed'
363            "#,
364            Uuid::from(new_id),
365            now,
366            scheduled_at,
367            Uuid::from(id),
368        )
369        .record(&span)
370        .execute(&mut *self.conn)
371        .instrument(span)
372        .await?;
373
374        DatabaseError::ensure_affected_rows(&res, 1)?;
375
376        // If that job was referenced by a schedule, update the schedule
377        let span = tracing::info_span!(
378            "db.queue_job.retry.update_schedule",
379            { DB_QUERY_TEXT } = tracing::field::Empty
380        );
381        sqlx::query!(
382            r#"
383                UPDATE queue_schedules
384                SET last_scheduled_at = $1,
385                    last_scheduled_job_id = $2
386                WHERE last_scheduled_job_id = $3
387            "#,
388            scheduled_at,
389            Uuid::from(new_id),
390            Uuid::from(id),
391        )
392        .record(&span)
393        .execute(&mut *self.conn)
394        .instrument(span)
395        .await?;
396
397        // Update the old job to point to the new attempt
398        let span = tracing::info_span!(
399            "db.queue_job.retry.update_old_job",
400            { DB_QUERY_TEXT } = tracing::field::Empty
401        );
402        let res = sqlx::query!(
403            r#"
404                UPDATE queue_jobs
405                SET next_attempt_id = $1
406                WHERE queue_job_id = $2
407            "#,
408            Uuid::from(new_id),
409            Uuid::from(id),
410        )
411        .record(&span)
412        .execute(&mut *self.conn)
413        .instrument(span)
414        .await?;
415
416        DatabaseError::ensure_affected_rows(&res, 1)?;
417
418        Ok(())
419    }
420
421    #[tracing::instrument(
422        name = "db.queue_job.schedule_available_jobs",
423        skip_all,
424        fields(
425            db.query.text,
426        ),
427        err
428    )]
429    async fn schedule_available_jobs(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error> {
430        let now = clock.now();
431        let res = sqlx::query!(
432            r#"
433                UPDATE queue_jobs
434                SET status = 'available'
435                WHERE
436                    status = 'scheduled'
437                    AND scheduled_at <= $1
438            "#,
439            now,
440        )
441        .traced()
442        .execute(&mut *self.conn)
443        .await?;
444
445        let count = res.rows_affected();
446        Ok(usize::try_from(count).unwrap_or(usize::MAX))
447    }
448
449    #[tracing::instrument(
450        name = "db.queue_job.cleanup",
451        skip_all,
452        fields(
453            db.query.text,
454            since = since.map(tracing::field::display),
455            until = %until,
456            limit = limit,
457        ),
458        err,
459    )]
460    async fn cleanup(
461        &mut self,
462        since: Option<Ulid>,
463        until: Ulid,
464        limit: usize,
465    ) -> Result<(usize, Option<Ulid>), Self::Error> {
466        // Use ULID cursor-based pagination for completed and failed jobs.
467        // We delete both completed and failed jobs in the same batch.
468        // `MAX(uuid)` isn't a thing in Postgres, so we aggregate on the client side.
469        let res = sqlx::query_scalar!(
470            r#"
471                WITH to_delete AS (
472                    SELECT queue_job_id
473                    FROM queue_jobs
474                    WHERE (status = 'completed' OR status = 'failed')
475                      AND ($1::uuid IS NULL OR queue_job_id > $1)
476                      AND queue_job_id <= $2
477                    ORDER BY queue_job_id
478                    LIMIT $3
479                )
480                DELETE FROM queue_jobs
481                USING to_delete
482                WHERE queue_jobs.queue_job_id = to_delete.queue_job_id
483                RETURNING queue_jobs.queue_job_id
484            "#,
485            since.map(Uuid::from),
486            Uuid::from(until),
487            i64::try_from(limit).unwrap_or(i64::MAX)
488        )
489        .traced()
490        .fetch_all(&mut *self.conn)
491        .await?;
492
493        let count = res.len();
494        let max_id = res.into_iter().max();
495
496        Ok((count, max_id.map(Ulid::from)))
497    }
498}