mas_storage_pg/queue/
job.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6//! A module containing the PostgreSQL implementation of the
7//! [`QueueJobRepository`].
8
9use async_trait::async_trait;
10use chrono::{DateTime, Duration, Utc};
11use mas_storage::{
12    Clock,
13    queue::{Job, QueueJobRepository, Worker},
14};
15use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
16use rand::RngCore;
17use sqlx::PgConnection;
18use tracing::Instrument;
19use ulid::Ulid;
20use uuid::Uuid;
21
22use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt};
23
24/// An implementation of [`QueueJobRepository`] for a PostgreSQL connection.
25pub struct PgQueueJobRepository<'c> {
26    conn: &'c mut PgConnection,
27}
28
29impl<'c> PgQueueJobRepository<'c> {
30    /// Create a new [`PgQueueJobRepository`] from an active PostgreSQL
31    /// connection.
32    #[must_use]
33    pub fn new(conn: &'c mut PgConnection) -> Self {
34        Self { conn }
35    }
36}
37
38struct JobReservationResult {
39    queue_job_id: Uuid,
40    queue_name: String,
41    payload: serde_json::Value,
42    metadata: serde_json::Value,
43    attempt: i32,
44}
45
46impl TryFrom<JobReservationResult> for Job {
47    type Error = DatabaseInconsistencyError;
48
49    fn try_from(value: JobReservationResult) -> Result<Self, Self::Error> {
50        let id = value.queue_job_id.into();
51        let queue_name = value.queue_name;
52        let payload = value.payload;
53
54        let metadata = serde_json::from_value(value.metadata).map_err(|e| {
55            DatabaseInconsistencyError::on("queue_jobs")
56                .column("metadata")
57                .row(id)
58                .source(e)
59        })?;
60
61        let attempt = value.attempt.try_into().map_err(|e| {
62            DatabaseInconsistencyError::on("queue_jobs")
63                .column("attempt")
64                .row(id)
65                .source(e)
66        })?;
67
68        Ok(Self {
69            id,
70            queue_name,
71            payload,
72            metadata,
73            attempt,
74        })
75    }
76}
77
78#[async_trait]
79impl QueueJobRepository for PgQueueJobRepository<'_> {
80    type Error = DatabaseError;
81
82    #[tracing::instrument(
83        name = "db.queue_job.schedule",
84        fields(
85            queue_job.id,
86            queue_job.queue_name = queue_name,
87            db.query.text,
88        ),
89        skip_all,
90        err,
91    )]
92    async fn schedule(
93        &mut self,
94        rng: &mut (dyn RngCore + Send),
95        clock: &dyn Clock,
96        queue_name: &str,
97        payload: serde_json::Value,
98        metadata: serde_json::Value,
99    ) -> Result<(), Self::Error> {
100        let created_at = clock.now();
101        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
102        tracing::Span::current().record("queue_job.id", tracing::field::display(id));
103
104        sqlx::query!(
105            r#"
106                INSERT INTO queue_jobs
107                    (queue_job_id, queue_name, payload, metadata, created_at)
108                VALUES ($1, $2, $3, $4, $5)
109            "#,
110            Uuid::from(id),
111            queue_name,
112            payload,
113            metadata,
114            created_at,
115        )
116        .traced()
117        .execute(&mut *self.conn)
118        .await?;
119
120        Ok(())
121    }
122
123    #[tracing::instrument(
124        name = "db.queue_job.schedule_later",
125        fields(
126            queue_job.id,
127            queue_job.queue_name = queue_name,
128            queue_job.scheduled_at = %scheduled_at,
129            db.query.text,
130        ),
131        skip_all,
132        err,
133    )]
134    async fn schedule_later(
135        &mut self,
136        rng: &mut (dyn RngCore + Send),
137        clock: &dyn Clock,
138        queue_name: &str,
139        payload: serde_json::Value,
140        metadata: serde_json::Value,
141        scheduled_at: DateTime<Utc>,
142        schedule_name: Option<&str>,
143    ) -> Result<(), Self::Error> {
144        let created_at = clock.now();
145        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
146        tracing::Span::current().record("queue_job.id", tracing::field::display(id));
147
148        sqlx::query!(
149            r#"
150                INSERT INTO queue_jobs
151                    (queue_job_id, queue_name, payload, metadata, created_at, scheduled_at, schedule_name, status)
152                VALUES ($1, $2, $3, $4, $5, $6, $7, 'scheduled')
153            "#,
154            Uuid::from(id),
155            queue_name,
156            payload,
157            metadata,
158            created_at,
159            scheduled_at,
160            schedule_name,
161        )
162        .traced()
163        .execute(&mut *self.conn)
164        .await?;
165
166        // If there was a schedule name supplied, update the queue_schedules table
167        if let Some(schedule_name) = schedule_name {
168            let span = tracing::info_span!(
169                "db.queue_job.schedule_later.update_schedules",
170                { DB_QUERY_TEXT } = tracing::field::Empty,
171            );
172
173            let res = sqlx::query!(
174                r#"
175                    UPDATE queue_schedules
176                    SET last_scheduled_at = $1,
177                        last_scheduled_job_id = $2
178                    WHERE schedule_name = $3
179                "#,
180                scheduled_at,
181                Uuid::from(id),
182                schedule_name,
183            )
184            .record(&span)
185            .execute(&mut *self.conn)
186            .instrument(span)
187            .await?;
188
189            DatabaseError::ensure_affected_rows(&res, 1)?;
190        }
191
192        Ok(())
193    }
194
195    #[tracing::instrument(
196        name = "db.queue_job.reserve",
197        skip_all,
198        fields(
199            db.query.text,
200        ),
201        err,
202    )]
203    async fn reserve(
204        &mut self,
205        clock: &dyn Clock,
206        worker: &Worker,
207        queues: &[&str],
208        count: usize,
209    ) -> Result<Vec<Job>, Self::Error> {
210        let now = clock.now();
211        let max_count = i64::try_from(count).unwrap_or(i64::MAX);
212        let queues: Vec<String> = queues.iter().map(|&s| s.to_owned()).collect();
213        let results = sqlx::query_as!(
214            JobReservationResult,
215            r#"
216                -- We first grab a few jobs that are available,
217                -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently
218                -- and we don't get multiple workers grabbing the same jobs
219                WITH locked_jobs AS (
220                    SELECT queue_job_id
221                    FROM queue_jobs
222                    WHERE
223                        status = 'available'
224                        AND queue_name = ANY($1)
225                    ORDER BY queue_job_id ASC
226                    LIMIT $2
227                    FOR UPDATE
228                    SKIP LOCKED
229                )
230                -- then we update the status of those jobs to 'running', returning the job details
231                UPDATE queue_jobs
232                SET status = 'running', started_at = $3, started_by = $4
233                FROM locked_jobs
234                WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id
235                RETURNING
236                    queue_jobs.queue_job_id,
237                    queue_jobs.queue_name,
238                    queue_jobs.payload,
239                    queue_jobs.metadata,
240                    queue_jobs.attempt
241            "#,
242            &queues,
243            max_count,
244            now,
245            Uuid::from(worker.id),
246        )
247        .traced()
248        .fetch_all(&mut *self.conn)
249        .await?;
250
251        let jobs = results
252            .into_iter()
253            .map(TryFrom::try_from)
254            .collect::<Result<Vec<_>, _>>()?;
255
256        Ok(jobs)
257    }
258
259    #[tracing::instrument(
260        name = "db.queue_job.mark_as_completed",
261        skip_all,
262        fields(
263            db.query.text,
264            job.id = %id,
265        ),
266        err,
267    )]
268    async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> {
269        let now = clock.now();
270        let res = sqlx::query!(
271            r#"
272                UPDATE queue_jobs
273                SET status = 'completed', completed_at = $1
274                WHERE queue_job_id = $2 AND status = 'running'
275            "#,
276            now,
277            Uuid::from(id),
278        )
279        .traced()
280        .execute(&mut *self.conn)
281        .await?;
282
283        DatabaseError::ensure_affected_rows(&res, 1)?;
284
285        Ok(())
286    }
287
288    #[tracing::instrument(
289        name = "db.queue_job.mark_as_failed",
290        skip_all,
291        fields(
292            db.query.text,
293            job.id = %id,
294        ),
295        err
296    )]
297    async fn mark_as_failed(
298        &mut self,
299        clock: &dyn Clock,
300        id: Ulid,
301        reason: &str,
302    ) -> Result<(), Self::Error> {
303        let now = clock.now();
304        let res = sqlx::query!(
305            r#"
306                UPDATE queue_jobs
307                SET
308                    status = 'failed',
309                    failed_at = $1,
310                    failed_reason = $2
311                WHERE
312                    queue_job_id = $3
313                    AND status = 'running'
314            "#,
315            now,
316            reason,
317            Uuid::from(id),
318        )
319        .traced()
320        .execute(&mut *self.conn)
321        .await?;
322
323        DatabaseError::ensure_affected_rows(&res, 1)?;
324
325        Ok(())
326    }
327
328    #[tracing::instrument(
329        name = "db.queue_job.retry",
330        skip_all,
331        fields(
332            db.query.text,
333            job.id = %id,
334        ),
335        err
336    )]
337    async fn retry(
338        &mut self,
339        rng: &mut (dyn RngCore + Send),
340        clock: &dyn Clock,
341        id: Ulid,
342        delay: Duration,
343    ) -> Result<(), Self::Error> {
344        let now = clock.now();
345        let scheduled_at = now + delay;
346        let new_id = Ulid::from_datetime_with_source(now.into(), rng);
347
348        let span = tracing::info_span!(
349            "db.queue_job.retry.insert_job",
350            { DB_QUERY_TEXT } = tracing::field::Empty
351        );
352        // Create a new job with the same payload and metadata, but a new ID and
353        // increment the attempt
354        // We make sure we do this only for 'failed' jobs
355        let res = sqlx::query!(
356            r#"
357                INSERT INTO queue_jobs
358                    (queue_job_id, queue_name, payload, metadata, created_at,
359                     attempt, scheduled_at, schedule_name, status)
360                SELECT $1, queue_name, payload, metadata, $2, attempt + 1, $3, schedule_name, 'scheduled'
361                FROM queue_jobs
362                WHERE queue_job_id = $4
363                  AND status = 'failed'
364            "#,
365            Uuid::from(new_id),
366            now,
367            scheduled_at,
368            Uuid::from(id),
369        )
370        .record(&span)
371        .execute(&mut *self.conn)
372        .instrument(span)
373        .await?;
374
375        DatabaseError::ensure_affected_rows(&res, 1)?;
376
377        // If that job was referenced by a schedule, update the schedule
378        let span = tracing::info_span!(
379            "db.queue_job.retry.update_schedule",
380            { DB_QUERY_TEXT } = tracing::field::Empty
381        );
382        sqlx::query!(
383            r#"
384                UPDATE queue_schedules
385                SET last_scheduled_at = $1,
386                    last_scheduled_job_id = $2
387                WHERE last_scheduled_job_id = $3
388            "#,
389            scheduled_at,
390            Uuid::from(new_id),
391            Uuid::from(id),
392        )
393        .record(&span)
394        .execute(&mut *self.conn)
395        .instrument(span)
396        .await?;
397
398        // Update the old job to point to the new attempt
399        let span = tracing::info_span!(
400            "db.queue_job.retry.update_old_job",
401            { DB_QUERY_TEXT } = tracing::field::Empty
402        );
403        let res = sqlx::query!(
404            r#"
405                UPDATE queue_jobs
406                SET next_attempt_id = $1
407                WHERE queue_job_id = $2
408            "#,
409            Uuid::from(new_id),
410            Uuid::from(id),
411        )
412        .record(&span)
413        .execute(&mut *self.conn)
414        .instrument(span)
415        .await?;
416
417        DatabaseError::ensure_affected_rows(&res, 1)?;
418
419        Ok(())
420    }
421
422    #[tracing::instrument(
423        name = "db.queue_job.schedule_available_jobs",
424        skip_all,
425        fields(
426            db.query.text,
427        ),
428        err
429    )]
430    async fn schedule_available_jobs(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error> {
431        let now = clock.now();
432        let res = sqlx::query!(
433            r#"
434                UPDATE queue_jobs
435                SET status = 'available'
436                WHERE
437                    status = 'scheduled'
438                    AND scheduled_at <= $1
439            "#,
440            now,
441        )
442        .traced()
443        .execute(&mut *self.conn)
444        .await?;
445
446        let count = res.rows_affected();
447        Ok(usize::try_from(count).unwrap_or(usize::MAX))
448    }
449}