mas_storage_pg/user/
recovery.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Duration, Utc};
12use mas_data_model::{Clock, UserEmail, UserRecoverySession, UserRecoveryTicket};
13use mas_storage::user::UserRecoveryRepository;
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, ExecuteExt};
20
21/// An implementation of [`UserRecoveryRepository`] for a PostgreSQL connection
22pub struct PgUserRecoveryRepository<'c> {
23    conn: &'c mut PgConnection,
24}
25
26impl<'c> PgUserRecoveryRepository<'c> {
27    /// Create a new [`PgUserRecoveryRepository`] from an active PostgreSQL
28    /// connection
29    pub fn new(conn: &'c mut PgConnection) -> Self {
30        Self { conn }
31    }
32}
33
34struct UserRecoverySessionRow {
35    user_recovery_session_id: Uuid,
36    email: String,
37    user_agent: String,
38    ip_address: Option<IpAddr>,
39    locale: String,
40    created_at: DateTime<Utc>,
41    consumed_at: Option<DateTime<Utc>>,
42}
43
44impl From<UserRecoverySessionRow> for UserRecoverySession {
45    fn from(row: UserRecoverySessionRow) -> Self {
46        UserRecoverySession {
47            id: row.user_recovery_session_id.into(),
48            email: row.email,
49            user_agent: row.user_agent,
50            ip_address: row.ip_address,
51            locale: row.locale,
52            created_at: row.created_at,
53            consumed_at: row.consumed_at,
54        }
55    }
56}
57
58struct UserRecoveryTicketRow {
59    user_recovery_ticket_id: Uuid,
60    user_recovery_session_id: Uuid,
61    user_email_id: Uuid,
62    ticket: String,
63    created_at: DateTime<Utc>,
64    expires_at: DateTime<Utc>,
65}
66
67impl From<UserRecoveryTicketRow> for UserRecoveryTicket {
68    fn from(row: UserRecoveryTicketRow) -> Self {
69        Self {
70            id: row.user_recovery_ticket_id.into(),
71            user_recovery_session_id: row.user_recovery_session_id.into(),
72            user_email_id: row.user_email_id.into(),
73            ticket: row.ticket,
74            created_at: row.created_at,
75            expires_at: row.expires_at,
76        }
77    }
78}
79
80#[async_trait]
81impl UserRecoveryRepository for PgUserRecoveryRepository<'_> {
82    type Error = DatabaseError;
83
84    #[tracing::instrument(
85        name = "db.user_recovery.lookup_session",
86        skip_all,
87        fields(
88            db.query.text,
89            user_recovery_session.id = %id,
90        ),
91        err,
92    )]
93    async fn lookup_session(
94        &mut self,
95        id: Ulid,
96    ) -> Result<Option<UserRecoverySession>, Self::Error> {
97        let row = sqlx::query_as!(
98            UserRecoverySessionRow,
99            r#"
100                SELECT
101                      user_recovery_session_id
102                    , email
103                    , user_agent
104                    , ip_address as "ip_address: IpAddr"
105                    , locale
106                    , created_at
107                    , consumed_at
108                FROM user_recovery_sessions
109                WHERE user_recovery_session_id = $1
110            "#,
111            Uuid::from(id),
112        )
113        .traced()
114        .fetch_optional(&mut *self.conn)
115        .await?;
116
117        let Some(row) = row else {
118            return Ok(None);
119        };
120
121        Ok(Some(row.into()))
122    }
123
124    #[tracing::instrument(
125        name = "db.user_recovery.add_session",
126        skip_all,
127        fields(
128            db.query.text,
129            user_recovery_session.id,
130            user_recovery_session.email = email,
131            user_recovery_session.user_agent = user_agent,
132            user_recovery_session.ip_address = ip_address.map(|ip| ip.to_string()),
133        )
134    )]
135    async fn add_session(
136        &mut self,
137        rng: &mut (dyn RngCore + Send),
138        clock: &dyn Clock,
139        email: String,
140        user_agent: String,
141        ip_address: Option<IpAddr>,
142        locale: String,
143    ) -> Result<UserRecoverySession, Self::Error> {
144        let created_at = clock.now();
145        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
146        tracing::Span::current().record("user_recovery_session.id", tracing::field::display(id));
147        sqlx::query!(
148            r#"
149                INSERT INTO user_recovery_sessions (
150                      user_recovery_session_id
151                    , email
152                    , user_agent
153                    , ip_address
154                    , locale
155                    , created_at
156                )
157                VALUES ($1, $2, $3, $4, $5, $6)
158            "#,
159            Uuid::from(id),
160            &email,
161            &*user_agent,
162            ip_address as Option<IpAddr>,
163            &locale,
164            created_at,
165        )
166        .traced()
167        .execute(&mut *self.conn)
168        .await?;
169
170        let user_recovery_session = UserRecoverySession {
171            id,
172            email,
173            user_agent,
174            ip_address,
175            locale,
176            created_at,
177            consumed_at: None,
178        };
179
180        Ok(user_recovery_session)
181    }
182
183    #[tracing::instrument(
184        name = "db.user_recovery.find_ticket",
185        skip_all,
186        fields(
187            db.query.text,
188            user_recovery_ticket.id = ticket,
189        ),
190        err,
191    )]
192    async fn find_ticket(
193        &mut self,
194        ticket: &str,
195    ) -> Result<Option<UserRecoveryTicket>, Self::Error> {
196        let row = sqlx::query_as!(
197            UserRecoveryTicketRow,
198            r#"
199                SELECT
200                      user_recovery_ticket_id
201                    , user_recovery_session_id
202                    , user_email_id
203                    , ticket
204                    , created_at
205                    , expires_at
206                FROM user_recovery_tickets
207                WHERE ticket = $1
208            "#,
209            ticket,
210        )
211        .traced()
212        .fetch_optional(&mut *self.conn)
213        .await?;
214
215        let Some(row) = row else {
216            return Ok(None);
217        };
218
219        Ok(Some(row.into()))
220    }
221
222    #[tracing::instrument(
223        name = "db.user_recovery.add_ticket",
224        skip_all,
225        fields(
226            db.query.text,
227            user_recovery_ticket.id,
228            user_recovery_ticket.id = ticket,
229            %user_recovery_session.id,
230            %user_email.id,
231        )
232    )]
233    async fn add_ticket(
234        &mut self,
235        rng: &mut (dyn RngCore + Send),
236        clock: &dyn Clock,
237        user_recovery_session: &UserRecoverySession,
238        user_email: &UserEmail,
239        ticket: String,
240    ) -> Result<UserRecoveryTicket, Self::Error> {
241        let created_at = clock.now();
242        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
243        tracing::Span::current().record("user_recovery_ticket.id", tracing::field::display(id));
244
245        // TODO: move that to a parameter
246        let expires_at = created_at + Duration::minutes(10);
247
248        sqlx::query!(
249            r#"
250                INSERT INTO user_recovery_tickets (
251                      user_recovery_ticket_id
252                    , user_recovery_session_id
253                    , user_email_id
254                    , ticket
255                    , created_at
256                    , expires_at
257                )
258                VALUES ($1, $2, $3, $4, $5, $6)
259            "#,
260            Uuid::from(id),
261            Uuid::from(user_recovery_session.id),
262            Uuid::from(user_email.id),
263            &ticket,
264            created_at,
265            expires_at,
266        )
267        .traced()
268        .execute(&mut *self.conn)
269        .await?;
270
271        let ticket = UserRecoveryTicket {
272            id,
273            user_recovery_session_id: user_recovery_session.id,
274            user_email_id: user_email.id,
275            ticket,
276            created_at,
277            expires_at,
278        };
279
280        Ok(ticket)
281    }
282
283    #[tracing::instrument(
284        name = "db.user_recovery.consume_ticket",
285        skip_all,
286        fields(
287            db.query.text,
288            %user_recovery_ticket.id,
289            user_email.id = %user_recovery_ticket.user_email_id,
290            %user_recovery_session.id,
291            %user_recovery_session.email,
292        ),
293        err,
294    )]
295    async fn consume_ticket(
296        &mut self,
297        clock: &dyn Clock,
298        user_recovery_ticket: UserRecoveryTicket,
299        mut user_recovery_session: UserRecoverySession,
300    ) -> Result<UserRecoverySession, Self::Error> {
301        // We don't really use the ticket, we just want to make sure we drop it
302        let _ = user_recovery_ticket;
303
304        // This should have been checked by the caller
305        if user_recovery_session.consumed_at.is_some() {
306            return Err(DatabaseError::invalid_operation());
307        }
308
309        let consumed_at = clock.now();
310
311        let res = sqlx::query!(
312            r#"
313                UPDATE user_recovery_sessions
314                SET consumed_at = $1
315                WHERE user_recovery_session_id = $2
316            "#,
317            consumed_at,
318            Uuid::from(user_recovery_session.id),
319        )
320        .traced()
321        .execute(&mut *self.conn)
322        .await?;
323
324        user_recovery_session.consumed_at = Some(consumed_at);
325
326        DatabaseError::ensure_affected_rows(&res, 1)?;
327
328        Ok(user_recovery_session)
329    }
330
331    #[tracing::instrument(
332        name = "db.user_recovery.cleanup",
333        skip_all,
334        fields(
335            db.query.text,
336            since = since.map(tracing::field::display),
337            until = %until,
338            limit = limit,
339        ),
340        err,
341    )]
342    async fn cleanup(
343        &mut self,
344        since: Option<Ulid>,
345        until: Ulid,
346        limit: usize,
347    ) -> Result<(usize, Option<Ulid>), Self::Error> {
348        // Use ULID cursor-based pagination. Since ULIDs contain a timestamp,
349        // we can efficiently delete old sessions without needing an index.
350        // `MAX(uuid)` isn't a thing in Postgres, so we aggregate on the client side.
351        let res = sqlx::query_scalar!(
352            r#"
353                WITH to_delete AS (
354                    SELECT user_recovery_session_id
355                    FROM user_recovery_sessions
356                    WHERE ($1::uuid IS NULL OR user_recovery_session_id > $1)
357                    AND user_recovery_session_id <= $2
358                    ORDER BY user_recovery_session_id
359                    LIMIT $3
360                )
361                DELETE FROM user_recovery_sessions
362                USING to_delete
363                WHERE user_recovery_sessions.user_recovery_session_id = to_delete.user_recovery_session_id
364                RETURNING user_recovery_sessions.user_recovery_session_id
365            "#,
366            since.map(Uuid::from),
367            Uuid::from(until),
368            i64::try_from(limit).unwrap_or(i64::MAX)
369        )
370        .traced()
371        .fetch_all(&mut *self.conn)
372        .await?;
373
374        let count = res.len();
375        let max_id = res.into_iter().max();
376
377        Ok((count, max_id.map(Ulid::from)))
378    }
379}