mas_storage_pg/user/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the user-related
8//! repositories
9
10use async_trait::async_trait;
11use mas_data_model::User;
12use mas_storage::{
13    Clock,
14    user::{UserFilter, UserRepository},
15};
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24    DatabaseError,
25    filter::{Filter, StatementExt},
26    iden::Users,
27    pagination::QueryBuilderExt,
28    tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod session;
36mod terms;
37
38#[cfg(test)]
39mod tests;
40
41pub use self::{
42    email::PgUserEmailRepository, password::PgUserPasswordRepository,
43    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
44    session::PgBrowserSessionRepository, terms::PgUserTermsRepository,
45};
46
47/// An implementation of [`UserRepository`] for a PostgreSQL connection
48pub struct PgUserRepository<'c> {
49    conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
54    pub fn new(conn: &'c mut PgConnection) -> Self {
55        Self { conn }
56    }
57}
58
59mod priv_ {
60    // The enum_def macro generates a public enum, which we don't want, because it
61    // triggers the missing docs warning
62    #![allow(missing_docs)]
63
64    use chrono::{DateTime, Utc};
65    use sea_query::enum_def;
66    use uuid::Uuid;
67
68    #[derive(Debug, Clone, sqlx::FromRow)]
69    #[enum_def]
70    pub(super) struct UserLookup {
71        pub(super) user_id: Uuid,
72        pub(super) username: String,
73        pub(super) created_at: DateTime<Utc>,
74        pub(super) locked_at: Option<DateTime<Utc>>,
75        pub(super) deactivated_at: Option<DateTime<Utc>>,
76        pub(super) can_request_admin: bool,
77    }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83    fn from(value: UserLookup) -> Self {
84        let id = value.user_id.into();
85        Self {
86            id,
87            username: value.username,
88            sub: id.to_string(),
89            created_at: value.created_at,
90            locked_at: value.locked_at,
91            deactivated_at: value.deactivated_at,
92            can_request_admin: value.can_request_admin,
93        }
94    }
95}
96
97impl Filter for UserFilter<'_> {
98    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
99        sea_query::Condition::all()
100            .add_option(self.state().map(|state| {
101                match state {
102                    mas_storage::user::UserState::Deactivated => {
103                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
104                    }
105                    mas_storage::user::UserState::Locked => {
106                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
107                    }
108                    mas_storage::user::UserState::Active => {
109                        Expr::col((Users::Table, Users::LockedAt))
110                            .is_null()
111                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
112                    }
113                }
114            }))
115            .add_option(self.can_request_admin().map(|can_request_admin| {
116                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
117            }))
118    }
119}
120
121#[async_trait]
122impl UserRepository for PgUserRepository<'_> {
123    type Error = DatabaseError;
124
125    #[tracing::instrument(
126        name = "db.user.lookup",
127        skip_all,
128        fields(
129            db.query.text,
130            user.id = %id,
131        ),
132        err,
133    )]
134    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
135        let res = sqlx::query_as!(
136            UserLookup,
137            r#"
138                SELECT user_id
139                     , username
140                     , created_at
141                     , locked_at
142                     , deactivated_at
143                     , can_request_admin
144                FROM users
145                WHERE user_id = $1
146            "#,
147            Uuid::from(id),
148        )
149        .traced()
150        .fetch_optional(&mut *self.conn)
151        .await?;
152
153        let Some(res) = res else { return Ok(None) };
154
155        Ok(Some(res.into()))
156    }
157
158    #[tracing::instrument(
159        name = "db.user.find_by_username",
160        skip_all,
161        fields(
162            db.query.text,
163            user.username = username,
164        ),
165        err,
166    )]
167    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
168        let res = sqlx::query_as!(
169            UserLookup,
170            r#"
171                SELECT user_id
172                     , username
173                     , created_at
174                     , locked_at
175                     , deactivated_at
176                     , can_request_admin
177                FROM users
178                WHERE username = $1
179            "#,
180            username,
181        )
182        .traced()
183        .fetch_optional(&mut *self.conn)
184        .await?;
185
186        let Some(res) = res else { return Ok(None) };
187
188        Ok(Some(res.into()))
189    }
190
191    #[tracing::instrument(
192        name = "db.user.add",
193        skip_all,
194        fields(
195            db.query.text,
196            user.username = username,
197            user.id,
198        ),
199        err,
200    )]
201    async fn add(
202        &mut self,
203        rng: &mut (dyn RngCore + Send),
204        clock: &dyn Clock,
205        username: String,
206    ) -> Result<User, Self::Error> {
207        let created_at = clock.now();
208        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
209        tracing::Span::current().record("user.id", tracing::field::display(id));
210
211        let res = sqlx::query!(
212            r#"
213                INSERT INTO users (user_id, username, created_at)
214                VALUES ($1, $2, $3)
215                ON CONFLICT (username) DO NOTHING
216            "#,
217            Uuid::from(id),
218            username,
219            created_at,
220        )
221        .traced()
222        .execute(&mut *self.conn)
223        .await?;
224
225        // If the user already exists, want to return an error but not poison the
226        // transaction
227        DatabaseError::ensure_affected_rows(&res, 1)?;
228
229        Ok(User {
230            id,
231            username,
232            sub: id.to_string(),
233            created_at,
234            locked_at: None,
235            deactivated_at: None,
236            can_request_admin: false,
237        })
238    }
239
240    #[tracing::instrument(
241        name = "db.user.exists",
242        skip_all,
243        fields(
244            db.query.text,
245            user.username = username,
246        ),
247        err,
248    )]
249    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
250        let exists = sqlx::query_scalar!(
251            r#"
252                SELECT EXISTS(
253                    SELECT 1 FROM users WHERE username = $1
254                ) AS "exists!"
255            "#,
256            username
257        )
258        .traced()
259        .fetch_one(&mut *self.conn)
260        .await?;
261
262        Ok(exists)
263    }
264
265    #[tracing::instrument(
266        name = "db.user.lock",
267        skip_all,
268        fields(
269            db.query.text,
270            %user.id,
271        ),
272        err,
273    )]
274    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
275        if user.locked_at.is_some() {
276            return Ok(user);
277        }
278
279        let locked_at = clock.now();
280        let res = sqlx::query!(
281            r#"
282                UPDATE users
283                SET locked_at = $1
284                WHERE user_id = $2
285            "#,
286            locked_at,
287            Uuid::from(user.id),
288        )
289        .traced()
290        .execute(&mut *self.conn)
291        .await?;
292
293        DatabaseError::ensure_affected_rows(&res, 1)?;
294
295        user.locked_at = Some(locked_at);
296
297        Ok(user)
298    }
299
300    #[tracing::instrument(
301        name = "db.user.unlock",
302        skip_all,
303        fields(
304            db.query.text,
305            %user.id,
306        ),
307        err,
308    )]
309    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
310        if user.locked_at.is_none() {
311            return Ok(user);
312        }
313
314        let res = sqlx::query!(
315            r#"
316                UPDATE users
317                SET locked_at = NULL
318                WHERE user_id = $1
319            "#,
320            Uuid::from(user.id),
321        )
322        .traced()
323        .execute(&mut *self.conn)
324        .await?;
325
326        DatabaseError::ensure_affected_rows(&res, 1)?;
327
328        user.locked_at = None;
329
330        Ok(user)
331    }
332
333    #[tracing::instrument(
334        name = "db.user.deactivate",
335        skip_all,
336        fields(
337            db.query.text,
338            %user.id,
339        ),
340        err,
341    )]
342    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
343        if user.deactivated_at.is_some() {
344            return Ok(user);
345        }
346
347        let deactivated_at = clock.now();
348        let res = sqlx::query!(
349            r#"
350                UPDATE users
351                SET deactivated_at = $2
352                WHERE user_id = $1
353                  AND deactivated_at IS NULL
354            "#,
355            Uuid::from(user.id),
356            deactivated_at,
357        )
358        .traced()
359        .execute(&mut *self.conn)
360        .await?;
361
362        DatabaseError::ensure_affected_rows(&res, 1)?;
363
364        user.deactivated_at = Some(user.created_at);
365
366        Ok(user)
367    }
368
369    #[tracing::instrument(
370        name = "db.user.set_can_request_admin",
371        skip_all,
372        fields(
373            db.query.text,
374            %user.id,
375            user.can_request_admin = can_request_admin,
376        ),
377        err,
378    )]
379    async fn set_can_request_admin(
380        &mut self,
381        mut user: User,
382        can_request_admin: bool,
383    ) -> Result<User, Self::Error> {
384        let res = sqlx::query!(
385            r#"
386                UPDATE users
387                SET can_request_admin = $2
388                WHERE user_id = $1
389            "#,
390            Uuid::from(user.id),
391            can_request_admin,
392        )
393        .traced()
394        .execute(&mut *self.conn)
395        .await?;
396
397        DatabaseError::ensure_affected_rows(&res, 1)?;
398
399        user.can_request_admin = can_request_admin;
400
401        Ok(user)
402    }
403
404    #[tracing::instrument(
405        name = "db.user.list",
406        skip_all,
407        fields(
408            db.query.text,
409        ),
410        err,
411    )]
412    async fn list(
413        &mut self,
414        filter: UserFilter<'_>,
415        pagination: mas_storage::Pagination,
416    ) -> Result<mas_storage::Page<User>, Self::Error> {
417        let (sql, arguments) = Query::select()
418            .expr_as(
419                Expr::col((Users::Table, Users::UserId)),
420                UserLookupIden::UserId,
421            )
422            .expr_as(
423                Expr::col((Users::Table, Users::Username)),
424                UserLookupIden::Username,
425            )
426            .expr_as(
427                Expr::col((Users::Table, Users::CreatedAt)),
428                UserLookupIden::CreatedAt,
429            )
430            .expr_as(
431                Expr::col((Users::Table, Users::LockedAt)),
432                UserLookupIden::LockedAt,
433            )
434            .expr_as(
435                Expr::col((Users::Table, Users::DeactivatedAt)),
436                UserLookupIden::DeactivatedAt,
437            )
438            .expr_as(
439                Expr::col((Users::Table, Users::CanRequestAdmin)),
440                UserLookupIden::CanRequestAdmin,
441            )
442            .from(Users::Table)
443            .apply_filter(filter)
444            .generate_pagination((Users::Table, Users::UserId), pagination)
445            .build_sqlx(PostgresQueryBuilder);
446
447        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
448            .traced()
449            .fetch_all(&mut *self.conn)
450            .await?;
451
452        let page = pagination.process(edges).map(User::from);
453
454        Ok(page)
455    }
456
457    #[tracing::instrument(
458        name = "db.user.count",
459        skip_all,
460        fields(
461            db.query.text,
462        ),
463        err,
464    )]
465    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
466        let (sql, arguments) = Query::select()
467            .expr(Expr::col((Users::Table, Users::UserId)).count())
468            .from(Users::Table)
469            .apply_filter(filter)
470            .build_sqlx(PostgresQueryBuilder);
471
472        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
473            .traced()
474            .fetch_one(&mut *self.conn)
475            .await?;
476
477        count
478            .try_into()
479            .map_err(DatabaseError::to_invalid_operation)
480    }
481
482    #[tracing::instrument(
483        name = "db.user.acquire_lock_for_sync",
484        skip_all,
485        fields(
486            db.query.text,
487            user.id = %user.id,
488        ),
489        err,
490    )]
491    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
492        // XXX: this lock isn't stictly scoped to users, but as we don't use many
493        // postgres advisory locks, it's fine for now. Later on, we could use row-level
494        // locks to make sure we don't get into trouble
495
496        // Convert the user ID to a u128 and grab the lower 64 bits
497        // As this includes 64bit of the random part of the ULID, it should be random
498        // enough to not collide
499        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
500
501        // Use a PG advisory lock, which will be released when the transaction is
502        // committed or rolled back
503        sqlx::query!(
504            r#"
505                SELECT pg_advisory_xact_lock($1)
506            "#,
507            lock_id,
508        )
509        .traced()
510        .execute(&mut *self.conn)
511        .await?;
512
513        Ok(())
514    }
515}