1use async_trait::async_trait;
10use chrono::Duration;
11use mas_storage::{
12 Clock,
13 queue::{QueueWorkerRepository, Worker},
14};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{DatabaseError, ExecuteExt};
21
22pub struct PgQueueWorkerRepository<'c> {
24 conn: &'c mut PgConnection,
25}
26
27impl<'c> PgQueueWorkerRepository<'c> {
28 #[must_use]
31 pub fn new(conn: &'c mut PgConnection) -> Self {
32 Self { conn }
33 }
34}
35
36#[async_trait]
37impl QueueWorkerRepository for PgQueueWorkerRepository<'_> {
38 type Error = DatabaseError;
39
40 #[tracing::instrument(
41 name = "db.queue_worker.register",
42 skip_all,
43 fields(
44 worker.id,
45 db.query.text,
46 ),
47 err,
48 )]
49 async fn register(
50 &mut self,
51 rng: &mut (dyn RngCore + Send),
52 clock: &dyn Clock,
53 ) -> Result<Worker, Self::Error> {
54 let now = clock.now();
55 let worker_id = Ulid::from_datetime_with_source(now.into(), rng);
56 tracing::Span::current().record("worker.id", tracing::field::display(worker_id));
57
58 sqlx::query!(
59 r#"
60 INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at)
61 VALUES ($1, $2, $2)
62 "#,
63 Uuid::from(worker_id),
64 now,
65 )
66 .traced()
67 .execute(&mut *self.conn)
68 .await?;
69
70 Ok(Worker { id: worker_id })
71 }
72
73 #[tracing::instrument(
74 name = "db.queue_worker.heartbeat",
75 skip_all,
76 fields(
77 %worker.id,
78 db.query.text,
79 ),
80 err,
81 )]
82 async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
83 let now = clock.now();
84 let res = sqlx::query!(
85 r#"
86 UPDATE queue_workers
87 SET last_seen_at = $2
88 WHERE queue_worker_id = $1 AND shutdown_at IS NULL
89 "#,
90 Uuid::from(worker.id),
91 now,
92 )
93 .traced()
94 .execute(&mut *self.conn)
95 .await?;
96
97 DatabaseError::ensure_affected_rows(&res, 1)?;
99
100 Ok(())
101 }
102
103 #[tracing::instrument(
104 name = "db.queue_worker.shutdown",
105 skip_all,
106 fields(
107 %worker.id,
108 db.query.text,
109 ),
110 err,
111 )]
112 async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
113 let now = clock.now();
114 let res = sqlx::query!(
115 r#"
116 UPDATE queue_workers
117 SET shutdown_at = $2
118 WHERE queue_worker_id = $1
119 "#,
120 Uuid::from(worker.id),
121 now,
122 )
123 .traced()
124 .execute(&mut *self.conn)
125 .await?;
126
127 DatabaseError::ensure_affected_rows(&res, 1)?;
128
129 let res = sqlx::query!(
131 r#"
132 DELETE FROM queue_leader
133 WHERE queue_worker_id = $1
134 "#,
135 Uuid::from(worker.id),
136 )
137 .traced()
138 .execute(&mut *self.conn)
139 .await?;
140
141 if res.rows_affected() > 0 {
143 sqlx::query!(
144 r#"
145 NOTIFY queue_leader_stepdown
146 "#,
147 )
148 .traced()
149 .execute(&mut *self.conn)
150 .await?;
151 }
152
153 Ok(())
154 }
155
156 #[tracing::instrument(
157 name = "db.queue_worker.shutdown_dead_workers",
158 skip_all,
159 fields(
160 db.query.text,
161 ),
162 err,
163 )]
164 async fn shutdown_dead_workers(
165 &mut self,
166 clock: &dyn Clock,
167 threshold: Duration,
168 ) -> Result<(), Self::Error> {
169 let now = clock.now();
173 sqlx::query!(
174 r#"
175 UPDATE queue_workers
176 SET shutdown_at = $1
177 WHERE shutdown_at IS NULL
178 AND last_seen_at < $2
179 "#,
180 now,
181 now - threshold,
182 )
183 .traced()
184 .execute(&mut *self.conn)
185 .await?;
186
187 Ok(())
188 }
189
190 #[tracing::instrument(
191 name = "db.queue_worker.remove_leader_lease_if_expired",
192 skip_all,
193 fields(
194 db.query.text,
195 ),
196 err,
197 )]
198 async fn remove_leader_lease_if_expired(
199 &mut self,
200 _clock: &dyn Clock,
201 ) -> Result<(), Self::Error> {
202 sqlx::query!(
205 r#"
206 DELETE FROM queue_leader
207 WHERE expires_at < NOW()
208 "#,
209 )
210 .traced()
211 .execute(&mut *self.conn)
212 .await?;
213
214 Ok(())
215 }
216
217 #[tracing::instrument(
218 name = "db.queue_worker.try_get_leader_lease",
219 skip_all,
220 fields(
221 %worker.id,
222 db.query.text,
223 ),
224 err,
225 )]
226 async fn try_get_leader_lease(
227 &mut self,
228 clock: &dyn Clock,
229 worker: &Worker,
230 ) -> Result<bool, Self::Error> {
231 let now = clock.now();
232 let res = sqlx::query!(
241 r#"
242 INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)
243 VALUES ($1, NOW() + INTERVAL '5 seconds', $2)
244 ON CONFLICT (active)
245 DO UPDATE SET expires_at = EXCLUDED.expires_at
246 WHERE queue_leader.queue_worker_id = $2
247 "#,
248 now,
249 Uuid::from(worker.id)
250 )
251 .traced()
252 .execute(&mut *self.conn)
253 .await?;
254
255 let am_i_the_leader = res.rows_affected() == 1;
258
259 Ok(am_i_the_leader)
260 }
261}