mas_storage_pg/oauth2/
session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Client, Clock, Session, SessionState, User};
13use mas_storage::{
14    Page, Pagination,
15    oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
16    pagination::Node,
17};
18use oauth2_types::scope::{Scope, ScopeToken};
19use rand::RngCore;
20use sea_query::{
21    Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
22    extension::postgres::PgExpr,
23};
24use sea_query_binder::SqlxBinder;
25use sqlx::PgConnection;
26use ulid::Ulid;
27use uuid::Uuid;
28
29use crate::{
30    DatabaseError, DatabaseInconsistencyError,
31    filter::{Filter, StatementExt},
32    iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
33    pagination::QueryBuilderExt,
34    tracing::ExecuteExt,
35};
36
37/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
38pub struct PgOAuth2SessionRepository<'c> {
39    conn: &'c mut PgConnection,
40}
41
42impl<'c> PgOAuth2SessionRepository<'c> {
43    /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
44    /// connection
45    pub fn new(conn: &'c mut PgConnection) -> Self {
46        Self { conn }
47    }
48}
49
50#[derive(sqlx::FromRow)]
51#[enum_def]
52struct OAuthSessionLookup {
53    oauth2_session_id: Uuid,
54    user_id: Option<Uuid>,
55    user_session_id: Option<Uuid>,
56    oauth2_client_id: Uuid,
57    scope_list: Vec<String>,
58    created_at: DateTime<Utc>,
59    finished_at: Option<DateTime<Utc>>,
60    user_agent: Option<String>,
61    last_active_at: Option<DateTime<Utc>>,
62    last_active_ip: Option<IpAddr>,
63    human_name: Option<String>,
64}
65
66impl Node<Ulid> for OAuthSessionLookup {
67    fn cursor(&self) -> Ulid {
68        self.oauth2_session_id.into()
69    }
70}
71
72impl TryFrom<OAuthSessionLookup> for Session {
73    type Error = DatabaseInconsistencyError;
74
75    fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
76        let id = Ulid::from(value.oauth2_session_id);
77        let scope: Result<Scope, _> = value
78            .scope_list
79            .iter()
80            .map(|s| s.parse::<ScopeToken>())
81            .collect();
82        let scope = scope.map_err(|e| {
83            DatabaseInconsistencyError::on("oauth2_sessions")
84                .column("scope")
85                .row(id)
86                .source(e)
87        })?;
88
89        let state = match value.finished_at {
90            None => SessionState::Valid,
91            Some(finished_at) => SessionState::Finished { finished_at },
92        };
93
94        Ok(Session {
95            id,
96            state,
97            created_at: value.created_at,
98            client_id: value.oauth2_client_id.into(),
99            user_id: value.user_id.map(Ulid::from),
100            user_session_id: value.user_session_id.map(Ulid::from),
101            scope,
102            user_agent: value.user_agent,
103            last_active_at: value.last_active_at,
104            last_active_ip: value.last_active_ip,
105            human_name: value.human_name,
106        })
107    }
108}
109
110impl Filter for OAuth2SessionFilter<'_> {
111    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
112        sea_query::Condition::all()
113            .add_option(self.user().map(|user| {
114                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
115            }))
116            .add_option(self.client().map(|client| {
117                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
118                    .eq(Uuid::from(client.id))
119            }))
120            .add_option(self.client_kind().map(|client_kind| {
121                // This builds either a:
122                // `WHERE oauth2_client_id = ANY(...)`
123                // or a `WHERE oauth2_client_id <> ALL(...)`
124                let static_clients = Query::select()
125                    .expr(Expr::col((
126                        OAuth2Clients::Table,
127                        OAuth2Clients::OAuth2ClientId,
128                    )))
129                    .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
130                    .from(OAuth2Clients::Table)
131                    .take();
132                if client_kind.is_static() {
133                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
134                        .eq(Expr::any(static_clients))
135                } else {
136                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
137                        .ne(Expr::all(static_clients))
138                }
139            }))
140            .add_option(self.device().map(|device| -> SimpleExpr {
141                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
142                    Condition::any()
143                        .add(
144                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
145                                OAuth2Sessions::Table,
146                                OAuth2Sessions::ScopeList,
147                            )))),
148                        )
149                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
150                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
151                        )))
152                        .into()
153                } else {
154                    // If the device ID can't be encoded as a scope token, match no rows
155                    Expr::val(false).into()
156                }
157            }))
158            .add_option(self.browser_session().map(|browser_session| {
159                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
160                    .eq(Uuid::from(browser_session.id))
161            }))
162            .add_option(self.browser_session_filter().map(|browser_session_filter| {
163                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
164                    Query::select()
165                        .expr(Expr::col((
166                            UserSessions::Table,
167                            UserSessions::UserSessionId,
168                        )))
169                        .apply_filter(browser_session_filter)
170                        .from(UserSessions::Table)
171                        .take(),
172                )
173            }))
174            .add_option(self.state().map(|state| {
175                if state.is_active() {
176                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
177                } else {
178                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
179                }
180            }))
181            .add_option(self.scope().map(|scope| {
182                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
183                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
184            }))
185            .add_option(self.any_user().map(|any_user| {
186                if any_user {
187                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
188                } else {
189                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
190                }
191            }))
192            .add_option(self.last_active_after().map(|last_active_after| {
193                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
194                    .gt(last_active_after)
195            }))
196            .add_option(self.last_active_before().map(|last_active_before| {
197                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
198                    .lt(last_active_before)
199            }))
200    }
201}
202
203#[async_trait]
204impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
205    type Error = DatabaseError;
206
207    #[tracing::instrument(
208        name = "db.oauth2_session.lookup",
209        skip_all,
210        fields(
211            db.query.text,
212            session.id = %id,
213        ),
214        err,
215    )]
216    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
217        let res = sqlx::query_as!(
218            OAuthSessionLookup,
219            r#"
220                SELECT oauth2_session_id
221                     , user_id
222                     , user_session_id
223                     , oauth2_client_id
224                     , scope_list
225                     , created_at
226                     , finished_at
227                     , user_agent
228                     , last_active_at
229                     , last_active_ip as "last_active_ip: IpAddr"
230                     , human_name
231                FROM oauth2_sessions
232
233                WHERE oauth2_session_id = $1
234            "#,
235            Uuid::from(id),
236        )
237        .traced()
238        .fetch_optional(&mut *self.conn)
239        .await?;
240
241        let Some(session) = res else { return Ok(None) };
242
243        Ok(Some(session.try_into()?))
244    }
245
246    #[tracing::instrument(
247        name = "db.oauth2_session.add",
248        skip_all,
249        fields(
250            db.query.text,
251            %client.id,
252            session.id,
253            session.scope = %scope,
254        ),
255        err,
256    )]
257    async fn add(
258        &mut self,
259        rng: &mut (dyn RngCore + Send),
260        clock: &dyn Clock,
261        client: &Client,
262        user: Option<&User>,
263        user_session: Option<&BrowserSession>,
264        scope: Scope,
265    ) -> Result<Session, Self::Error> {
266        let created_at = clock.now();
267        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
268        tracing::Span::current().record("session.id", tracing::field::display(id));
269
270        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
271
272        sqlx::query!(
273            r#"
274                INSERT INTO oauth2_sessions
275                    ( oauth2_session_id
276                    , user_id
277                    , user_session_id
278                    , oauth2_client_id
279                    , scope_list
280                    , created_at
281                    )
282                VALUES ($1, $2, $3, $4, $5, $6)
283            "#,
284            Uuid::from(id),
285            user.map(|u| Uuid::from(u.id)),
286            user_session.map(|s| Uuid::from(s.id)),
287            Uuid::from(client.id),
288            &scope_list,
289            created_at,
290        )
291        .traced()
292        .execute(&mut *self.conn)
293        .await?;
294
295        Ok(Session {
296            id,
297            state: SessionState::Valid,
298            created_at,
299            user_id: user.map(|u| u.id),
300            user_session_id: user_session.map(|s| s.id),
301            client_id: client.id,
302            scope,
303            user_agent: None,
304            last_active_at: None,
305            last_active_ip: None,
306            human_name: None,
307        })
308    }
309
310    #[tracing::instrument(
311        name = "db.oauth2_session.finish_bulk",
312        skip_all,
313        fields(
314            db.query.text,
315        ),
316        err,
317    )]
318    async fn finish_bulk(
319        &mut self,
320        clock: &dyn Clock,
321        filter: OAuth2SessionFilter<'_>,
322    ) -> Result<usize, Self::Error> {
323        let finished_at = clock.now();
324        let (sql, arguments) = Query::update()
325            .table(OAuth2Sessions::Table)
326            .value(OAuth2Sessions::FinishedAt, finished_at)
327            .apply_filter(filter)
328            .build_sqlx(PostgresQueryBuilder);
329
330        let res = sqlx::query_with(&sql, arguments)
331            .traced()
332            .execute(&mut *self.conn)
333            .await?;
334
335        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
336    }
337
338    #[tracing::instrument(
339        name = "db.oauth2_session.finish",
340        skip_all,
341        fields(
342            db.query.text,
343            %session.id,
344            %session.scope,
345            client.id = %session.client_id,
346        ),
347        err,
348    )]
349    async fn finish(
350        &mut self,
351        clock: &dyn Clock,
352        session: Session,
353    ) -> Result<Session, Self::Error> {
354        let finished_at = clock.now();
355        let res = sqlx::query!(
356            r#"
357                UPDATE oauth2_sessions
358                SET finished_at = $2
359                WHERE oauth2_session_id = $1
360            "#,
361            Uuid::from(session.id),
362            finished_at,
363        )
364        .traced()
365        .execute(&mut *self.conn)
366        .await?;
367
368        DatabaseError::ensure_affected_rows(&res, 1)?;
369
370        session
371            .finish(finished_at)
372            .map_err(DatabaseError::to_invalid_operation)
373    }
374
375    #[tracing::instrument(
376        name = "db.oauth2_session.list",
377        skip_all,
378        fields(
379            db.query.text,
380        ),
381        err,
382    )]
383    async fn list(
384        &mut self,
385        filter: OAuth2SessionFilter<'_>,
386        pagination: Pagination,
387    ) -> Result<Page<Session>, Self::Error> {
388        let (sql, arguments) = Query::select()
389            .expr_as(
390                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
391                OAuthSessionLookupIden::Oauth2SessionId,
392            )
393            .expr_as(
394                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
395                OAuthSessionLookupIden::UserId,
396            )
397            .expr_as(
398                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
399                OAuthSessionLookupIden::UserSessionId,
400            )
401            .expr_as(
402                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
403                OAuthSessionLookupIden::Oauth2ClientId,
404            )
405            .expr_as(
406                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
407                OAuthSessionLookupIden::ScopeList,
408            )
409            .expr_as(
410                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
411                OAuthSessionLookupIden::CreatedAt,
412            )
413            .expr_as(
414                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
415                OAuthSessionLookupIden::FinishedAt,
416            )
417            .expr_as(
418                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
419                OAuthSessionLookupIden::UserAgent,
420            )
421            .expr_as(
422                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
423                OAuthSessionLookupIden::LastActiveAt,
424            )
425            .expr_as(
426                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
427                OAuthSessionLookupIden::LastActiveIp,
428            )
429            .expr_as(
430                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
431                OAuthSessionLookupIden::HumanName,
432            )
433            .from(OAuth2Sessions::Table)
434            .apply_filter(filter)
435            .generate_pagination(
436                (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
437                pagination,
438            )
439            .build_sqlx(PostgresQueryBuilder);
440
441        let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
442            .traced()
443            .fetch_all(&mut *self.conn)
444            .await?;
445
446        let page = pagination.process(edges).try_map(Session::try_from)?;
447
448        Ok(page)
449    }
450
451    #[tracing::instrument(
452        name = "db.oauth2_session.count",
453        skip_all,
454        fields(
455            db.query.text,
456        ),
457        err,
458    )]
459    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
460        let (sql, arguments) = Query::select()
461            .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
462            .from(OAuth2Sessions::Table)
463            .apply_filter(filter)
464            .build_sqlx(PostgresQueryBuilder);
465
466        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
467            .traced()
468            .fetch_one(&mut *self.conn)
469            .await?;
470
471        count
472            .try_into()
473            .map_err(DatabaseError::to_invalid_operation)
474    }
475
476    #[tracing::instrument(
477        name = "db.oauth2_session.record_batch_activity",
478        skip_all,
479        fields(
480            db.query.text,
481        ),
482        err,
483    )]
484    async fn record_batch_activity(
485        &mut self,
486        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
487    ) -> Result<(), Self::Error> {
488        // Sort the activity by ID, so that when batching the updates, Postgres
489        // locks the rows in a stable order, preventing deadlocks
490        activities.sort_unstable();
491        let mut ids = Vec::with_capacity(activities.len());
492        let mut last_activities = Vec::with_capacity(activities.len());
493        let mut ips = Vec::with_capacity(activities.len());
494
495        for (id, last_activity, ip) in activities {
496            ids.push(Uuid::from(id));
497            last_activities.push(last_activity);
498            ips.push(ip);
499        }
500
501        let res = sqlx::query!(
502            r#"
503                UPDATE oauth2_sessions
504                SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
505                  , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
506                FROM (
507                    SELECT *
508                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
509                        AS t(oauth2_session_id, last_active_at, last_active_ip)
510                ) AS t
511                WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
512            "#,
513            &ids,
514            &last_activities,
515            &ips as &[Option<IpAddr>],
516        )
517        .traced()
518        .execute(&mut *self.conn)
519        .await?;
520
521        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
522
523        Ok(())
524    }
525
526    #[tracing::instrument(
527        name = "db.oauth2_session.record_user_agent",
528        skip_all,
529        fields(
530            db.query.text,
531            %session.id,
532            %session.scope,
533            client.id = %session.client_id,
534            session.user_agent = user_agent,
535        ),
536        err,
537    )]
538    async fn record_user_agent(
539        &mut self,
540        mut session: Session,
541        user_agent: String,
542    ) -> Result<Session, Self::Error> {
543        let res = sqlx::query!(
544            r#"
545                UPDATE oauth2_sessions
546                SET user_agent = $2
547                WHERE oauth2_session_id = $1
548            "#,
549            Uuid::from(session.id),
550            &*user_agent,
551        )
552        .traced()
553        .execute(&mut *self.conn)
554        .await?;
555
556        session.user_agent = Some(user_agent);
557
558        DatabaseError::ensure_affected_rows(&res, 1)?;
559
560        Ok(session)
561    }
562
563    #[tracing::instrument(
564        name = "repository.oauth2_session.set_human_name",
565        skip(self),
566        fields(
567            client.id = %session.client_id,
568            session.human_name = ?human_name,
569        ),
570        err,
571    )]
572    async fn set_human_name(
573        &mut self,
574        mut session: Session,
575        human_name: Option<String>,
576    ) -> Result<Session, Self::Error> {
577        let res = sqlx::query!(
578            r#"
579                UPDATE oauth2_sessions
580                SET human_name = $2
581                WHERE oauth2_session_id = $1
582            "#,
583            Uuid::from(session.id),
584            human_name.as_deref(),
585        )
586        .traced()
587        .execute(&mut *self.conn)
588        .await?;
589
590        session.human_name = human_name;
591
592        DatabaseError::ensure_affected_rows(&res, 1)?;
593
594        Ok(session)
595    }
596
597    #[tracing::instrument(
598        name = "db.oauth2_session.cleanup_finished",
599        skip_all,
600        fields(
601            db.query.text,
602            since = since.map(tracing::field::display),
603            until = %until,
604            limit = limit,
605        ),
606        err,
607    )]
608    async fn cleanup_finished(
609        &mut self,
610        since: Option<DateTime<Utc>>,
611        until: DateTime<Utc>,
612        limit: usize,
613    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
614        let res = sqlx::query!(
615            r#"
616                WITH
617                    to_delete AS (
618                        SELECT oauth2_session_id, finished_at
619                        FROM oauth2_sessions
620                        WHERE finished_at IS NOT NULL
621                          AND ($1::timestamptz IS NULL OR finished_at >= $1)
622                          AND finished_at < $2
623                        ORDER BY finished_at ASC
624                        LIMIT $3
625                        FOR UPDATE
626                    ),
627                    deleted_refresh_tokens AS (
628                        DELETE FROM oauth2_refresh_tokens USING to_delete
629                        WHERE oauth2_refresh_tokens.oauth2_session_id = to_delete.oauth2_session_id
630                    ),
631                    deleted_access_tokens AS (
632                        DELETE FROM oauth2_access_tokens USING to_delete
633                        WHERE oauth2_access_tokens.oauth2_session_id = to_delete.oauth2_session_id
634                    ),
635                    deleted_sessions AS (
636                        DELETE FROM oauth2_sessions USING to_delete
637                        WHERE oauth2_sessions.oauth2_session_id = to_delete.oauth2_session_id
638                        RETURNING oauth2_sessions.finished_at
639                    )
640                SELECT COUNT(*) as "count!", MAX(finished_at) as last_finished_at FROM deleted_sessions
641            "#,
642            since,
643            until,
644            i64::try_from(limit).unwrap_or(i64::MAX),
645        )
646        .traced()
647        .fetch_one(&mut *self.conn)
648        .await?;
649
650        Ok((
651            res.count.try_into().unwrap_or(usize::MAX),
652            res.last_finished_at,
653        ))
654    }
655
656    #[tracing::instrument(
657        name = "db.oauth2_session.cleanup_inactive_ips",
658        skip_all,
659        fields(
660            db.query.text,
661            since = since.map(tracing::field::display),
662            threshold = %threshold,
663            limit = limit,
664        ),
665        err,
666    )]
667    async fn cleanup_inactive_ips(
668        &mut self,
669        since: Option<DateTime<Utc>>,
670        threshold: DateTime<Utc>,
671        limit: usize,
672    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
673        let res = sqlx::query!(
674            r#"
675                WITH to_update AS (
676                    SELECT oauth2_session_id, last_active_at
677                    FROM oauth2_sessions
678                    WHERE last_active_ip IS NOT NULL
679                      AND last_active_at IS NOT NULL
680                      AND ($1::timestamptz IS NULL OR last_active_at >= $1)
681                      AND last_active_at < $2
682                    ORDER BY last_active_at ASC
683                    LIMIT $3
684                    FOR UPDATE
685                ),
686                updated AS (
687                    UPDATE oauth2_sessions
688                    SET last_active_ip = NULL
689                    FROM to_update
690                    WHERE oauth2_sessions.oauth2_session_id = to_update.oauth2_session_id
691                    RETURNING oauth2_sessions.last_active_at
692                )
693                SELECT COUNT(*) AS "count!", MAX(last_active_at) AS last_active_at FROM updated
694            "#,
695            since,
696            threshold,
697            i64::try_from(limit).unwrap_or(i64::MAX),
698        )
699        .traced()
700        .fetch_one(&mut *self.conn)
701        .await?;
702
703        Ok((
704            res.count.try_into().unwrap_or(usize::MAX),
705            res.last_active_at,
706        ))
707    }
708}