mas_storage_pg/user/
email.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail,
12    UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration,
13};
14use mas_storage::{
15    Page, Pagination,
16    pagination::Node,
17    user::{UserEmailFilter, UserEmailRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, Func, PostgresQueryBuilder, Query, SimpleExpr, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27    DatabaseError,
28    filter::{Filter, StatementExt},
29    iden::UserEmails,
30    pagination::QueryBuilderExt,
31    tracing::ExecuteExt,
32};
33
34/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection
35pub struct PgUserEmailRepository<'c> {
36    conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUserEmailRepository<'c> {
40    /// Create a new [`PgUserEmailRepository`] from an active PostgreSQL
41    /// connection
42    pub fn new(conn: &'c mut PgConnection) -> Self {
43        Self { conn }
44    }
45}
46
47#[derive(Debug, Clone, sqlx::FromRow)]
48#[enum_def]
49struct UserEmailLookup {
50    user_email_id: Uuid,
51    user_id: Uuid,
52    email: String,
53    created_at: DateTime<Utc>,
54}
55
56impl Node<Ulid> for UserEmailLookup {
57    fn cursor(&self) -> Ulid {
58        self.user_email_id.into()
59    }
60}
61
62impl From<UserEmailLookup> for UserEmail {
63    fn from(e: UserEmailLookup) -> UserEmail {
64        UserEmail {
65            id: e.user_email_id.into(),
66            user_id: e.user_id.into(),
67            email: e.email,
68            created_at: e.created_at,
69        }
70    }
71}
72
73struct UserEmailAuthenticationLookup {
74    user_email_authentication_id: Uuid,
75    user_session_id: Option<Uuid>,
76    user_registration_id: Option<Uuid>,
77    email: String,
78    created_at: DateTime<Utc>,
79    completed_at: Option<DateTime<Utc>>,
80}
81
82impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
83    fn from(value: UserEmailAuthenticationLookup) -> Self {
84        UserEmailAuthentication {
85            id: value.user_email_authentication_id.into(),
86            user_session_id: value.user_session_id.map(Ulid::from),
87            user_registration_id: value.user_registration_id.map(Ulid::from),
88            email: value.email,
89            created_at: value.created_at,
90            completed_at: value.completed_at,
91        }
92    }
93}
94
95struct UserEmailAuthenticationCodeLookup {
96    user_email_authentication_code_id: Uuid,
97    user_email_authentication_id: Uuid,
98    code: String,
99    created_at: DateTime<Utc>,
100    expires_at: DateTime<Utc>,
101}
102
103impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
104    fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
105        UserEmailAuthenticationCode {
106            id: value.user_email_authentication_code_id.into(),
107            user_email_authentication_id: value.user_email_authentication_id.into(),
108            code: value.code,
109            created_at: value.created_at,
110            expires_at: value.expires_at,
111        }
112    }
113}
114
115impl Filter for UserEmailFilter<'_> {
116    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
117        sea_query::Condition::all()
118            .add_option(self.user().map(|user| {
119                Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
120            }))
121            .add_option(self.email().map(|email| {
122                SimpleExpr::from(Func::lower(Expr::col((
123                    UserEmails::Table,
124                    UserEmails::Email,
125                ))))
126                .eq(Func::lower(email))
127            }))
128    }
129}
130
131#[async_trait]
132impl UserEmailRepository for PgUserEmailRepository<'_> {
133    type Error = DatabaseError;
134
135    #[tracing::instrument(
136        name = "db.user_email.lookup",
137        skip_all,
138        fields(
139            db.query.text,
140            user_email.id = %id,
141        ),
142        err,
143    )]
144    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
145        let res = sqlx::query_as!(
146            UserEmailLookup,
147            r#"
148                SELECT user_email_id
149                     , user_id
150                     , email
151                     , created_at
152                FROM user_emails
153
154                WHERE user_email_id = $1
155            "#,
156            Uuid::from(id),
157        )
158        .traced()
159        .fetch_optional(&mut *self.conn)
160        .await?;
161
162        let Some(user_email) = res else {
163            return Ok(None);
164        };
165
166        Ok(Some(user_email.into()))
167    }
168
169    #[tracing::instrument(
170        name = "db.user_email.find",
171        skip_all,
172        fields(
173            db.query.text,
174            %user.id,
175            user_email.email = email,
176        ),
177        err,
178    )]
179    async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
180        let res = sqlx::query_as!(
181            UserEmailLookup,
182            r#"
183                SELECT user_email_id
184                     , user_id
185                     , email
186                     , created_at
187                FROM user_emails
188
189                WHERE user_id = $1 AND LOWER(email) = LOWER($2)
190            "#,
191            Uuid::from(user.id),
192            email,
193        )
194        .traced()
195        .fetch_optional(&mut *self.conn)
196        .await?;
197
198        let Some(user_email) = res else {
199            return Ok(None);
200        };
201
202        Ok(Some(user_email.into()))
203    }
204
205    #[tracing::instrument(
206        name = "db.user_email.find_by_email",
207        skip_all,
208        fields(
209            db.query.text,
210            user_email.email = email,
211        ),
212        err,
213    )]
214    async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
215        let res = sqlx::query_as!(
216            UserEmailLookup,
217            r#"
218                SELECT user_email_id
219                     , user_id
220                     , email
221                     , created_at
222                FROM user_emails
223                WHERE LOWER(email) = LOWER($1)
224            "#,
225            email,
226        )
227        .traced()
228        .fetch_all(&mut *self.conn)
229        .await?;
230
231        if res.len() != 1 {
232            return Ok(None);
233        }
234
235        let Some(user_email) = res.into_iter().next() else {
236            return Ok(None);
237        };
238
239        Ok(Some(user_email.into()))
240    }
241
242    #[tracing::instrument(
243        name = "db.user_email.all",
244        skip_all,
245        fields(
246            db.query.text,
247            %user.id,
248        ),
249        err,
250    )]
251    async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
252        let res = sqlx::query_as!(
253            UserEmailLookup,
254            r#"
255                SELECT user_email_id
256                     , user_id
257                     , email
258                     , created_at
259                FROM user_emails
260
261                WHERE user_id = $1
262
263                ORDER BY email ASC
264            "#,
265            Uuid::from(user.id),
266        )
267        .traced()
268        .fetch_all(&mut *self.conn)
269        .await?;
270
271        Ok(res.into_iter().map(Into::into).collect())
272    }
273
274    #[tracing::instrument(
275        name = "db.user_email.list",
276        skip_all,
277        fields(
278            db.query.text,
279        ),
280        err,
281    )]
282    async fn list(
283        &mut self,
284        filter: UserEmailFilter<'_>,
285        pagination: Pagination,
286    ) -> Result<Page<UserEmail>, DatabaseError> {
287        let (sql, arguments) = Query::select()
288            .expr_as(
289                Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
290                UserEmailLookupIden::UserEmailId,
291            )
292            .expr_as(
293                Expr::col((UserEmails::Table, UserEmails::UserId)),
294                UserEmailLookupIden::UserId,
295            )
296            .expr_as(
297                Expr::col((UserEmails::Table, UserEmails::Email)),
298                UserEmailLookupIden::Email,
299            )
300            .expr_as(
301                Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
302                UserEmailLookupIden::CreatedAt,
303            )
304            .from(UserEmails::Table)
305            .apply_filter(filter)
306            .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
307            .build_sqlx(PostgresQueryBuilder);
308
309        let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
310            .traced()
311            .fetch_all(&mut *self.conn)
312            .await?;
313
314        let page = pagination.process(edges).map(UserEmail::from);
315
316        Ok(page)
317    }
318
319    #[tracing::instrument(
320        name = "db.user_email.count",
321        skip_all,
322        fields(
323            db.query.text,
324        ),
325        err,
326    )]
327    async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
328        let (sql, arguments) = Query::select()
329            .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
330            .from(UserEmails::Table)
331            .apply_filter(filter)
332            .build_sqlx(PostgresQueryBuilder);
333
334        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
335            .traced()
336            .fetch_one(&mut *self.conn)
337            .await?;
338
339        count
340            .try_into()
341            .map_err(DatabaseError::to_invalid_operation)
342    }
343
344    #[tracing::instrument(
345        name = "db.user_email.add",
346        skip_all,
347        fields(
348            db.query.text,
349            %user.id,
350            user_email.id,
351            user_email.email = email,
352        ),
353        err,
354    )]
355    async fn add(
356        &mut self,
357        rng: &mut (dyn RngCore + Send),
358        clock: &dyn Clock,
359        user: &User,
360        email: String,
361    ) -> Result<UserEmail, Self::Error> {
362        let created_at = clock.now();
363        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
364        tracing::Span::current().record("user_email.id", tracing::field::display(id));
365
366        sqlx::query!(
367            r#"
368                INSERT INTO user_emails (user_email_id, user_id, email, created_at)
369                VALUES ($1, $2, $3, $4)
370            "#,
371            Uuid::from(id),
372            Uuid::from(user.id),
373            &email,
374            created_at,
375        )
376        .traced()
377        .execute(&mut *self.conn)
378        .await?;
379
380        Ok(UserEmail {
381            id,
382            user_id: user.id,
383            email,
384            created_at,
385        })
386    }
387
388    #[tracing::instrument(
389        name = "db.user_email.remove",
390        skip_all,
391        fields(
392            db.query.text,
393            user.id = %user_email.user_id,
394            %user_email.id,
395            %user_email.email,
396        ),
397        err,
398    )]
399    async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
400        let res = sqlx::query!(
401            r#"
402                DELETE FROM user_emails
403                WHERE user_email_id = $1
404            "#,
405            Uuid::from(user_email.id),
406        )
407        .traced()
408        .execute(&mut *self.conn)
409        .await?;
410
411        DatabaseError::ensure_affected_rows(&res, 1)?;
412
413        Ok(())
414    }
415
416    #[tracing::instrument(
417        name = "db.user_email.remove_bulk",
418        skip_all,
419        fields(
420            db.query.text,
421        ),
422        err,
423    )]
424    async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
425        let (sql, arguments) = Query::delete()
426            .from_table(UserEmails::Table)
427            .apply_filter(filter)
428            .build_sqlx(PostgresQueryBuilder);
429
430        let res = sqlx::query_with(&sql, arguments)
431            .traced()
432            .execute(&mut *self.conn)
433            .await?;
434
435        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
436    }
437
438    #[tracing::instrument(
439        name = "db.user_email.add_authentication_for_session",
440        skip_all,
441        fields(
442            db.query.text,
443            %session.id,
444            user_email_authentication.id,
445            user_email_authentication.email = email,
446        ),
447        err,
448    )]
449    async fn add_authentication_for_session(
450        &mut self,
451        rng: &mut (dyn RngCore + Send),
452        clock: &dyn Clock,
453        email: String,
454        session: &BrowserSession,
455    ) -> Result<UserEmailAuthentication, Self::Error> {
456        let created_at = clock.now();
457        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
458        tracing::Span::current()
459            .record("user_email_authentication.id", tracing::field::display(id));
460
461        sqlx::query!(
462            r#"
463                INSERT INTO user_email_authentications
464                  ( user_email_authentication_id
465                  , user_session_id
466                  , email
467                  , created_at
468                  )
469                VALUES ($1, $2, $3, $4)
470            "#,
471            Uuid::from(id),
472            Uuid::from(session.id),
473            &email,
474            created_at,
475        )
476        .traced()
477        .execute(&mut *self.conn)
478        .await?;
479
480        Ok(UserEmailAuthentication {
481            id,
482            user_session_id: Some(session.id),
483            user_registration_id: None,
484            email,
485            created_at,
486            completed_at: None,
487        })
488    }
489
490    #[tracing::instrument(
491        name = "db.user_email.add_authentication_for_registration",
492        skip_all,
493        fields(
494            db.query.text,
495            %user_registration.id,
496            user_email_authentication.id,
497            user_email_authentication.email = email,
498        ),
499        err,
500    )]
501    async fn add_authentication_for_registration(
502        &mut self,
503        rng: &mut (dyn RngCore + Send),
504        clock: &dyn Clock,
505        email: String,
506        user_registration: &UserRegistration,
507    ) -> Result<UserEmailAuthentication, Self::Error> {
508        let created_at = clock.now();
509        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
510        tracing::Span::current()
511            .record("user_email_authentication.id", tracing::field::display(id));
512
513        sqlx::query!(
514            r#"
515                INSERT INTO user_email_authentications
516                  ( user_email_authentication_id
517                  , user_registration_id
518                  , email
519                  , created_at
520                  )
521                VALUES ($1, $2, $3, $4)
522            "#,
523            Uuid::from(id),
524            Uuid::from(user_registration.id),
525            &email,
526            created_at,
527        )
528        .traced()
529        .execute(&mut *self.conn)
530        .await?;
531
532        Ok(UserEmailAuthentication {
533            id,
534            user_session_id: None,
535            user_registration_id: Some(user_registration.id),
536            email,
537            created_at,
538            completed_at: None,
539        })
540    }
541
542    #[tracing::instrument(
543        name = "db.user_email.add_authentication_code",
544        skip_all,
545        fields(
546            db.query.text,
547            %user_email_authentication.id,
548            %user_email_authentication.email,
549            user_email_authentication_code.id,
550            user_email_authentication_code.code = code,
551        ),
552        err,
553    )]
554    async fn add_authentication_code(
555        &mut self,
556        rng: &mut (dyn RngCore + Send),
557        clock: &dyn Clock,
558        duration: chrono::Duration,
559        user_email_authentication: &UserEmailAuthentication,
560        code: String,
561    ) -> Result<UserEmailAuthenticationCode, Self::Error> {
562        let created_at = clock.now();
563        let expires_at = created_at + duration;
564        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
565        tracing::Span::current().record(
566            "user_email_authentication_code.id",
567            tracing::field::display(id),
568        );
569
570        sqlx::query!(
571            r#"
572                INSERT INTO user_email_authentication_codes
573                  ( user_email_authentication_code_id
574                  , user_email_authentication_id
575                  , code
576                  , created_at
577                  , expires_at
578                  )
579                VALUES ($1, $2, $3, $4, $5)
580            "#,
581            Uuid::from(id),
582            Uuid::from(user_email_authentication.id),
583            &code,
584            created_at,
585            expires_at,
586        )
587        .traced()
588        .execute(&mut *self.conn)
589        .await?;
590
591        Ok(UserEmailAuthenticationCode {
592            id,
593            user_email_authentication_id: user_email_authentication.id,
594            code,
595            created_at,
596            expires_at,
597        })
598    }
599
600    #[tracing::instrument(
601        name = "db.user_email.lookup_authentication",
602        skip_all,
603        fields(
604            db.query.text,
605            user_email_authentication.id = %id,
606        ),
607        err,
608    )]
609    async fn lookup_authentication(
610        &mut self,
611        id: Ulid,
612    ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
613        let res = sqlx::query_as!(
614            UserEmailAuthenticationLookup,
615            r#"
616                SELECT user_email_authentication_id
617                     , user_session_id
618                     , user_registration_id
619                     , email
620                     , created_at
621                     , completed_at
622                FROM user_email_authentications
623                WHERE user_email_authentication_id = $1
624            "#,
625            Uuid::from(id),
626        )
627        .traced()
628        .fetch_optional(&mut *self.conn)
629        .await?;
630
631        Ok(res.map(UserEmailAuthentication::from))
632    }
633
634    #[tracing::instrument(
635        name = "db.user_email.find_authentication_by_code",
636        skip_all,
637        fields(
638            db.query.text,
639            %authentication.id,
640            user_email_authentication_code.code = code,
641        ),
642        err,
643    )]
644    async fn find_authentication_code(
645        &mut self,
646        authentication: &UserEmailAuthentication,
647        code: &str,
648    ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
649        let res = sqlx::query_as!(
650            UserEmailAuthenticationCodeLookup,
651            r#"
652                SELECT user_email_authentication_code_id
653                     , user_email_authentication_id
654                     , code
655                     , created_at
656                     , expires_at
657                FROM user_email_authentication_codes
658                WHERE user_email_authentication_id = $1
659                  AND code = $2
660            "#,
661            Uuid::from(authentication.id),
662            code,
663        )
664        .traced()
665        .fetch_optional(&mut *self.conn)
666        .await?;
667
668        Ok(res.map(UserEmailAuthenticationCode::from))
669    }
670
671    #[tracing::instrument(
672        name = "db.user_email.complete_email_authentication_with_code",
673        skip_all,
674        fields(
675            db.query.text,
676            %user_email_authentication.id,
677            %user_email_authentication.email,
678            %user_email_authentication_code.id,
679            %user_email_authentication_code.code,
680        ),
681        err,
682    )]
683    async fn complete_authentication_with_code(
684        &mut self,
685        clock: &dyn Clock,
686        mut user_email_authentication: UserEmailAuthentication,
687        user_email_authentication_code: &UserEmailAuthenticationCode,
688    ) -> Result<UserEmailAuthentication, Self::Error> {
689        // We technically don't use the authentication code here (other than
690        // recording it in the span), but this is to make sure the caller has
691        // fetched one before calling this
692        let completed_at = clock.now();
693
694        // We'll assume the caller has checked that completed_at is None, so in case
695        // they haven't, the update will not affect any rows, which will raise
696        // an error
697        let res = sqlx::query!(
698            r#"
699                UPDATE user_email_authentications
700                SET completed_at = $2
701                WHERE user_email_authentication_id = $1
702                  AND completed_at IS NULL
703            "#,
704            Uuid::from(user_email_authentication.id),
705            completed_at,
706        )
707        .traced()
708        .execute(&mut *self.conn)
709        .await?;
710
711        DatabaseError::ensure_affected_rows(&res, 1)?;
712
713        user_email_authentication.completed_at = Some(completed_at);
714        Ok(user_email_authentication)
715    }
716
717    #[tracing::instrument(
718        name = "db.user_email.complete_email_authentication_with_upstream",
719        skip_all,
720        fields(
721            db.query.text,
722            %user_email_authentication.id,
723            %user_email_authentication.email,
724            %upstream_oauth_authorization_session.id,
725        ),
726        err,
727    )]
728    async fn complete_authentication_with_upstream(
729        &mut self,
730        clock: &dyn Clock,
731        mut user_email_authentication: UserEmailAuthentication,
732        upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
733    ) -> Result<UserEmailAuthentication, Self::Error> {
734        // We technically don't use the upstream_oauth_authorization_session here (other
735        // than recording it in the span), but this is to make sure the caller
736        // has fetched one before calling this
737        let completed_at = clock.now();
738
739        // We'll assume the caller has checked that completed_at is None, so in case
740        // they haven't, the update will not affect any rows, which will raise
741        // an error
742        let res = sqlx::query!(
743            r#"
744                UPDATE user_email_authentications
745                SET completed_at = $2
746                WHERE user_email_authentication_id = $1
747                  AND completed_at IS NULL
748            "#,
749            Uuid::from(user_email_authentication.id),
750            completed_at,
751        )
752        .traced()
753        .execute(&mut *self.conn)
754        .await?;
755
756        DatabaseError::ensure_affected_rows(&res, 1)?;
757
758        user_email_authentication.completed_at = Some(completed_at);
759        Ok(user_email_authentication)
760    }
761
762    #[tracing::instrument(
763        name = "db.user_email.cleanup_authentications",
764        skip_all,
765        fields(
766            db.query.text,
767            since = since.map(tracing::field::display),
768            until = %until,
769            limit = limit,
770        ),
771        err,
772    )]
773    async fn cleanup_authentications(
774        &mut self,
775        since: Option<Ulid>,
776        until: Ulid,
777        limit: usize,
778    ) -> Result<(usize, Option<Ulid>), Self::Error> {
779        // Use ULID cursor-based pagination. Since ULIDs contain a timestamp,
780        // we can efficiently delete old authentications without needing an index.
781        // `MAX(uuid)` isn't a thing in Postgres, so we aggregate on the client side.
782        let res = sqlx::query_scalar!(
783            r#"
784                WITH
785                  to_delete AS (
786                    SELECT user_email_authentication_id
787                    FROM user_email_authentications
788                    WHERE ($1::uuid IS NULL OR user_email_authentication_id > $1)
789                      AND user_email_authentication_id <= $2
790                    ORDER BY user_email_authentication_id
791                    LIMIT $3
792                  ),
793                  deleted_codes AS (
794                    DELETE FROM user_email_authentication_codes
795                    USING to_delete
796                    WHERE user_email_authentication_codes.user_email_authentication_id = to_delete.user_email_authentication_id
797                    RETURNING user_email_authentication_codes.user_email_authentication_code_id
798                  )
799                DELETE FROM user_email_authentications
800                USING to_delete
801                WHERE user_email_authentications.user_email_authentication_id = to_delete.user_email_authentication_id
802                RETURNING user_email_authentications.user_email_authentication_id
803            "#,
804            since.map(Uuid::from),
805            Uuid::from(until),
806            i64::try_from(limit).unwrap_or(i64::MAX)
807        )
808        .traced()
809        .fetch_all(&mut *self.conn)
810        .await?;
811
812        let count = res.len();
813        let max_id = res.into_iter().max();
814
815        Ok((count, max_id.map(Ulid::from)))
816    }
817}