1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail,
12 UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration,
13};
14use mas_storage::{
15 Page, Pagination,
16 pagination::Node,
17 user::{UserEmailFilter, UserEmailRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, Func, PostgresQueryBuilder, Query, SimpleExpr, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27 DatabaseError,
28 filter::{Filter, StatementExt},
29 iden::UserEmails,
30 pagination::QueryBuilderExt,
31 tracing::ExecuteExt,
32};
33
34pub struct PgUserEmailRepository<'c> {
36 conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUserEmailRepository<'c> {
40 pub fn new(conn: &'c mut PgConnection) -> Self {
43 Self { conn }
44 }
45}
46
47#[derive(Debug, Clone, sqlx::FromRow)]
48#[enum_def]
49struct UserEmailLookup {
50 user_email_id: Uuid,
51 user_id: Uuid,
52 email: String,
53 created_at: DateTime<Utc>,
54}
55
56impl Node<Ulid> for UserEmailLookup {
57 fn cursor(&self) -> Ulid {
58 self.user_email_id.into()
59 }
60}
61
62impl From<UserEmailLookup> for UserEmail {
63 fn from(e: UserEmailLookup) -> UserEmail {
64 UserEmail {
65 id: e.user_email_id.into(),
66 user_id: e.user_id.into(),
67 email: e.email,
68 created_at: e.created_at,
69 }
70 }
71}
72
73struct UserEmailAuthenticationLookup {
74 user_email_authentication_id: Uuid,
75 user_session_id: Option<Uuid>,
76 user_registration_id: Option<Uuid>,
77 email: String,
78 created_at: DateTime<Utc>,
79 completed_at: Option<DateTime<Utc>>,
80}
81
82impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
83 fn from(value: UserEmailAuthenticationLookup) -> Self {
84 UserEmailAuthentication {
85 id: value.user_email_authentication_id.into(),
86 user_session_id: value.user_session_id.map(Ulid::from),
87 user_registration_id: value.user_registration_id.map(Ulid::from),
88 email: value.email,
89 created_at: value.created_at,
90 completed_at: value.completed_at,
91 }
92 }
93}
94
95struct UserEmailAuthenticationCodeLookup {
96 user_email_authentication_code_id: Uuid,
97 user_email_authentication_id: Uuid,
98 code: String,
99 created_at: DateTime<Utc>,
100 expires_at: DateTime<Utc>,
101}
102
103impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
104 fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
105 UserEmailAuthenticationCode {
106 id: value.user_email_authentication_code_id.into(),
107 user_email_authentication_id: value.user_email_authentication_id.into(),
108 code: value.code,
109 created_at: value.created_at,
110 expires_at: value.expires_at,
111 }
112 }
113}
114
115impl Filter for UserEmailFilter<'_> {
116 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
117 sea_query::Condition::all()
118 .add_option(self.user().map(|user| {
119 Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
120 }))
121 .add_option(self.email().map(|email| {
122 SimpleExpr::from(Func::lower(Expr::col((
123 UserEmails::Table,
124 UserEmails::Email,
125 ))))
126 .eq(Func::lower(email))
127 }))
128 }
129}
130
131#[async_trait]
132impl UserEmailRepository for PgUserEmailRepository<'_> {
133 type Error = DatabaseError;
134
135 #[tracing::instrument(
136 name = "db.user_email.lookup",
137 skip_all,
138 fields(
139 db.query.text,
140 user_email.id = %id,
141 ),
142 err,
143 )]
144 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
145 let res = sqlx::query_as!(
146 UserEmailLookup,
147 r#"
148 SELECT user_email_id
149 , user_id
150 , email
151 , created_at
152 FROM user_emails
153
154 WHERE user_email_id = $1
155 "#,
156 Uuid::from(id),
157 )
158 .traced()
159 .fetch_optional(&mut *self.conn)
160 .await?;
161
162 let Some(user_email) = res else {
163 return Ok(None);
164 };
165
166 Ok(Some(user_email.into()))
167 }
168
169 #[tracing::instrument(
170 name = "db.user_email.find",
171 skip_all,
172 fields(
173 db.query.text,
174 %user.id,
175 user_email.email = email,
176 ),
177 err,
178 )]
179 async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
180 let res = sqlx::query_as!(
181 UserEmailLookup,
182 r#"
183 SELECT user_email_id
184 , user_id
185 , email
186 , created_at
187 FROM user_emails
188
189 WHERE user_id = $1 AND LOWER(email) = LOWER($2)
190 "#,
191 Uuid::from(user.id),
192 email,
193 )
194 .traced()
195 .fetch_optional(&mut *self.conn)
196 .await?;
197
198 let Some(user_email) = res else {
199 return Ok(None);
200 };
201
202 Ok(Some(user_email.into()))
203 }
204
205 #[tracing::instrument(
206 name = "db.user_email.find_by_email",
207 skip_all,
208 fields(
209 db.query.text,
210 user_email.email = email,
211 ),
212 err,
213 )]
214 async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
215 let res = sqlx::query_as!(
216 UserEmailLookup,
217 r#"
218 SELECT user_email_id
219 , user_id
220 , email
221 , created_at
222 FROM user_emails
223 WHERE LOWER(email) = LOWER($1)
224 "#,
225 email,
226 )
227 .traced()
228 .fetch_all(&mut *self.conn)
229 .await?;
230
231 if res.len() != 1 {
232 return Ok(None);
233 }
234
235 let Some(user_email) = res.into_iter().next() else {
236 return Ok(None);
237 };
238
239 Ok(Some(user_email.into()))
240 }
241
242 #[tracing::instrument(
243 name = "db.user_email.all",
244 skip_all,
245 fields(
246 db.query.text,
247 %user.id,
248 ),
249 err,
250 )]
251 async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
252 let res = sqlx::query_as!(
253 UserEmailLookup,
254 r#"
255 SELECT user_email_id
256 , user_id
257 , email
258 , created_at
259 FROM user_emails
260
261 WHERE user_id = $1
262
263 ORDER BY email ASC
264 "#,
265 Uuid::from(user.id),
266 )
267 .traced()
268 .fetch_all(&mut *self.conn)
269 .await?;
270
271 Ok(res.into_iter().map(Into::into).collect())
272 }
273
274 #[tracing::instrument(
275 name = "db.user_email.list",
276 skip_all,
277 fields(
278 db.query.text,
279 ),
280 err,
281 )]
282 async fn list(
283 &mut self,
284 filter: UserEmailFilter<'_>,
285 pagination: Pagination,
286 ) -> Result<Page<UserEmail>, DatabaseError> {
287 let (sql, arguments) = Query::select()
288 .expr_as(
289 Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
290 UserEmailLookupIden::UserEmailId,
291 )
292 .expr_as(
293 Expr::col((UserEmails::Table, UserEmails::UserId)),
294 UserEmailLookupIden::UserId,
295 )
296 .expr_as(
297 Expr::col((UserEmails::Table, UserEmails::Email)),
298 UserEmailLookupIden::Email,
299 )
300 .expr_as(
301 Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
302 UserEmailLookupIden::CreatedAt,
303 )
304 .from(UserEmails::Table)
305 .apply_filter(filter)
306 .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
307 .build_sqlx(PostgresQueryBuilder);
308
309 let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
310 .traced()
311 .fetch_all(&mut *self.conn)
312 .await?;
313
314 let page = pagination.process(edges).map(UserEmail::from);
315
316 Ok(page)
317 }
318
319 #[tracing::instrument(
320 name = "db.user_email.count",
321 skip_all,
322 fields(
323 db.query.text,
324 ),
325 err,
326 )]
327 async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
328 let (sql, arguments) = Query::select()
329 .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
330 .from(UserEmails::Table)
331 .apply_filter(filter)
332 .build_sqlx(PostgresQueryBuilder);
333
334 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
335 .traced()
336 .fetch_one(&mut *self.conn)
337 .await?;
338
339 count
340 .try_into()
341 .map_err(DatabaseError::to_invalid_operation)
342 }
343
344 #[tracing::instrument(
345 name = "db.user_email.add",
346 skip_all,
347 fields(
348 db.query.text,
349 %user.id,
350 user_email.id,
351 user_email.email = email,
352 ),
353 err,
354 )]
355 async fn add(
356 &mut self,
357 rng: &mut (dyn RngCore + Send),
358 clock: &dyn Clock,
359 user: &User,
360 email: String,
361 ) -> Result<UserEmail, Self::Error> {
362 let created_at = clock.now();
363 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
364 tracing::Span::current().record("user_email.id", tracing::field::display(id));
365
366 sqlx::query!(
367 r#"
368 INSERT INTO user_emails (user_email_id, user_id, email, created_at)
369 VALUES ($1, $2, $3, $4)
370 "#,
371 Uuid::from(id),
372 Uuid::from(user.id),
373 &email,
374 created_at,
375 )
376 .traced()
377 .execute(&mut *self.conn)
378 .await?;
379
380 Ok(UserEmail {
381 id,
382 user_id: user.id,
383 email,
384 created_at,
385 })
386 }
387
388 #[tracing::instrument(
389 name = "db.user_email.remove",
390 skip_all,
391 fields(
392 db.query.text,
393 user.id = %user_email.user_id,
394 %user_email.id,
395 %user_email.email,
396 ),
397 err,
398 )]
399 async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
400 let res = sqlx::query!(
401 r#"
402 DELETE FROM user_emails
403 WHERE user_email_id = $1
404 "#,
405 Uuid::from(user_email.id),
406 )
407 .traced()
408 .execute(&mut *self.conn)
409 .await?;
410
411 DatabaseError::ensure_affected_rows(&res, 1)?;
412
413 Ok(())
414 }
415
416 #[tracing::instrument(
417 name = "db.user_email.remove_bulk",
418 skip_all,
419 fields(
420 db.query.text,
421 ),
422 err,
423 )]
424 async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
425 let (sql, arguments) = Query::delete()
426 .from_table(UserEmails::Table)
427 .apply_filter(filter)
428 .build_sqlx(PostgresQueryBuilder);
429
430 let res = sqlx::query_with(&sql, arguments)
431 .traced()
432 .execute(&mut *self.conn)
433 .await?;
434
435 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
436 }
437
438 #[tracing::instrument(
439 name = "db.user_email.add_authentication_for_session",
440 skip_all,
441 fields(
442 db.query.text,
443 %session.id,
444 user_email_authentication.id,
445 user_email_authentication.email = email,
446 ),
447 err,
448 )]
449 async fn add_authentication_for_session(
450 &mut self,
451 rng: &mut (dyn RngCore + Send),
452 clock: &dyn Clock,
453 email: String,
454 session: &BrowserSession,
455 ) -> Result<UserEmailAuthentication, Self::Error> {
456 let created_at = clock.now();
457 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
458 tracing::Span::current()
459 .record("user_email_authentication.id", tracing::field::display(id));
460
461 sqlx::query!(
462 r#"
463 INSERT INTO user_email_authentications
464 ( user_email_authentication_id
465 , user_session_id
466 , email
467 , created_at
468 )
469 VALUES ($1, $2, $3, $4)
470 "#,
471 Uuid::from(id),
472 Uuid::from(session.id),
473 &email,
474 created_at,
475 )
476 .traced()
477 .execute(&mut *self.conn)
478 .await?;
479
480 Ok(UserEmailAuthentication {
481 id,
482 user_session_id: Some(session.id),
483 user_registration_id: None,
484 email,
485 created_at,
486 completed_at: None,
487 })
488 }
489
490 #[tracing::instrument(
491 name = "db.user_email.add_authentication_for_registration",
492 skip_all,
493 fields(
494 db.query.text,
495 %user_registration.id,
496 user_email_authentication.id,
497 user_email_authentication.email = email,
498 ),
499 err,
500 )]
501 async fn add_authentication_for_registration(
502 &mut self,
503 rng: &mut (dyn RngCore + Send),
504 clock: &dyn Clock,
505 email: String,
506 user_registration: &UserRegistration,
507 ) -> Result<UserEmailAuthentication, Self::Error> {
508 let created_at = clock.now();
509 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
510 tracing::Span::current()
511 .record("user_email_authentication.id", tracing::field::display(id));
512
513 sqlx::query!(
514 r#"
515 INSERT INTO user_email_authentications
516 ( user_email_authentication_id
517 , user_registration_id
518 , email
519 , created_at
520 )
521 VALUES ($1, $2, $3, $4)
522 "#,
523 Uuid::from(id),
524 Uuid::from(user_registration.id),
525 &email,
526 created_at,
527 )
528 .traced()
529 .execute(&mut *self.conn)
530 .await?;
531
532 Ok(UserEmailAuthentication {
533 id,
534 user_session_id: None,
535 user_registration_id: Some(user_registration.id),
536 email,
537 created_at,
538 completed_at: None,
539 })
540 }
541
542 #[tracing::instrument(
543 name = "db.user_email.add_authentication_code",
544 skip_all,
545 fields(
546 db.query.text,
547 %user_email_authentication.id,
548 %user_email_authentication.email,
549 user_email_authentication_code.id,
550 user_email_authentication_code.code = code,
551 ),
552 err,
553 )]
554 async fn add_authentication_code(
555 &mut self,
556 rng: &mut (dyn RngCore + Send),
557 clock: &dyn Clock,
558 duration: chrono::Duration,
559 user_email_authentication: &UserEmailAuthentication,
560 code: String,
561 ) -> Result<UserEmailAuthenticationCode, Self::Error> {
562 let created_at = clock.now();
563 let expires_at = created_at + duration;
564 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
565 tracing::Span::current().record(
566 "user_email_authentication_code.id",
567 tracing::field::display(id),
568 );
569
570 sqlx::query!(
571 r#"
572 INSERT INTO user_email_authentication_codes
573 ( user_email_authentication_code_id
574 , user_email_authentication_id
575 , code
576 , created_at
577 , expires_at
578 )
579 VALUES ($1, $2, $3, $4, $5)
580 "#,
581 Uuid::from(id),
582 Uuid::from(user_email_authentication.id),
583 &code,
584 created_at,
585 expires_at,
586 )
587 .traced()
588 .execute(&mut *self.conn)
589 .await?;
590
591 Ok(UserEmailAuthenticationCode {
592 id,
593 user_email_authentication_id: user_email_authentication.id,
594 code,
595 created_at,
596 expires_at,
597 })
598 }
599
600 #[tracing::instrument(
601 name = "db.user_email.lookup_authentication",
602 skip_all,
603 fields(
604 db.query.text,
605 user_email_authentication.id = %id,
606 ),
607 err,
608 )]
609 async fn lookup_authentication(
610 &mut self,
611 id: Ulid,
612 ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
613 let res = sqlx::query_as!(
614 UserEmailAuthenticationLookup,
615 r#"
616 SELECT user_email_authentication_id
617 , user_session_id
618 , user_registration_id
619 , email
620 , created_at
621 , completed_at
622 FROM user_email_authentications
623 WHERE user_email_authentication_id = $1
624 "#,
625 Uuid::from(id),
626 )
627 .traced()
628 .fetch_optional(&mut *self.conn)
629 .await?;
630
631 Ok(res.map(UserEmailAuthentication::from))
632 }
633
634 #[tracing::instrument(
635 name = "db.user_email.find_authentication_by_code",
636 skip_all,
637 fields(
638 db.query.text,
639 %authentication.id,
640 user_email_authentication_code.code = code,
641 ),
642 err,
643 )]
644 async fn find_authentication_code(
645 &mut self,
646 authentication: &UserEmailAuthentication,
647 code: &str,
648 ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
649 let res = sqlx::query_as!(
650 UserEmailAuthenticationCodeLookup,
651 r#"
652 SELECT user_email_authentication_code_id
653 , user_email_authentication_id
654 , code
655 , created_at
656 , expires_at
657 FROM user_email_authentication_codes
658 WHERE user_email_authentication_id = $1
659 AND code = $2
660 "#,
661 Uuid::from(authentication.id),
662 code,
663 )
664 .traced()
665 .fetch_optional(&mut *self.conn)
666 .await?;
667
668 Ok(res.map(UserEmailAuthenticationCode::from))
669 }
670
671 #[tracing::instrument(
672 name = "db.user_email.complete_email_authentication_with_code",
673 skip_all,
674 fields(
675 db.query.text,
676 %user_email_authentication.id,
677 %user_email_authentication.email,
678 %user_email_authentication_code.id,
679 %user_email_authentication_code.code,
680 ),
681 err,
682 )]
683 async fn complete_authentication_with_code(
684 &mut self,
685 clock: &dyn Clock,
686 mut user_email_authentication: UserEmailAuthentication,
687 user_email_authentication_code: &UserEmailAuthenticationCode,
688 ) -> Result<UserEmailAuthentication, Self::Error> {
689 let completed_at = clock.now();
693
694 let res = sqlx::query!(
698 r#"
699 UPDATE user_email_authentications
700 SET completed_at = $2
701 WHERE user_email_authentication_id = $1
702 AND completed_at IS NULL
703 "#,
704 Uuid::from(user_email_authentication.id),
705 completed_at,
706 )
707 .traced()
708 .execute(&mut *self.conn)
709 .await?;
710
711 DatabaseError::ensure_affected_rows(&res, 1)?;
712
713 user_email_authentication.completed_at = Some(completed_at);
714 Ok(user_email_authentication)
715 }
716
717 #[tracing::instrument(
718 name = "db.user_email.complete_email_authentication_with_upstream",
719 skip_all,
720 fields(
721 db.query.text,
722 %user_email_authentication.id,
723 %user_email_authentication.email,
724 %upstream_oauth_authorization_session.id,
725 ),
726 err,
727 )]
728 async fn complete_authentication_with_upstream(
729 &mut self,
730 clock: &dyn Clock,
731 mut user_email_authentication: UserEmailAuthentication,
732 upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
733 ) -> Result<UserEmailAuthentication, Self::Error> {
734 let completed_at = clock.now();
738
739 let res = sqlx::query!(
743 r#"
744 UPDATE user_email_authentications
745 SET completed_at = $2
746 WHERE user_email_authentication_id = $1
747 AND completed_at IS NULL
748 "#,
749 Uuid::from(user_email_authentication.id),
750 completed_at,
751 )
752 .traced()
753 .execute(&mut *self.conn)
754 .await?;
755
756 DatabaseError::ensure_affected_rows(&res, 1)?;
757
758 user_email_authentication.completed_at = Some(completed_at);
759 Ok(user_email_authentication)
760 }
761
762 #[tracing::instrument(
763 name = "db.user_email.cleanup_authentications",
764 skip_all,
765 fields(
766 db.query.text,
767 since = since.map(tracing::field::display),
768 until = %until,
769 limit = limit,
770 ),
771 err,
772 )]
773 async fn cleanup_authentications(
774 &mut self,
775 since: Option<Ulid>,
776 until: Ulid,
777 limit: usize,
778 ) -> Result<(usize, Option<Ulid>), Self::Error> {
779 let res = sqlx::query_scalar!(
783 r#"
784 WITH
785 to_delete AS (
786 SELECT user_email_authentication_id
787 FROM user_email_authentications
788 WHERE ($1::uuid IS NULL OR user_email_authentication_id > $1)
789 AND user_email_authentication_id <= $2
790 ORDER BY user_email_authentication_id
791 LIMIT $3
792 ),
793 deleted_codes AS (
794 DELETE FROM user_email_authentication_codes
795 USING to_delete
796 WHERE user_email_authentication_codes.user_email_authentication_id = to_delete.user_email_authentication_id
797 RETURNING user_email_authentication_codes.user_email_authentication_code_id
798 )
799 DELETE FROM user_email_authentications
800 USING to_delete
801 WHERE user_email_authentications.user_email_authentication_id = to_delete.user_email_authentication_id
802 RETURNING user_email_authentications.user_email_authentication_id
803 "#,
804 since.map(Uuid::from),
805 Uuid::from(until),
806 i64::try_from(limit).unwrap_or(i64::MAX)
807 )
808 .traced()
809 .fetch_all(&mut *self.conn)
810 .await?;
811
812 let count = res.len();
813 let max_id = res.into_iter().max();
814
815 Ok((count, max_id.map(Ulid::from)))
816 }
817}