1use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Duration, Utc};
12use mas_data_model::{Clock, UserEmail, UserRecoverySession, UserRecoveryTicket};
13use mas_storage::user::UserRecoveryRepository;
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, ExecuteExt};
20
21pub struct PgUserRecoveryRepository<'c> {
23 conn: &'c mut PgConnection,
24}
25
26impl<'c> PgUserRecoveryRepository<'c> {
27 pub fn new(conn: &'c mut PgConnection) -> Self {
30 Self { conn }
31 }
32}
33
34struct UserRecoverySessionRow {
35 user_recovery_session_id: Uuid,
36 email: String,
37 user_agent: String,
38 ip_address: Option<IpAddr>,
39 locale: String,
40 created_at: DateTime<Utc>,
41 consumed_at: Option<DateTime<Utc>>,
42}
43
44impl From<UserRecoverySessionRow> for UserRecoverySession {
45 fn from(row: UserRecoverySessionRow) -> Self {
46 UserRecoverySession {
47 id: row.user_recovery_session_id.into(),
48 email: row.email,
49 user_agent: row.user_agent,
50 ip_address: row.ip_address,
51 locale: row.locale,
52 created_at: row.created_at,
53 consumed_at: row.consumed_at,
54 }
55 }
56}
57
58struct UserRecoveryTicketRow {
59 user_recovery_ticket_id: Uuid,
60 user_recovery_session_id: Uuid,
61 user_email_id: Uuid,
62 ticket: String,
63 created_at: DateTime<Utc>,
64 expires_at: DateTime<Utc>,
65}
66
67impl From<UserRecoveryTicketRow> for UserRecoveryTicket {
68 fn from(row: UserRecoveryTicketRow) -> Self {
69 Self {
70 id: row.user_recovery_ticket_id.into(),
71 user_recovery_session_id: row.user_recovery_session_id.into(),
72 user_email_id: row.user_email_id.into(),
73 ticket: row.ticket,
74 created_at: row.created_at,
75 expires_at: row.expires_at,
76 }
77 }
78}
79
80#[async_trait]
81impl UserRecoveryRepository for PgUserRecoveryRepository<'_> {
82 type Error = DatabaseError;
83
84 #[tracing::instrument(
85 name = "db.user_recovery.lookup_session",
86 skip_all,
87 fields(
88 db.query.text,
89 user_recovery_session.id = %id,
90 ),
91 err,
92 )]
93 async fn lookup_session(
94 &mut self,
95 id: Ulid,
96 ) -> Result<Option<UserRecoverySession>, Self::Error> {
97 let row = sqlx::query_as!(
98 UserRecoverySessionRow,
99 r#"
100 SELECT
101 user_recovery_session_id
102 , email
103 , user_agent
104 , ip_address as "ip_address: IpAddr"
105 , locale
106 , created_at
107 , consumed_at
108 FROM user_recovery_sessions
109 WHERE user_recovery_session_id = $1
110 "#,
111 Uuid::from(id),
112 )
113 .traced()
114 .fetch_optional(&mut *self.conn)
115 .await?;
116
117 let Some(row) = row else {
118 return Ok(None);
119 };
120
121 Ok(Some(row.into()))
122 }
123
124 #[tracing::instrument(
125 name = "db.user_recovery.add_session",
126 skip_all,
127 fields(
128 db.query.text,
129 user_recovery_session.id,
130 user_recovery_session.email = email,
131 user_recovery_session.user_agent = user_agent,
132 user_recovery_session.ip_address = ip_address.map(|ip| ip.to_string()),
133 )
134 )]
135 async fn add_session(
136 &mut self,
137 rng: &mut (dyn RngCore + Send),
138 clock: &dyn Clock,
139 email: String,
140 user_agent: String,
141 ip_address: Option<IpAddr>,
142 locale: String,
143 ) -> Result<UserRecoverySession, Self::Error> {
144 let created_at = clock.now();
145 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
146 tracing::Span::current().record("user_recovery_session.id", tracing::field::display(id));
147 sqlx::query!(
148 r#"
149 INSERT INTO user_recovery_sessions (
150 user_recovery_session_id
151 , email
152 , user_agent
153 , ip_address
154 , locale
155 , created_at
156 )
157 VALUES ($1, $2, $3, $4, $5, $6)
158 "#,
159 Uuid::from(id),
160 &email,
161 &*user_agent,
162 ip_address as Option<IpAddr>,
163 &locale,
164 created_at,
165 )
166 .traced()
167 .execute(&mut *self.conn)
168 .await?;
169
170 let user_recovery_session = UserRecoverySession {
171 id,
172 email,
173 user_agent,
174 ip_address,
175 locale,
176 created_at,
177 consumed_at: None,
178 };
179
180 Ok(user_recovery_session)
181 }
182
183 #[tracing::instrument(
184 name = "db.user_recovery.find_ticket",
185 skip_all,
186 fields(
187 db.query.text,
188 user_recovery_ticket.id = ticket,
189 ),
190 err,
191 )]
192 async fn find_ticket(
193 &mut self,
194 ticket: &str,
195 ) -> Result<Option<UserRecoveryTicket>, Self::Error> {
196 let row = sqlx::query_as!(
197 UserRecoveryTicketRow,
198 r#"
199 SELECT
200 user_recovery_ticket_id
201 , user_recovery_session_id
202 , user_email_id
203 , ticket
204 , created_at
205 , expires_at
206 FROM user_recovery_tickets
207 WHERE ticket = $1
208 "#,
209 ticket,
210 )
211 .traced()
212 .fetch_optional(&mut *self.conn)
213 .await?;
214
215 let Some(row) = row else {
216 return Ok(None);
217 };
218
219 Ok(Some(row.into()))
220 }
221
222 #[tracing::instrument(
223 name = "db.user_recovery.add_ticket",
224 skip_all,
225 fields(
226 db.query.text,
227 user_recovery_ticket.id,
228 user_recovery_ticket.id = ticket,
229 %user_recovery_session.id,
230 %user_email.id,
231 )
232 )]
233 async fn add_ticket(
234 &mut self,
235 rng: &mut (dyn RngCore + Send),
236 clock: &dyn Clock,
237 user_recovery_session: &UserRecoverySession,
238 user_email: &UserEmail,
239 ticket: String,
240 ) -> Result<UserRecoveryTicket, Self::Error> {
241 let created_at = clock.now();
242 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
243 tracing::Span::current().record("user_recovery_ticket.id", tracing::field::display(id));
244
245 let expires_at = created_at + Duration::minutes(10);
247
248 sqlx::query!(
249 r#"
250 INSERT INTO user_recovery_tickets (
251 user_recovery_ticket_id
252 , user_recovery_session_id
253 , user_email_id
254 , ticket
255 , created_at
256 , expires_at
257 )
258 VALUES ($1, $2, $3, $4, $5, $6)
259 "#,
260 Uuid::from(id),
261 Uuid::from(user_recovery_session.id),
262 Uuid::from(user_email.id),
263 &ticket,
264 created_at,
265 expires_at,
266 )
267 .traced()
268 .execute(&mut *self.conn)
269 .await?;
270
271 let ticket = UserRecoveryTicket {
272 id,
273 user_recovery_session_id: user_recovery_session.id,
274 user_email_id: user_email.id,
275 ticket,
276 created_at,
277 expires_at,
278 };
279
280 Ok(ticket)
281 }
282
283 #[tracing::instrument(
284 name = "db.user_recovery.consume_ticket",
285 skip_all,
286 fields(
287 db.query.text,
288 %user_recovery_ticket.id,
289 user_email.id = %user_recovery_ticket.user_email_id,
290 %user_recovery_session.id,
291 %user_recovery_session.email,
292 ),
293 err,
294 )]
295 async fn consume_ticket(
296 &mut self,
297 clock: &dyn Clock,
298 user_recovery_ticket: UserRecoveryTicket,
299 mut user_recovery_session: UserRecoverySession,
300 ) -> Result<UserRecoverySession, Self::Error> {
301 let _ = user_recovery_ticket;
303
304 if user_recovery_session.consumed_at.is_some() {
306 return Err(DatabaseError::invalid_operation());
307 }
308
309 let consumed_at = clock.now();
310
311 let res = sqlx::query!(
312 r#"
313 UPDATE user_recovery_sessions
314 SET consumed_at = $1
315 WHERE user_recovery_session_id = $2
316 "#,
317 consumed_at,
318 Uuid::from(user_recovery_session.id),
319 )
320 .traced()
321 .execute(&mut *self.conn)
322 .await?;
323
324 user_recovery_session.consumed_at = Some(consumed_at);
325
326 DatabaseError::ensure_affected_rows(&res, 1)?;
327
328 Ok(user_recovery_session)
329 }
330
331 #[tracing::instrument(
332 name = "db.user_recovery.cleanup",
333 skip_all,
334 fields(
335 db.query.text,
336 since = since.map(tracing::field::display),
337 until = %until,
338 limit = limit,
339 ),
340 err,
341 )]
342 async fn cleanup(
343 &mut self,
344 since: Option<Ulid>,
345 until: Ulid,
346 limit: usize,
347 ) -> Result<(usize, Option<Ulid>), Self::Error> {
348 let res = sqlx::query_scalar!(
352 r#"
353 WITH to_delete AS (
354 SELECT user_recovery_session_id
355 FROM user_recovery_sessions
356 WHERE ($1::uuid IS NULL OR user_recovery_session_id > $1)
357 AND user_recovery_session_id <= $2
358 ORDER BY user_recovery_session_id
359 LIMIT $3
360 )
361 DELETE FROM user_recovery_sessions
362 USING to_delete
363 WHERE user_recovery_sessions.user_recovery_session_id = to_delete.user_recovery_session_id
364 RETURNING user_recovery_sessions.user_recovery_session_id
365 "#,
366 since.map(Uuid::from),
367 Uuid::from(until),
368 i64::try_from(limit).unwrap_or(i64::MAX)
369 )
370 .traced()
371 .fetch_all(&mut *self.conn)
372 .await?;
373
374 let count = res.len();
375 let max_id = res.into_iter().max();
376
377 Ok((count, max_id.map(Ulid::from)))
378 }
379}