1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 BrowserSession, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11 UserRegistration,
12};
13use mas_storage::{
14 Clock, Page, Pagination,
15 user::{UserEmailFilter, UserEmailRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError,
26 filter::{Filter, StatementExt},
27 iden::UserEmails,
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32pub struct PgUserEmailRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUserEmailRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(Debug, Clone, sqlx::FromRow)]
46#[enum_def]
47struct UserEmailLookup {
48 user_email_id: Uuid,
49 user_id: Uuid,
50 email: String,
51 created_at: DateTime<Utc>,
52}
53
54impl From<UserEmailLookup> for UserEmail {
55 fn from(e: UserEmailLookup) -> UserEmail {
56 UserEmail {
57 id: e.user_email_id.into(),
58 user_id: e.user_id.into(),
59 email: e.email,
60 created_at: e.created_at,
61 }
62 }
63}
64
65struct UserEmailAuthenticationLookup {
66 user_email_authentication_id: Uuid,
67 user_session_id: Option<Uuid>,
68 user_registration_id: Option<Uuid>,
69 email: String,
70 created_at: DateTime<Utc>,
71 completed_at: Option<DateTime<Utc>>,
72}
73
74impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
75 fn from(value: UserEmailAuthenticationLookup) -> Self {
76 UserEmailAuthentication {
77 id: value.user_email_authentication_id.into(),
78 user_session_id: value.user_session_id.map(Ulid::from),
79 user_registration_id: value.user_registration_id.map(Ulid::from),
80 email: value.email,
81 created_at: value.created_at,
82 completed_at: value.completed_at,
83 }
84 }
85}
86
87struct UserEmailAuthenticationCodeLookup {
88 user_email_authentication_code_id: Uuid,
89 user_email_authentication_id: Uuid,
90 code: String,
91 created_at: DateTime<Utc>,
92 expires_at: DateTime<Utc>,
93}
94
95impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
96 fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
97 UserEmailAuthenticationCode {
98 id: value.user_email_authentication_code_id.into(),
99 user_email_authentication_id: value.user_email_authentication_id.into(),
100 code: value.code,
101 created_at: value.created_at,
102 expires_at: value.expires_at,
103 }
104 }
105}
106
107impl Filter for UserEmailFilter<'_> {
108 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109 sea_query::Condition::all()
110 .add_option(self.user().map(|user| {
111 Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
112 }))
113 .add_option(
114 self.email()
115 .map(|email| Expr::col((UserEmails::Table, UserEmails::Email)).eq(email)),
116 )
117 }
118}
119
120#[async_trait]
121impl UserEmailRepository for PgUserEmailRepository<'_> {
122 type Error = DatabaseError;
123
124 #[tracing::instrument(
125 name = "db.user_email.lookup",
126 skip_all,
127 fields(
128 db.query.text,
129 user_email.id = %id,
130 ),
131 err,
132 )]
133 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
134 let res = sqlx::query_as!(
135 UserEmailLookup,
136 r#"
137 SELECT user_email_id
138 , user_id
139 , email
140 , created_at
141 FROM user_emails
142
143 WHERE user_email_id = $1
144 "#,
145 Uuid::from(id),
146 )
147 .traced()
148 .fetch_optional(&mut *self.conn)
149 .await?;
150
151 let Some(user_email) = res else {
152 return Ok(None);
153 };
154
155 Ok(Some(user_email.into()))
156 }
157
158 #[tracing::instrument(
159 name = "db.user_email.find",
160 skip_all,
161 fields(
162 db.query.text,
163 %user.id,
164 user_email.email = email,
165 ),
166 err,
167 )]
168 async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
169 let res = sqlx::query_as!(
170 UserEmailLookup,
171 r#"
172 SELECT user_email_id
173 , user_id
174 , email
175 , created_at
176 FROM user_emails
177
178 WHERE user_id = $1 AND email = $2
179 "#,
180 Uuid::from(user.id),
181 email,
182 )
183 .traced()
184 .fetch_optional(&mut *self.conn)
185 .await?;
186
187 let Some(user_email) = res else {
188 return Ok(None);
189 };
190
191 Ok(Some(user_email.into()))
192 }
193
194 #[tracing::instrument(
195 name = "db.user_email.all",
196 skip_all,
197 fields(
198 db.query.text,
199 %user.id,
200 ),
201 err,
202 )]
203 async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
204 let res = sqlx::query_as!(
205 UserEmailLookup,
206 r#"
207 SELECT user_email_id
208 , user_id
209 , email
210 , created_at
211 FROM user_emails
212
213 WHERE user_id = $1
214
215 ORDER BY email ASC
216 "#,
217 Uuid::from(user.id),
218 )
219 .traced()
220 .fetch_all(&mut *self.conn)
221 .await?;
222
223 Ok(res.into_iter().map(Into::into).collect())
224 }
225
226 #[tracing::instrument(
227 name = "db.user_email.list",
228 skip_all,
229 fields(
230 db.query.text,
231 ),
232 err,
233 )]
234 async fn list(
235 &mut self,
236 filter: UserEmailFilter<'_>,
237 pagination: Pagination,
238 ) -> Result<Page<UserEmail>, DatabaseError> {
239 let (sql, arguments) = Query::select()
240 .expr_as(
241 Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
242 UserEmailLookupIden::UserEmailId,
243 )
244 .expr_as(
245 Expr::col((UserEmails::Table, UserEmails::UserId)),
246 UserEmailLookupIden::UserId,
247 )
248 .expr_as(
249 Expr::col((UserEmails::Table, UserEmails::Email)),
250 UserEmailLookupIden::Email,
251 )
252 .expr_as(
253 Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
254 UserEmailLookupIden::CreatedAt,
255 )
256 .from(UserEmails::Table)
257 .apply_filter(filter)
258 .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
259 .build_sqlx(PostgresQueryBuilder);
260
261 let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
262 .traced()
263 .fetch_all(&mut *self.conn)
264 .await?;
265
266 let page = pagination.process(edges).map(UserEmail::from);
267
268 Ok(page)
269 }
270
271 #[tracing::instrument(
272 name = "db.user_email.count",
273 skip_all,
274 fields(
275 db.query.text,
276 ),
277 err,
278 )]
279 async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
280 let (sql, arguments) = Query::select()
281 .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
282 .from(UserEmails::Table)
283 .apply_filter(filter)
284 .build_sqlx(PostgresQueryBuilder);
285
286 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
287 .traced()
288 .fetch_one(&mut *self.conn)
289 .await?;
290
291 count
292 .try_into()
293 .map_err(DatabaseError::to_invalid_operation)
294 }
295
296 #[tracing::instrument(
297 name = "db.user_email.add",
298 skip_all,
299 fields(
300 db.query.text,
301 %user.id,
302 user_email.id,
303 user_email.email = email,
304 ),
305 err,
306 )]
307 async fn add(
308 &mut self,
309 rng: &mut (dyn RngCore + Send),
310 clock: &dyn Clock,
311 user: &User,
312 email: String,
313 ) -> Result<UserEmail, Self::Error> {
314 let created_at = clock.now();
315 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
316 tracing::Span::current().record("user_email.id", tracing::field::display(id));
317
318 sqlx::query!(
319 r#"
320 INSERT INTO user_emails (user_email_id, user_id, email, created_at)
321 VALUES ($1, $2, $3, $4)
322 "#,
323 Uuid::from(id),
324 Uuid::from(user.id),
325 &email,
326 created_at,
327 )
328 .traced()
329 .execute(&mut *self.conn)
330 .await?;
331
332 Ok(UserEmail {
333 id,
334 user_id: user.id,
335 email,
336 created_at,
337 })
338 }
339
340 #[tracing::instrument(
341 name = "db.user_email.remove",
342 skip_all,
343 fields(
344 db.query.text,
345 user.id = %user_email.user_id,
346 %user_email.id,
347 %user_email.email,
348 ),
349 err,
350 )]
351 async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
352 let res = sqlx::query!(
353 r#"
354 DELETE FROM user_emails
355 WHERE user_email_id = $1
356 "#,
357 Uuid::from(user_email.id),
358 )
359 .traced()
360 .execute(&mut *self.conn)
361 .await?;
362
363 DatabaseError::ensure_affected_rows(&res, 1)?;
364
365 Ok(())
366 }
367
368 #[tracing::instrument(
369 name = "db.user_email.remove_bulk",
370 skip_all,
371 fields(
372 db.query.text,
373 ),
374 err,
375 )]
376 async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
377 let (sql, arguments) = Query::delete()
378 .from_table(UserEmails::Table)
379 .apply_filter(filter)
380 .build_sqlx(PostgresQueryBuilder);
381
382 let res = sqlx::query_with(&sql, arguments)
383 .traced()
384 .execute(&mut *self.conn)
385 .await?;
386
387 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
388 }
389
390 #[tracing::instrument(
391 name = "db.user_email.add_authentication_for_session",
392 skip_all,
393 fields(
394 db.query.text,
395 %session.id,
396 user_email_authentication.id,
397 user_email_authentication.email = email,
398 ),
399 err,
400 )]
401 async fn add_authentication_for_session(
402 &mut self,
403 rng: &mut (dyn RngCore + Send),
404 clock: &dyn Clock,
405 email: String,
406 session: &BrowserSession,
407 ) -> Result<UserEmailAuthentication, Self::Error> {
408 let created_at = clock.now();
409 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
410 tracing::Span::current()
411 .record("user_email_authentication.id", tracing::field::display(id));
412
413 sqlx::query!(
414 r#"
415 INSERT INTO user_email_authentications
416 ( user_email_authentication_id
417 , user_session_id
418 , email
419 , created_at
420 )
421 VALUES ($1, $2, $3, $4)
422 "#,
423 Uuid::from(id),
424 Uuid::from(session.id),
425 &email,
426 created_at,
427 )
428 .traced()
429 .execute(&mut *self.conn)
430 .await?;
431
432 Ok(UserEmailAuthentication {
433 id,
434 user_session_id: Some(session.id),
435 user_registration_id: None,
436 email,
437 created_at,
438 completed_at: None,
439 })
440 }
441
442 #[tracing::instrument(
443 name = "db.user_email.add_authentication_for_registration",
444 skip_all,
445 fields(
446 db.query.text,
447 %user_registration.id,
448 user_email_authentication.id,
449 user_email_authentication.email = email,
450 ),
451 err,
452 )]
453 async fn add_authentication_for_registration(
454 &mut self,
455 rng: &mut (dyn RngCore + Send),
456 clock: &dyn Clock,
457 email: String,
458 user_registration: &UserRegistration,
459 ) -> Result<UserEmailAuthentication, Self::Error> {
460 let created_at = clock.now();
461 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
462 tracing::Span::current()
463 .record("user_email_authentication.id", tracing::field::display(id));
464
465 sqlx::query!(
466 r#"
467 INSERT INTO user_email_authentications
468 ( user_email_authentication_id
469 , user_registration_id
470 , email
471 , created_at
472 )
473 VALUES ($1, $2, $3, $4)
474 "#,
475 Uuid::from(id),
476 Uuid::from(user_registration.id),
477 &email,
478 created_at,
479 )
480 .traced()
481 .execute(&mut *self.conn)
482 .await?;
483
484 Ok(UserEmailAuthentication {
485 id,
486 user_session_id: None,
487 user_registration_id: Some(user_registration.id),
488 email,
489 created_at,
490 completed_at: None,
491 })
492 }
493
494 #[tracing::instrument(
495 name = "db.user_email.add_authentication_code",
496 skip_all,
497 fields(
498 db.query.text,
499 %user_email_authentication.id,
500 %user_email_authentication.email,
501 user_email_authentication_code.id,
502 user_email_authentication_code.code = code,
503 ),
504 err,
505 )]
506 async fn add_authentication_code(
507 &mut self,
508 rng: &mut (dyn RngCore + Send),
509 clock: &dyn Clock,
510 duration: chrono::Duration,
511 user_email_authentication: &UserEmailAuthentication,
512 code: String,
513 ) -> Result<UserEmailAuthenticationCode, Self::Error> {
514 let created_at = clock.now();
515 let expires_at = created_at + duration;
516 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
517 tracing::Span::current().record(
518 "user_email_authentication_code.id",
519 tracing::field::display(id),
520 );
521
522 sqlx::query!(
523 r#"
524 INSERT INTO user_email_authentication_codes
525 ( user_email_authentication_code_id
526 , user_email_authentication_id
527 , code
528 , created_at
529 , expires_at
530 )
531 VALUES ($1, $2, $3, $4, $5)
532 "#,
533 Uuid::from(id),
534 Uuid::from(user_email_authentication.id),
535 &code,
536 created_at,
537 expires_at,
538 )
539 .traced()
540 .execute(&mut *self.conn)
541 .await?;
542
543 Ok(UserEmailAuthenticationCode {
544 id,
545 user_email_authentication_id: user_email_authentication.id,
546 code,
547 created_at,
548 expires_at,
549 })
550 }
551
552 #[tracing::instrument(
553 name = "db.user_email.lookup_authentication",
554 skip_all,
555 fields(
556 db.query.text,
557 user_email_authentication.id = %id,
558 ),
559 err,
560 )]
561 async fn lookup_authentication(
562 &mut self,
563 id: Ulid,
564 ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
565 let res = sqlx::query_as!(
566 UserEmailAuthenticationLookup,
567 r#"
568 SELECT user_email_authentication_id
569 , user_session_id
570 , user_registration_id
571 , email
572 , created_at
573 , completed_at
574 FROM user_email_authentications
575 WHERE user_email_authentication_id = $1
576 "#,
577 Uuid::from(id),
578 )
579 .traced()
580 .fetch_optional(&mut *self.conn)
581 .await?;
582
583 Ok(res.map(UserEmailAuthentication::from))
584 }
585
586 #[tracing::instrument(
587 name = "db.user_email.find_authentication_by_code",
588 skip_all,
589 fields(
590 db.query.text,
591 %authentication.id,
592 user_email_authentication_code.code = code,
593 ),
594 err,
595 )]
596 async fn find_authentication_code(
597 &mut self,
598 authentication: &UserEmailAuthentication,
599 code: &str,
600 ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
601 let res = sqlx::query_as!(
602 UserEmailAuthenticationCodeLookup,
603 r#"
604 SELECT user_email_authentication_code_id
605 , user_email_authentication_id
606 , code
607 , created_at
608 , expires_at
609 FROM user_email_authentication_codes
610 WHERE user_email_authentication_id = $1
611 AND code = $2
612 "#,
613 Uuid::from(authentication.id),
614 code,
615 )
616 .traced()
617 .fetch_optional(&mut *self.conn)
618 .await?;
619
620 Ok(res.map(UserEmailAuthenticationCode::from))
621 }
622
623 #[tracing::instrument(
624 name = "db.user_email.complete_email_authentication",
625 skip_all,
626 fields(
627 db.query.text,
628 %user_email_authentication.id,
629 %user_email_authentication.email,
630 %user_email_authentication_code.id,
631 %user_email_authentication_code.code,
632 ),
633 err,
634 )]
635 async fn complete_authentication(
636 &mut self,
637 clock: &dyn Clock,
638 mut user_email_authentication: UserEmailAuthentication,
639 user_email_authentication_code: &UserEmailAuthenticationCode,
640 ) -> Result<UserEmailAuthentication, Self::Error> {
641 let completed_at = clock.now();
645
646 let res = sqlx::query!(
650 r#"
651 UPDATE user_email_authentications
652 SET completed_at = $2
653 WHERE user_email_authentication_id = $1
654 AND completed_at IS NULL
655 "#,
656 Uuid::from(user_email_authentication.id),
657 completed_at,
658 )
659 .traced()
660 .execute(&mut *self.conn)
661 .await?;
662
663 DatabaseError::ensure_affected_rows(&res, 1)?;
664
665 user_email_authentication.completed_at = Some(completed_at);
666 Ok(user_email_authentication)
667 }
668}