mas_storage_pg/user/
password.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{Password, User};
10use mas_storage::{Clock, user::UserPasswordRepository};
11use rand::RngCore;
12use sqlx::PgConnection;
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
17
18/// An implementation of [`UserPasswordRepository`] for a PostgreSQL connection
19pub struct PgUserPasswordRepository<'c> {
20    conn: &'c mut PgConnection,
21}
22
23impl<'c> PgUserPasswordRepository<'c> {
24    /// Create a new [`PgUserPasswordRepository`] from an active PostgreSQL
25    /// connection
26    pub fn new(conn: &'c mut PgConnection) -> Self {
27        Self { conn }
28    }
29}
30
31struct UserPasswordLookup {
32    user_password_id: Uuid,
33    hashed_password: String,
34    version: i32,
35    upgraded_from_id: Option<Uuid>,
36    created_at: DateTime<Utc>,
37}
38
39#[async_trait]
40impl UserPasswordRepository for PgUserPasswordRepository<'_> {
41    type Error = DatabaseError;
42
43    #[tracing::instrument(
44        name = "db.user_password.active",
45        skip_all,
46        fields(
47            db.query.text,
48            %user.id,
49            %user.username,
50        ),
51        err,
52    )]
53    async fn active(&mut self, user: &User) -> Result<Option<Password>, Self::Error> {
54        let res = sqlx::query_as!(
55            UserPasswordLookup,
56            r#"
57                SELECT up.user_password_id
58                     , up.hashed_password
59                     , up.version
60                     , up.upgraded_from_id
61                     , up.created_at
62                FROM user_passwords up
63                WHERE up.user_id = $1
64                ORDER BY up.created_at DESC
65                LIMIT 1
66            "#,
67            Uuid::from(user.id),
68        )
69        .traced()
70        .fetch_optional(&mut *self.conn)
71        .await?;
72
73        let Some(res) = res else { return Ok(None) };
74
75        let id = Ulid::from(res.user_password_id);
76
77        let version = res.version.try_into().map_err(|e| {
78            DatabaseInconsistencyError::on("user_passwords")
79                .column("version")
80                .row(id)
81                .source(e)
82        })?;
83
84        let upgraded_from_id = res.upgraded_from_id.map(Ulid::from);
85        let created_at = res.created_at;
86        let hashed_password = res.hashed_password;
87
88        Ok(Some(Password {
89            id,
90            hashed_password,
91            version,
92            upgraded_from_id,
93            created_at,
94        }))
95    }
96
97    #[tracing::instrument(
98        name = "db.user_password.add",
99        skip_all,
100        fields(
101            db.query.text,
102            %user.id,
103            %user.username,
104            user_password.id,
105            user_password.version = version,
106        ),
107        err,
108    )]
109    async fn add(
110        &mut self,
111        rng: &mut (dyn RngCore + Send),
112        clock: &dyn Clock,
113        user: &User,
114        version: u16,
115        hashed_password: String,
116        upgraded_from: Option<&Password>,
117    ) -> Result<Password, Self::Error> {
118        let created_at = clock.now();
119        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
120        tracing::Span::current().record("user_password.id", tracing::field::display(id));
121
122        let upgraded_from_id = upgraded_from.map(|p| p.id);
123
124        sqlx::query!(
125            r#"
126                INSERT INTO user_passwords
127                    (user_password_id, user_id, hashed_password, version, upgraded_from_id, created_at)
128                VALUES ($1, $2, $3, $4, $5, $6)
129            "#,
130            Uuid::from(id),
131            Uuid::from(user.id),
132            hashed_password,
133            i32::from(version),
134            upgraded_from_id.map(Uuid::from),
135            created_at,
136        )
137        .traced()
138        .execute(&mut *self.conn)
139        .await?;
140
141        Ok(Password {
142            id,
143            hashed_password,
144            version,
145            upgraded_from_id,
146            created_at,
147        })
148    }
149}