mas_storage_pg/user/
email.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    BrowserSession, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11    UserRegistration,
12};
13use mas_storage::{
14    Clock, Page, Pagination,
15    user::{UserEmailFilter, UserEmailRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError,
26    filter::{Filter, StatementExt},
27    iden::UserEmails,
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection
33pub struct PgUserEmailRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUserEmailRepository<'c> {
38    /// Create a new [`PgUserEmailRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(Debug, Clone, sqlx::FromRow)]
46#[enum_def]
47struct UserEmailLookup {
48    user_email_id: Uuid,
49    user_id: Uuid,
50    email: String,
51    created_at: DateTime<Utc>,
52}
53
54impl From<UserEmailLookup> for UserEmail {
55    fn from(e: UserEmailLookup) -> UserEmail {
56        UserEmail {
57            id: e.user_email_id.into(),
58            user_id: e.user_id.into(),
59            email: e.email,
60            created_at: e.created_at,
61        }
62    }
63}
64
65struct UserEmailAuthenticationLookup {
66    user_email_authentication_id: Uuid,
67    user_session_id: Option<Uuid>,
68    user_registration_id: Option<Uuid>,
69    email: String,
70    created_at: DateTime<Utc>,
71    completed_at: Option<DateTime<Utc>>,
72}
73
74impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
75    fn from(value: UserEmailAuthenticationLookup) -> Self {
76        UserEmailAuthentication {
77            id: value.user_email_authentication_id.into(),
78            user_session_id: value.user_session_id.map(Ulid::from),
79            user_registration_id: value.user_registration_id.map(Ulid::from),
80            email: value.email,
81            created_at: value.created_at,
82            completed_at: value.completed_at,
83        }
84    }
85}
86
87struct UserEmailAuthenticationCodeLookup {
88    user_email_authentication_code_id: Uuid,
89    user_email_authentication_id: Uuid,
90    code: String,
91    created_at: DateTime<Utc>,
92    expires_at: DateTime<Utc>,
93}
94
95impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
96    fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
97        UserEmailAuthenticationCode {
98            id: value.user_email_authentication_code_id.into(),
99            user_email_authentication_id: value.user_email_authentication_id.into(),
100            code: value.code,
101            created_at: value.created_at,
102            expires_at: value.expires_at,
103        }
104    }
105}
106
107impl Filter for UserEmailFilter<'_> {
108    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109        sea_query::Condition::all()
110            .add_option(self.user().map(|user| {
111                Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
112            }))
113            .add_option(
114                self.email()
115                    .map(|email| Expr::col((UserEmails::Table, UserEmails::Email)).eq(email)),
116            )
117    }
118}
119
120#[async_trait]
121impl UserEmailRepository for PgUserEmailRepository<'_> {
122    type Error = DatabaseError;
123
124    #[tracing::instrument(
125        name = "db.user_email.lookup",
126        skip_all,
127        fields(
128            db.query.text,
129            user_email.id = %id,
130        ),
131        err,
132    )]
133    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
134        let res = sqlx::query_as!(
135            UserEmailLookup,
136            r#"
137                SELECT user_email_id
138                     , user_id
139                     , email
140                     , created_at
141                FROM user_emails
142
143                WHERE user_email_id = $1
144            "#,
145            Uuid::from(id),
146        )
147        .traced()
148        .fetch_optional(&mut *self.conn)
149        .await?;
150
151        let Some(user_email) = res else {
152            return Ok(None);
153        };
154
155        Ok(Some(user_email.into()))
156    }
157
158    #[tracing::instrument(
159        name = "db.user_email.find",
160        skip_all,
161        fields(
162            db.query.text,
163            %user.id,
164            user_email.email = email,
165        ),
166        err,
167    )]
168    async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
169        let res = sqlx::query_as!(
170            UserEmailLookup,
171            r#"
172                SELECT user_email_id
173                     , user_id
174                     , email
175                     , created_at
176                FROM user_emails
177
178                WHERE user_id = $1 AND email = $2
179            "#,
180            Uuid::from(user.id),
181            email,
182        )
183        .traced()
184        .fetch_optional(&mut *self.conn)
185        .await?;
186
187        let Some(user_email) = res else {
188            return Ok(None);
189        };
190
191        Ok(Some(user_email.into()))
192    }
193
194    #[tracing::instrument(
195        name = "db.user_email.all",
196        skip_all,
197        fields(
198            db.query.text,
199            %user.id,
200        ),
201        err,
202    )]
203    async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
204        let res = sqlx::query_as!(
205            UserEmailLookup,
206            r#"
207                SELECT user_email_id
208                     , user_id
209                     , email
210                     , created_at
211                FROM user_emails
212
213                WHERE user_id = $1
214
215                ORDER BY email ASC
216            "#,
217            Uuid::from(user.id),
218        )
219        .traced()
220        .fetch_all(&mut *self.conn)
221        .await?;
222
223        Ok(res.into_iter().map(Into::into).collect())
224    }
225
226    #[tracing::instrument(
227        name = "db.user_email.list",
228        skip_all,
229        fields(
230            db.query.text,
231        ),
232        err,
233    )]
234    async fn list(
235        &mut self,
236        filter: UserEmailFilter<'_>,
237        pagination: Pagination,
238    ) -> Result<Page<UserEmail>, DatabaseError> {
239        let (sql, arguments) = Query::select()
240            .expr_as(
241                Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
242                UserEmailLookupIden::UserEmailId,
243            )
244            .expr_as(
245                Expr::col((UserEmails::Table, UserEmails::UserId)),
246                UserEmailLookupIden::UserId,
247            )
248            .expr_as(
249                Expr::col((UserEmails::Table, UserEmails::Email)),
250                UserEmailLookupIden::Email,
251            )
252            .expr_as(
253                Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
254                UserEmailLookupIden::CreatedAt,
255            )
256            .from(UserEmails::Table)
257            .apply_filter(filter)
258            .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
259            .build_sqlx(PostgresQueryBuilder);
260
261        let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
262            .traced()
263            .fetch_all(&mut *self.conn)
264            .await?;
265
266        let page = pagination.process(edges).map(UserEmail::from);
267
268        Ok(page)
269    }
270
271    #[tracing::instrument(
272        name = "db.user_email.count",
273        skip_all,
274        fields(
275            db.query.text,
276        ),
277        err,
278    )]
279    async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
280        let (sql, arguments) = Query::select()
281            .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
282            .from(UserEmails::Table)
283            .apply_filter(filter)
284            .build_sqlx(PostgresQueryBuilder);
285
286        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
287            .traced()
288            .fetch_one(&mut *self.conn)
289            .await?;
290
291        count
292            .try_into()
293            .map_err(DatabaseError::to_invalid_operation)
294    }
295
296    #[tracing::instrument(
297        name = "db.user_email.add",
298        skip_all,
299        fields(
300            db.query.text,
301            %user.id,
302            user_email.id,
303            user_email.email = email,
304        ),
305        err,
306    )]
307    async fn add(
308        &mut self,
309        rng: &mut (dyn RngCore + Send),
310        clock: &dyn Clock,
311        user: &User,
312        email: String,
313    ) -> Result<UserEmail, Self::Error> {
314        let created_at = clock.now();
315        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
316        tracing::Span::current().record("user_email.id", tracing::field::display(id));
317
318        sqlx::query!(
319            r#"
320                INSERT INTO user_emails (user_email_id, user_id, email, created_at)
321                VALUES ($1, $2, $3, $4)
322            "#,
323            Uuid::from(id),
324            Uuid::from(user.id),
325            &email,
326            created_at,
327        )
328        .traced()
329        .execute(&mut *self.conn)
330        .await?;
331
332        Ok(UserEmail {
333            id,
334            user_id: user.id,
335            email,
336            created_at,
337        })
338    }
339
340    #[tracing::instrument(
341        name = "db.user_email.remove",
342        skip_all,
343        fields(
344            db.query.text,
345            user.id = %user_email.user_id,
346            %user_email.id,
347            %user_email.email,
348        ),
349        err,
350    )]
351    async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
352        let res = sqlx::query!(
353            r#"
354                DELETE FROM user_emails
355                WHERE user_email_id = $1
356            "#,
357            Uuid::from(user_email.id),
358        )
359        .traced()
360        .execute(&mut *self.conn)
361        .await?;
362
363        DatabaseError::ensure_affected_rows(&res, 1)?;
364
365        Ok(())
366    }
367
368    #[tracing::instrument(
369        name = "db.user_email.remove_bulk",
370        skip_all,
371        fields(
372            db.query.text,
373        ),
374        err,
375    )]
376    async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
377        let (sql, arguments) = Query::delete()
378            .from_table(UserEmails::Table)
379            .apply_filter(filter)
380            .build_sqlx(PostgresQueryBuilder);
381
382        let res = sqlx::query_with(&sql, arguments)
383            .traced()
384            .execute(&mut *self.conn)
385            .await?;
386
387        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
388    }
389
390    #[tracing::instrument(
391        name = "db.user_email.add_authentication_for_session",
392        skip_all,
393        fields(
394            db.query.text,
395            %session.id,
396            user_email_authentication.id,
397            user_email_authentication.email = email,
398        ),
399        err,
400    )]
401    async fn add_authentication_for_session(
402        &mut self,
403        rng: &mut (dyn RngCore + Send),
404        clock: &dyn Clock,
405        email: String,
406        session: &BrowserSession,
407    ) -> Result<UserEmailAuthentication, Self::Error> {
408        let created_at = clock.now();
409        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
410        tracing::Span::current()
411            .record("user_email_authentication.id", tracing::field::display(id));
412
413        sqlx::query!(
414            r#"
415                INSERT INTO user_email_authentications
416                  ( user_email_authentication_id
417                  , user_session_id
418                  , email
419                  , created_at
420                  )
421                VALUES ($1, $2, $3, $4)
422            "#,
423            Uuid::from(id),
424            Uuid::from(session.id),
425            &email,
426            created_at,
427        )
428        .traced()
429        .execute(&mut *self.conn)
430        .await?;
431
432        Ok(UserEmailAuthentication {
433            id,
434            user_session_id: Some(session.id),
435            user_registration_id: None,
436            email,
437            created_at,
438            completed_at: None,
439        })
440    }
441
442    #[tracing::instrument(
443        name = "db.user_email.add_authentication_for_registration",
444        skip_all,
445        fields(
446            db.query.text,
447            %user_registration.id,
448            user_email_authentication.id,
449            user_email_authentication.email = email,
450        ),
451        err,
452    )]
453    async fn add_authentication_for_registration(
454        &mut self,
455        rng: &mut (dyn RngCore + Send),
456        clock: &dyn Clock,
457        email: String,
458        user_registration: &UserRegistration,
459    ) -> Result<UserEmailAuthentication, Self::Error> {
460        let created_at = clock.now();
461        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
462        tracing::Span::current()
463            .record("user_email_authentication.id", tracing::field::display(id));
464
465        sqlx::query!(
466            r#"
467                INSERT INTO user_email_authentications
468                  ( user_email_authentication_id
469                  , user_registration_id
470                  , email
471                  , created_at
472                  )
473                VALUES ($1, $2, $3, $4)
474            "#,
475            Uuid::from(id),
476            Uuid::from(user_registration.id),
477            &email,
478            created_at,
479        )
480        .traced()
481        .execute(&mut *self.conn)
482        .await?;
483
484        Ok(UserEmailAuthentication {
485            id,
486            user_session_id: None,
487            user_registration_id: Some(user_registration.id),
488            email,
489            created_at,
490            completed_at: None,
491        })
492    }
493
494    #[tracing::instrument(
495        name = "db.user_email.add_authentication_code",
496        skip_all,
497        fields(
498            db.query.text,
499            %user_email_authentication.id,
500            %user_email_authentication.email,
501            user_email_authentication_code.id,
502            user_email_authentication_code.code = code,
503        ),
504        err,
505    )]
506    async fn add_authentication_code(
507        &mut self,
508        rng: &mut (dyn RngCore + Send),
509        clock: &dyn Clock,
510        duration: chrono::Duration,
511        user_email_authentication: &UserEmailAuthentication,
512        code: String,
513    ) -> Result<UserEmailAuthenticationCode, Self::Error> {
514        let created_at = clock.now();
515        let expires_at = created_at + duration;
516        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
517        tracing::Span::current().record(
518            "user_email_authentication_code.id",
519            tracing::field::display(id),
520        );
521
522        sqlx::query!(
523            r#"
524                INSERT INTO user_email_authentication_codes
525                  ( user_email_authentication_code_id
526                  , user_email_authentication_id
527                  , code
528                  , created_at
529                  , expires_at
530                  )
531                VALUES ($1, $2, $3, $4, $5)
532            "#,
533            Uuid::from(id),
534            Uuid::from(user_email_authentication.id),
535            &code,
536            created_at,
537            expires_at,
538        )
539        .traced()
540        .execute(&mut *self.conn)
541        .await?;
542
543        Ok(UserEmailAuthenticationCode {
544            id,
545            user_email_authentication_id: user_email_authentication.id,
546            code,
547            created_at,
548            expires_at,
549        })
550    }
551
552    #[tracing::instrument(
553        name = "db.user_email.lookup_authentication",
554        skip_all,
555        fields(
556            db.query.text,
557            user_email_authentication.id = %id,
558        ),
559        err,
560    )]
561    async fn lookup_authentication(
562        &mut self,
563        id: Ulid,
564    ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
565        let res = sqlx::query_as!(
566            UserEmailAuthenticationLookup,
567            r#"
568                SELECT user_email_authentication_id
569                     , user_session_id
570                     , user_registration_id
571                     , email
572                     , created_at
573                     , completed_at
574                FROM user_email_authentications
575                WHERE user_email_authentication_id = $1
576            "#,
577            Uuid::from(id),
578        )
579        .traced()
580        .fetch_optional(&mut *self.conn)
581        .await?;
582
583        Ok(res.map(UserEmailAuthentication::from))
584    }
585
586    #[tracing::instrument(
587        name = "db.user_email.find_authentication_by_code",
588        skip_all,
589        fields(
590            db.query.text,
591            %authentication.id,
592            user_email_authentication_code.code = code,
593        ),
594        err,
595    )]
596    async fn find_authentication_code(
597        &mut self,
598        authentication: &UserEmailAuthentication,
599        code: &str,
600    ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
601        let res = sqlx::query_as!(
602            UserEmailAuthenticationCodeLookup,
603            r#"
604                SELECT user_email_authentication_code_id
605                     , user_email_authentication_id
606                     , code
607                     , created_at
608                     , expires_at
609                FROM user_email_authentication_codes
610                WHERE user_email_authentication_id = $1
611                  AND code = $2
612            "#,
613            Uuid::from(authentication.id),
614            code,
615        )
616        .traced()
617        .fetch_optional(&mut *self.conn)
618        .await?;
619
620        Ok(res.map(UserEmailAuthenticationCode::from))
621    }
622
623    #[tracing::instrument(
624        name = "db.user_email.complete_email_authentication",
625        skip_all,
626        fields(
627            db.query.text,
628            %user_email_authentication.id,
629            %user_email_authentication.email,
630            %user_email_authentication_code.id,
631            %user_email_authentication_code.code,
632        ),
633        err,
634    )]
635    async fn complete_authentication(
636        &mut self,
637        clock: &dyn Clock,
638        mut user_email_authentication: UserEmailAuthentication,
639        user_email_authentication_code: &UserEmailAuthenticationCode,
640    ) -> Result<UserEmailAuthentication, Self::Error> {
641        // We technically don't use the authentication code here (other than
642        // recording it in the span), but this is to make sure the caller has
643        // fetched one before calling this
644        let completed_at = clock.now();
645
646        // We'll assume the caller has checked that completed_at is None, so in case
647        // they haven't, the update will not affect any rows, which will raise
648        // an error
649        let res = sqlx::query!(
650            r#"
651                UPDATE user_email_authentications
652                SET completed_at = $2
653                WHERE user_email_authentication_id = $1
654                  AND completed_at IS NULL
655            "#,
656            Uuid::from(user_email_authentication.id),
657            completed_at,
658        )
659        .traced()
660        .execute(&mut *self.conn)
661        .await?;
662
663        DatabaseError::ensure_affected_rows(&res, 1)?;
664
665        user_email_authentication.completed_at = Some(completed_at);
666        Ok(user_email_authentication)
667    }
668}