mas_storage_pg/queue/
worker.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6//! A module containing the PostgreSQL implementation of the
7//! [`QueueWorkerRepository`].
8
9use async_trait::async_trait;
10use chrono::Duration;
11use mas_storage::{
12    Clock,
13    queue::{QueueWorkerRepository, Worker},
14};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{DatabaseError, ExecuteExt};
21
22/// An implementation of [`QueueWorkerRepository`] for a PostgreSQL connection.
23pub struct PgQueueWorkerRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgQueueWorkerRepository<'c> {
28    /// Create a new [`PgQueueWorkerRepository`] from an active PostgreSQL
29    /// connection.
30    #[must_use]
31    pub fn new(conn: &'c mut PgConnection) -> Self {
32        Self { conn }
33    }
34}
35
36#[async_trait]
37impl QueueWorkerRepository for PgQueueWorkerRepository<'_> {
38    type Error = DatabaseError;
39
40    #[tracing::instrument(
41        name = "db.queue_worker.register",
42        skip_all,
43        fields(
44            worker.id,
45            db.query.text,
46        ),
47        err,
48    )]
49    async fn register(
50        &mut self,
51        rng: &mut (dyn RngCore + Send),
52        clock: &dyn Clock,
53    ) -> Result<Worker, Self::Error> {
54        let now = clock.now();
55        let worker_id = Ulid::from_datetime_with_source(now.into(), rng);
56        tracing::Span::current().record("worker.id", tracing::field::display(worker_id));
57
58        sqlx::query!(
59            r#"
60                INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at)
61                VALUES ($1, $2, $2)
62            "#,
63            Uuid::from(worker_id),
64            now,
65        )
66        .traced()
67        .execute(&mut *self.conn)
68        .await?;
69
70        Ok(Worker { id: worker_id })
71    }
72
73    #[tracing::instrument(
74        name = "db.queue_worker.heartbeat",
75        skip_all,
76        fields(
77            %worker.id,
78            db.query.text,
79        ),
80        err,
81    )]
82    async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
83        let now = clock.now();
84        let res = sqlx::query!(
85            r#"
86                UPDATE queue_workers
87                SET last_seen_at = $2
88                WHERE queue_worker_id = $1 AND shutdown_at IS NULL
89            "#,
90            Uuid::from(worker.id),
91            now,
92        )
93        .traced()
94        .execute(&mut *self.conn)
95        .await?;
96
97        // If no row was updated, the worker was shutdown so we return an error
98        DatabaseError::ensure_affected_rows(&res, 1)?;
99
100        Ok(())
101    }
102
103    #[tracing::instrument(
104        name = "db.queue_worker.shutdown",
105        skip_all,
106        fields(
107            %worker.id,
108            db.query.text,
109        ),
110        err,
111    )]
112    async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
113        let now = clock.now();
114        let res = sqlx::query!(
115            r#"
116                UPDATE queue_workers
117                SET shutdown_at = $2
118                WHERE queue_worker_id = $1
119            "#,
120            Uuid::from(worker.id),
121            now,
122        )
123        .traced()
124        .execute(&mut *self.conn)
125        .await?;
126
127        DatabaseError::ensure_affected_rows(&res, 1)?;
128
129        // Remove the leader lease if we were holding it
130        let res = sqlx::query!(
131            r#"
132                DELETE FROM queue_leader
133                WHERE queue_worker_id = $1
134            "#,
135            Uuid::from(worker.id),
136        )
137        .traced()
138        .execute(&mut *self.conn)
139        .await?;
140
141        // If we were holding the leader lease, notify workers
142        if res.rows_affected() > 0 {
143            sqlx::query!(
144                r#"
145                    NOTIFY queue_leader_stepdown
146                "#,
147            )
148            .traced()
149            .execute(&mut *self.conn)
150            .await?;
151        }
152
153        Ok(())
154    }
155
156    #[tracing::instrument(
157        name = "db.queue_worker.shutdown_dead_workers",
158        skip_all,
159        fields(
160            db.query.text,
161        ),
162        err,
163    )]
164    async fn shutdown_dead_workers(
165        &mut self,
166        clock: &dyn Clock,
167        threshold: Duration,
168    ) -> Result<(), Self::Error> {
169        // Here the threshold is usually set to a few minutes, so we don't need to use
170        // the database time, as we can assume worker clocks have less than a minute
171        // skew between each other, else other things would break
172        let now = clock.now();
173        sqlx::query!(
174            r#"
175                UPDATE queue_workers
176                SET shutdown_at = $1
177                WHERE shutdown_at IS NULL
178                  AND last_seen_at < $2
179            "#,
180            now,
181            now - threshold,
182        )
183        .traced()
184        .execute(&mut *self.conn)
185        .await?;
186
187        Ok(())
188    }
189
190    #[tracing::instrument(
191        name = "db.queue_worker.remove_leader_lease_if_expired",
192        skip_all,
193        fields(
194            db.query.text,
195        ),
196        err,
197    )]
198    async fn remove_leader_lease_if_expired(
199        &mut self,
200        _clock: &dyn Clock,
201    ) -> Result<(), Self::Error> {
202        // `expires_at` is a rare exception where we use the database time, as this
203        // would be very sensitive to clock skew between workers
204        sqlx::query!(
205            r#"
206                DELETE FROM queue_leader
207                WHERE expires_at < NOW()
208            "#,
209        )
210        .traced()
211        .execute(&mut *self.conn)
212        .await?;
213
214        Ok(())
215    }
216
217    #[tracing::instrument(
218        name = "db.queue_worker.try_get_leader_lease",
219        skip_all,
220        fields(
221            %worker.id,
222            db.query.text,
223        ),
224        err,
225    )]
226    async fn try_get_leader_lease(
227        &mut self,
228        clock: &dyn Clock,
229        worker: &Worker,
230    ) -> Result<bool, Self::Error> {
231        let now = clock.now();
232        // The queue_leader table is meant to only have a single row, which conflicts on
233        // the `active` column
234
235        // If there is a conflict, we update the `expires_at` column ONLY IF the current
236        // leader is ourselves.
237
238        // `expires_at` is a rare exception where we use the database time, as this
239        // would be very sensitive to clock skew between workers
240        let res = sqlx::query!(
241            r#"
242                INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)
243                VALUES ($1, NOW() + INTERVAL '5 seconds', $2)
244                ON CONFLICT (active)
245                DO UPDATE SET expires_at = EXCLUDED.expires_at
246                WHERE queue_leader.queue_worker_id = $2
247            "#,
248            now,
249            Uuid::from(worker.id)
250        )
251        .traced()
252        .execute(&mut *self.conn)
253        .await?;
254
255        // We can then detect whether we are the leader or not by checking how many rows
256        // were affected by the upsert
257        let am_i_the_leader = res.rows_affected() == 1;
258
259        Ok(am_i_the_leader)
260    }
261}