mas_storage_pg/user/
session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13    Authentication, AuthenticationMethod, BrowserSession, Clock, Password,
14    UpstreamOAuthAuthorizationSession, User,
15};
16use mas_storage::{
17    Page, Pagination,
18    pagination::Node,
19    user::{BrowserSessionFilter, BrowserSessionRepository},
20};
21use rand::RngCore;
22use sea_query::{Expr, PostgresQueryBuilder, Query};
23use sea_query_binder::SqlxBinder;
24use sqlx::PgConnection;
25use ulid::Ulid;
26use uuid::Uuid;
27
28use crate::{
29    DatabaseError, DatabaseInconsistencyError,
30    filter::StatementExt,
31    iden::{UpstreamOAuthAuthorizationSessions, UserSessions, Users},
32    pagination::QueryBuilderExt,
33    tracing::ExecuteExt,
34};
35
36/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
37/// connection
38pub struct PgBrowserSessionRepository<'c> {
39    conn: &'c mut PgConnection,
40}
41
42impl<'c> PgBrowserSessionRepository<'c> {
43    /// Create a new [`PgBrowserSessionRepository`] from an active PostgreSQL
44    /// connection
45    pub fn new(conn: &'c mut PgConnection) -> Self {
46        Self { conn }
47    }
48}
49
50#[allow(clippy::struct_field_names)]
51#[derive(sqlx::FromRow)]
52#[sea_query::enum_def]
53struct SessionLookup {
54    user_session_id: Uuid,
55    user_session_created_at: DateTime<Utc>,
56    user_session_finished_at: Option<DateTime<Utc>>,
57    user_session_user_agent: Option<String>,
58    user_session_last_active_at: Option<DateTime<Utc>>,
59    user_session_last_active_ip: Option<IpAddr>,
60    user_id: Uuid,
61    user_username: String,
62    user_created_at: DateTime<Utc>,
63    user_locked_at: Option<DateTime<Utc>>,
64    user_deactivated_at: Option<DateTime<Utc>>,
65    user_can_request_admin: bool,
66    user_is_guest: bool,
67}
68
69impl Node<Ulid> for SessionLookup {
70    fn cursor(&self) -> Ulid {
71        self.user_id.into()
72    }
73}
74
75impl TryFrom<SessionLookup> for BrowserSession {
76    type Error = DatabaseInconsistencyError;
77
78    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
79        let id = Ulid::from(value.user_id);
80        let user = User {
81            id,
82            username: value.user_username,
83            sub: id.to_string(),
84            created_at: value.user_created_at,
85            locked_at: value.user_locked_at,
86            deactivated_at: value.user_deactivated_at,
87            can_request_admin: value.user_can_request_admin,
88            is_guest: value.user_is_guest,
89        };
90
91        Ok(BrowserSession {
92            id: value.user_session_id.into(),
93            user,
94            created_at: value.user_session_created_at,
95            finished_at: value.user_session_finished_at,
96            user_agent: value.user_session_user_agent,
97            last_active_at: value.user_session_last_active_at,
98            last_active_ip: value.user_session_last_active_ip,
99        })
100    }
101}
102
103struct AuthenticationLookup {
104    user_session_authentication_id: Uuid,
105    created_at: DateTime<Utc>,
106    user_password_id: Option<Uuid>,
107    upstream_oauth_authorization_session_id: Option<Uuid>,
108}
109
110impl TryFrom<AuthenticationLookup> for Authentication {
111    type Error = DatabaseInconsistencyError;
112
113    fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
114        let id = Ulid::from(value.user_session_authentication_id);
115        let authentication_method = match (
116            value.user_password_id.map(Into::into),
117            value
118                .upstream_oauth_authorization_session_id
119                .map(Into::into),
120        ) {
121            (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
122            (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
123                upstream_oauth2_session_id,
124            },
125            (None, None) => AuthenticationMethod::Unknown,
126            _ => {
127                return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
128            }
129        };
130
131        Ok(Authentication {
132            id,
133            created_at: value.created_at,
134            authentication_method,
135        })
136    }
137}
138
139impl crate::filter::Filter for BrowserSessionFilter<'_> {
140    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
141        sea_query::Condition::all()
142            .add_option(self.user().map(|user| {
143                Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
144            }))
145            .add_option(self.state().map(|state| {
146                if state.is_active() {
147                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
148                } else {
149                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
150                }
151            }))
152            .add_option(self.last_active_after().map(|last_active_after| {
153                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
154            }))
155            .add_option(self.last_active_before().map(|last_active_before| {
156                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
157            }))
158            .add_option(self.linked_to_upstream_sessions().map(|filter| {
159                Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
160                    Query::select()
161                        .expr(Expr::col((
162                            UpstreamOAuthAuthorizationSessions::Table,
163                            UpstreamOAuthAuthorizationSessions::UserSessionId,
164                        )))
165                        .from(UpstreamOAuthAuthorizationSessions::Table)
166                        .apply_filter(filter)
167                        .take(),
168                )
169            }))
170    }
171}
172
173#[async_trait]
174impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
175    type Error = DatabaseError;
176
177    #[tracing::instrument(
178        name = "db.browser_session.lookup",
179        skip_all,
180        fields(
181            db.query.text,
182            user_session.id = %id,
183        ),
184        err,
185    )]
186    async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
187        let res = sqlx::query_as!(
188            SessionLookup,
189            r#"
190                SELECT s.user_session_id
191                     , s.created_at            AS "user_session_created_at"
192                     , s.finished_at           AS "user_session_finished_at"
193                     , s.user_agent            AS "user_session_user_agent"
194                     , s.last_active_at        AS "user_session_last_active_at"
195                     , s.last_active_ip        AS "user_session_last_active_ip: IpAddr"
196                     , u.user_id
197                     , u.username              AS "user_username"
198                     , u.created_at            AS "user_created_at"
199                     , u.locked_at             AS "user_locked_at"
200                     , u.deactivated_at        AS "user_deactivated_at"
201                     , u.can_request_admin     AS "user_can_request_admin"
202                     , u.is_guest              AS "user_is_guest"
203                FROM user_sessions s
204                INNER JOIN users u
205                    USING (user_id)
206                WHERE s.user_session_id = $1
207            "#,
208            Uuid::from(id),
209        )
210        .traced()
211        .fetch_optional(&mut *self.conn)
212        .await?;
213
214        let Some(res) = res else { return Ok(None) };
215
216        Ok(Some(res.try_into()?))
217    }
218
219    #[tracing::instrument(
220        name = "db.browser_session.add",
221        skip_all,
222        fields(
223            db.query.text,
224            %user.id,
225            user_session.id,
226        ),
227        err,
228    )]
229    async fn add(
230        &mut self,
231        rng: &mut (dyn RngCore + Send),
232        clock: &dyn Clock,
233        user: &User,
234        user_agent: Option<String>,
235    ) -> Result<BrowserSession, Self::Error> {
236        let created_at = clock.now();
237        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
238        tracing::Span::current().record("user_session.id", tracing::field::display(id));
239
240        sqlx::query!(
241            r#"
242                INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
243                VALUES ($1, $2, $3, $4)
244            "#,
245            Uuid::from(id),
246            Uuid::from(user.id),
247            created_at,
248            user_agent.as_deref(),
249        )
250        .traced()
251        .execute(&mut *self.conn)
252        .await?;
253
254        let session = BrowserSession {
255            id,
256            // XXX
257            user: user.clone(),
258            created_at,
259            finished_at: None,
260            user_agent,
261            last_active_at: None,
262            last_active_ip: None,
263        };
264
265        Ok(session)
266    }
267
268    #[tracing::instrument(
269        name = "db.browser_session.finish",
270        skip_all,
271        fields(
272            db.query.text,
273            %user_session.id,
274        ),
275        err,
276    )]
277    async fn finish(
278        &mut self,
279        clock: &dyn Clock,
280        mut user_session: BrowserSession,
281    ) -> Result<BrowserSession, Self::Error> {
282        let finished_at = clock.now();
283        let res = sqlx::query!(
284            r#"
285                UPDATE user_sessions
286                SET finished_at = $1
287                WHERE user_session_id = $2
288            "#,
289            finished_at,
290            Uuid::from(user_session.id),
291        )
292        .traced()
293        .execute(&mut *self.conn)
294        .await?;
295
296        user_session.finished_at = Some(finished_at);
297
298        DatabaseError::ensure_affected_rows(&res, 1)?;
299
300        Ok(user_session)
301    }
302
303    #[tracing::instrument(
304        name = "db.browser_session.finish_bulk",
305        skip_all,
306        fields(
307            db.query.text,
308        ),
309        err,
310    )]
311    async fn finish_bulk(
312        &mut self,
313        clock: &dyn Clock,
314        filter: BrowserSessionFilter<'_>,
315    ) -> Result<usize, Self::Error> {
316        let finished_at = clock.now();
317        let (sql, arguments) = sea_query::Query::update()
318            .table(UserSessions::Table)
319            .value(UserSessions::FinishedAt, finished_at)
320            .apply_filter(filter)
321            .build_sqlx(PostgresQueryBuilder);
322
323        let res = sqlx::query_with(&sql, arguments)
324            .traced()
325            .execute(&mut *self.conn)
326            .await?;
327
328        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
329    }
330
331    #[tracing::instrument(
332        name = "db.browser_session.list",
333        skip_all,
334        fields(
335            db.query.text,
336        ),
337        err,
338    )]
339    async fn list(
340        &mut self,
341        filter: BrowserSessionFilter<'_>,
342        pagination: Pagination,
343    ) -> Result<Page<BrowserSession>, Self::Error> {
344        let (sql, arguments) = sea_query::Query::select()
345            .expr_as(
346                Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
347                SessionLookupIden::UserSessionId,
348            )
349            .expr_as(
350                Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
351                SessionLookupIden::UserSessionCreatedAt,
352            )
353            .expr_as(
354                Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
355                SessionLookupIden::UserSessionFinishedAt,
356            )
357            .expr_as(
358                Expr::col((UserSessions::Table, UserSessions::UserAgent)),
359                SessionLookupIden::UserSessionUserAgent,
360            )
361            .expr_as(
362                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
363                SessionLookupIden::UserSessionLastActiveAt,
364            )
365            .expr_as(
366                Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
367                SessionLookupIden::UserSessionLastActiveIp,
368            )
369            .expr_as(
370                Expr::col((Users::Table, Users::UserId)),
371                SessionLookupIden::UserId,
372            )
373            .expr_as(
374                Expr::col((Users::Table, Users::Username)),
375                SessionLookupIden::UserUsername,
376            )
377            .expr_as(
378                Expr::col((Users::Table, Users::CreatedAt)),
379                SessionLookupIden::UserCreatedAt,
380            )
381            .expr_as(
382                Expr::col((Users::Table, Users::LockedAt)),
383                SessionLookupIden::UserLockedAt,
384            )
385            .expr_as(
386                Expr::col((Users::Table, Users::DeactivatedAt)),
387                SessionLookupIden::UserDeactivatedAt,
388            )
389            .expr_as(
390                Expr::col((Users::Table, Users::CanRequestAdmin)),
391                SessionLookupIden::UserCanRequestAdmin,
392            )
393            .expr_as(
394                Expr::col((Users::Table, Users::IsGuest)),
395                SessionLookupIden::UserIsGuest,
396            )
397            .from(UserSessions::Table)
398            .inner_join(
399                Users::Table,
400                Expr::col((UserSessions::Table, UserSessions::UserId))
401                    .equals((Users::Table, Users::UserId)),
402            )
403            .apply_filter(filter)
404            .generate_pagination(
405                (UserSessions::Table, UserSessions::UserSessionId),
406                pagination,
407            )
408            .build_sqlx(PostgresQueryBuilder);
409
410        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
411            .traced()
412            .fetch_all(&mut *self.conn)
413            .await?;
414
415        let page = pagination
416            .process(edges)
417            .try_map(BrowserSession::try_from)?;
418
419        Ok(page)
420    }
421
422    #[tracing::instrument(
423        name = "db.browser_session.count",
424        skip_all,
425        fields(
426            db.query.text,
427        ),
428        err,
429    )]
430    async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
431        let (sql, arguments) = sea_query::Query::select()
432            .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
433            .from(UserSessions::Table)
434            .apply_filter(filter)
435            .build_sqlx(PostgresQueryBuilder);
436
437        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
438            .traced()
439            .fetch_one(&mut *self.conn)
440            .await?;
441
442        count
443            .try_into()
444            .map_err(DatabaseError::to_invalid_operation)
445    }
446
447    #[tracing::instrument(
448        name = "db.browser_session.authenticate_with_password",
449        skip_all,
450        fields(
451            db.query.text,
452            %user_session.id,
453            %user_password.id,
454            user_session_authentication.id,
455        ),
456        err,
457    )]
458    async fn authenticate_with_password(
459        &mut self,
460        rng: &mut (dyn RngCore + Send),
461        clock: &dyn Clock,
462        user_session: &BrowserSession,
463        user_password: &Password,
464    ) -> Result<Authentication, Self::Error> {
465        let created_at = clock.now();
466        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
467        tracing::Span::current().record(
468            "user_session_authentication.id",
469            tracing::field::display(id),
470        );
471
472        sqlx::query!(
473            r#"
474                INSERT INTO user_session_authentications
475                    (user_session_authentication_id, user_session_id, created_at, user_password_id)
476                VALUES ($1, $2, $3, $4)
477            "#,
478            Uuid::from(id),
479            Uuid::from(user_session.id),
480            created_at,
481            Uuid::from(user_password.id),
482        )
483        .traced()
484        .execute(&mut *self.conn)
485        .await?;
486
487        Ok(Authentication {
488            id,
489            created_at,
490            authentication_method: AuthenticationMethod::Password {
491                user_password_id: user_password.id,
492            },
493        })
494    }
495
496    #[tracing::instrument(
497        name = "db.browser_session.authenticate_with_upstream",
498        skip_all,
499        fields(
500            db.query.text,
501            %user_session.id,
502            %upstream_oauth_session.id,
503            user_session_authentication.id,
504        ),
505        err,
506    )]
507    async fn authenticate_with_upstream(
508        &mut self,
509        rng: &mut (dyn RngCore + Send),
510        clock: &dyn Clock,
511        user_session: &BrowserSession,
512        upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
513    ) -> Result<Authentication, Self::Error> {
514        let created_at = clock.now();
515        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
516        tracing::Span::current().record(
517            "user_session_authentication.id",
518            tracing::field::display(id),
519        );
520
521        sqlx::query!(
522            r#"
523                INSERT INTO user_session_authentications
524                    (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
525                VALUES ($1, $2, $3, $4)
526            "#,
527            Uuid::from(id),
528            Uuid::from(user_session.id),
529            created_at,
530            Uuid::from(upstream_oauth_session.id),
531        )
532        .traced()
533        .execute(&mut *self.conn)
534        .await?;
535
536        Ok(Authentication {
537            id,
538            created_at,
539            authentication_method: AuthenticationMethod::UpstreamOAuth2 {
540                upstream_oauth2_session_id: upstream_oauth_session.id,
541            },
542        })
543    }
544
545    #[tracing::instrument(
546        name = "db.browser_session.get_last_authentication",
547        skip_all,
548        fields(
549            db.query.text,
550            %user_session.id,
551        ),
552        err,
553    )]
554    async fn get_last_authentication(
555        &mut self,
556        user_session: &BrowserSession,
557    ) -> Result<Option<Authentication>, Self::Error> {
558        let authentication = sqlx::query_as!(
559            AuthenticationLookup,
560            r#"
561                SELECT user_session_authentication_id
562                     , created_at
563                     , user_password_id
564                     , upstream_oauth_authorization_session_id
565                FROM user_session_authentications
566                WHERE user_session_id = $1
567                ORDER BY created_at DESC
568                LIMIT 1
569            "#,
570            Uuid::from(user_session.id),
571        )
572        .traced()
573        .fetch_optional(&mut *self.conn)
574        .await?;
575
576        let Some(authentication) = authentication else {
577            return Ok(None);
578        };
579
580        let authentication = Authentication::try_from(authentication)?;
581        Ok(Some(authentication))
582    }
583
584    #[tracing::instrument(
585        name = "db.browser_session.record_batch_activity",
586        skip_all,
587        fields(
588            db.query.text,
589        ),
590        err,
591    )]
592    async fn record_batch_activity(
593        &mut self,
594        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
595    ) -> Result<(), Self::Error> {
596        // Sort the activity by ID, so that when batching the updates, Postgres
597        // locks the rows in a stable order, preventing deadlocks
598        activities.sort_unstable();
599        let mut ids = Vec::with_capacity(activities.len());
600        let mut last_activities = Vec::with_capacity(activities.len());
601        let mut ips = Vec::with_capacity(activities.len());
602
603        for (id, last_activity, ip) in activities {
604            ids.push(Uuid::from(id));
605            last_activities.push(last_activity);
606            ips.push(ip);
607        }
608
609        let res = sqlx::query!(
610            r#"
611                UPDATE user_sessions
612                SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
613                  , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
614                FROM (
615                    SELECT *
616                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
617                        AS t(user_session_id, last_active_at, last_active_ip)
618                ) AS t
619                WHERE user_sessions.user_session_id = t.user_session_id
620            "#,
621            &ids,
622            &last_activities,
623            &ips as &[Option<IpAddr>],
624        )
625        .traced()
626        .execute(&mut *self.conn)
627        .await?;
628
629        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
630
631        Ok(())
632    }
633
634    #[tracing::instrument(
635        name = "db.browser_session.cleanup_finished",
636        skip_all,
637        fields(
638            db.query.text,
639            since = since.map(tracing::field::display),
640            until = %until,
641            limit = limit,
642        ),
643        err,
644    )]
645    async fn cleanup_finished(
646        &mut self,
647        since: Option<DateTime<Utc>>,
648        until: DateTime<Utc>,
649        limit: usize,
650    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
651        let res = sqlx::query!(
652            r#"
653                WITH
654                    to_delete AS (
655                        SELECT user_session_id, finished_at
656                        FROM user_sessions us
657                        WHERE us.finished_at IS NOT NULL
658                          AND ($1::timestamptz IS NULL OR us.finished_at >= $1)
659                          AND us.finished_at < $2
660                          -- Only delete if no oauth2_sessions reference this user_session
661                          AND NOT EXISTS (
662                              SELECT 1 FROM oauth2_sessions os
663                              WHERE os.user_session_id = us.user_session_id
664                          )
665                          -- Only delete if no compat_sessions reference this user_session
666                          AND NOT EXISTS (
667                              SELECT 1 FROM compat_sessions cs
668                              WHERE cs.user_session_id = us.user_session_id
669                          )
670                        ORDER BY us.finished_at ASC
671                        LIMIT $3
672                        FOR UPDATE OF us
673                    ),
674                    deleted_authentications AS (
675                        DELETE FROM user_session_authentications USING to_delete
676                        WHERE user_session_authentications.user_session_id = to_delete.user_session_id
677                    ),
678                    deleted_sessions AS (
679                        DELETE FROM user_sessions USING to_delete
680                        WHERE user_sessions.user_session_id = to_delete.user_session_id
681                        RETURNING user_sessions.finished_at
682                    )
683                SELECT COUNT(*) as "count!", MAX(finished_at) as last_finished_at FROM deleted_sessions
684            "#,
685            since,
686            until,
687            i64::try_from(limit).unwrap_or(i64::MAX),
688        )
689        .traced()
690        .fetch_one(&mut *self.conn)
691        .await?;
692
693        Ok((
694            res.count.try_into().unwrap_or(usize::MAX),
695            res.last_finished_at,
696        ))
697    }
698
699    #[tracing::instrument(
700        name = "db.browser_session.cleanup_inactive_ips",
701        skip_all,
702        fields(
703            db.query.text,
704            since = since.map(tracing::field::display),
705            threshold = %threshold,
706            limit = limit,
707        ),
708        err,
709    )]
710    async fn cleanup_inactive_ips(
711        &mut self,
712        since: Option<DateTime<Utc>>,
713        threshold: DateTime<Utc>,
714        limit: usize,
715    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
716        let res = sqlx::query!(
717            r#"
718                WITH to_update AS (
719                    SELECT user_session_id, last_active_at
720                    FROM user_sessions
721                    WHERE last_active_ip IS NOT NULL
722                      AND last_active_at IS NOT NULL
723                      AND ($1::timestamptz IS NULL OR last_active_at >= $1)
724                      AND last_active_at < $2
725                    ORDER BY last_active_at ASC
726                    LIMIT $3
727                    FOR UPDATE
728                ),
729                updated AS (
730                    UPDATE user_sessions
731                    SET last_active_ip = NULL
732                    FROM to_update
733                    WHERE user_sessions.user_session_id = to_update.user_session_id
734                    RETURNING user_sessions.last_active_at
735                )
736                SELECT COUNT(*) AS "count!", MAX(last_active_at) AS last_active_at FROM updated
737            "#,
738            since,
739            threshold,
740            i64::try_from(limit).unwrap_or(i64::MAX),
741        )
742        .traced()
743        .fetch_one(&mut *self.conn)
744        .await?;
745
746        Ok((
747            res.count.try_into().unwrap_or(usize::MAX),
748            res.last_active_at,
749        ))
750    }
751}