mas_storage_pg/upstream_oauth2/
session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    BrowserSession, Clock, UpstreamOAuthAuthorizationSession,
12    UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider,
13};
14use mas_storage::{
15    Page, Pagination,
16    pagination::Node,
17    upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27    DatabaseError, DatabaseInconsistencyError,
28    filter::{Filter, StatementExt},
29    iden::UpstreamOAuthAuthorizationSessions,
30    pagination::QueryBuilderExt,
31    tracing::ExecuteExt,
32};
33
34impl Filter for UpstreamOAuthSessionFilter<'_> {
35    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
36        sea_query::Condition::all()
37            .add_option(self.provider().map(|provider| {
38                Expr::col((
39                    UpstreamOAuthAuthorizationSessions::Table,
40                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
41                ))
42                .eq(Uuid::from(provider.id))
43            }))
44            .add_option(self.sub_claim().map(|sub| {
45                Expr::col((
46                    UpstreamOAuthAuthorizationSessions::Table,
47                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
48                ))
49                .cast_json_field("sub")
50                .eq(sub)
51            }))
52            .add_option(self.sid_claim().map(|sid| {
53                Expr::col((
54                    UpstreamOAuthAuthorizationSessions::Table,
55                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
56                ))
57                .cast_json_field("sid")
58                .eq(sid)
59            }))
60    }
61}
62
63/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL
64/// connection
65pub struct PgUpstreamOAuthSessionRepository<'c> {
66    conn: &'c mut PgConnection,
67}
68
69impl<'c> PgUpstreamOAuthSessionRepository<'c> {
70    /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active
71    /// PostgreSQL connection
72    pub fn new(conn: &'c mut PgConnection) -> Self {
73        Self { conn }
74    }
75}
76
77#[derive(sqlx::FromRow)]
78#[enum_def]
79struct SessionLookup {
80    upstream_oauth_authorization_session_id: Uuid,
81    upstream_oauth_provider_id: Uuid,
82    upstream_oauth_link_id: Option<Uuid>,
83    state: String,
84    code_challenge_verifier: Option<String>,
85    nonce: Option<String>,
86    id_token: Option<String>,
87    id_token_claims: Option<serde_json::Value>,
88    userinfo: Option<serde_json::Value>,
89    created_at: DateTime<Utc>,
90    completed_at: Option<DateTime<Utc>>,
91    consumed_at: Option<DateTime<Utc>>,
92    extra_callback_parameters: Option<serde_json::Value>,
93    unlinked_at: Option<DateTime<Utc>>,
94}
95
96impl Node<Ulid> for SessionLookup {
97    fn cursor(&self) -> Ulid {
98        self.upstream_oauth_authorization_session_id.into()
99    }
100}
101
102impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
103    type Error = DatabaseInconsistencyError;
104
105    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
106        let id = value.upstream_oauth_authorization_session_id.into();
107        let state = match (
108            value.upstream_oauth_link_id,
109            value.id_token,
110            value.id_token_claims,
111            value.extra_callback_parameters,
112            value.userinfo,
113            value.completed_at,
114            value.consumed_at,
115            value.unlinked_at,
116        ) {
117            (None, None, None, None, None, None, None, None) => {
118                UpstreamOAuthAuthorizationSessionState::Pending
119            }
120            (
121                Some(link_id),
122                id_token,
123                id_token_claims,
124                extra_callback_parameters,
125                userinfo,
126                Some(completed_at),
127                None,
128                None,
129            ) => UpstreamOAuthAuthorizationSessionState::Completed {
130                completed_at,
131                link_id: link_id.into(),
132                id_token,
133                id_token_claims,
134                extra_callback_parameters,
135                userinfo,
136            },
137            (
138                Some(link_id),
139                id_token,
140                id_token_claims,
141                extra_callback_parameters,
142                userinfo,
143                Some(completed_at),
144                Some(consumed_at),
145                None,
146            ) => UpstreamOAuthAuthorizationSessionState::Consumed {
147                completed_at,
148                link_id: link_id.into(),
149                id_token,
150                id_token_claims,
151                extra_callback_parameters,
152                userinfo,
153                consumed_at,
154            },
155            (
156                _,
157                id_token,
158                id_token_claims,
159                _,
160                _,
161                Some(completed_at),
162                consumed_at,
163                Some(unlinked_at),
164            ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
165                completed_at,
166                id_token,
167                id_token_claims,
168                consumed_at,
169                unlinked_at,
170            },
171            _ => {
172                return Err(DatabaseInconsistencyError::on(
173                    "upstream_oauth_authorization_sessions",
174                )
175                .row(id));
176            }
177        };
178
179        Ok(Self {
180            id,
181            provider_id: value.upstream_oauth_provider_id.into(),
182            state_str: value.state,
183            nonce: value.nonce,
184            code_challenge_verifier: value.code_challenge_verifier,
185            created_at: value.created_at,
186            state,
187        })
188    }
189}
190
191#[async_trait]
192impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
193    type Error = DatabaseError;
194
195    #[tracing::instrument(
196        name = "db.upstream_oauth_authorization_session.lookup",
197        skip_all,
198        fields(
199            db.query.text,
200            upstream_oauth_provider.id = %id,
201        ),
202        err,
203    )]
204    async fn lookup(
205        &mut self,
206        id: Ulid,
207    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
208        let res = sqlx::query_as!(
209            SessionLookup,
210            r#"
211                SELECT
212                    upstream_oauth_authorization_session_id,
213                    upstream_oauth_provider_id,
214                    upstream_oauth_link_id,
215                    state,
216                    code_challenge_verifier,
217                    nonce,
218                    id_token,
219                    id_token_claims,
220                    extra_callback_parameters,
221                    userinfo,
222                    created_at,
223                    completed_at,
224                    consumed_at,
225                    unlinked_at
226                FROM upstream_oauth_authorization_sessions
227                WHERE upstream_oauth_authorization_session_id = $1
228            "#,
229            Uuid::from(id),
230        )
231        .traced()
232        .fetch_optional(&mut *self.conn)
233        .await?;
234
235        let Some(res) = res else { return Ok(None) };
236
237        Ok(Some(res.try_into()?))
238    }
239
240    #[tracing::instrument(
241        name = "db.upstream_oauth_authorization_session.add",
242        skip_all,
243        fields(
244            db.query.text,
245            %upstream_oauth_provider.id,
246            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
247            %upstream_oauth_provider.client_id,
248            upstream_oauth_authorization_session.id,
249        ),
250        err,
251    )]
252    async fn add(
253        &mut self,
254        rng: &mut (dyn RngCore + Send),
255        clock: &dyn Clock,
256        upstream_oauth_provider: &UpstreamOAuthProvider,
257        state_str: String,
258        code_challenge_verifier: Option<String>,
259        nonce: Option<String>,
260    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
261        let created_at = clock.now();
262        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
263        tracing::Span::current().record(
264            "upstream_oauth_authorization_session.id",
265            tracing::field::display(id),
266        );
267
268        sqlx::query!(
269            r#"
270                INSERT INTO upstream_oauth_authorization_sessions (
271                    upstream_oauth_authorization_session_id,
272                    upstream_oauth_provider_id,
273                    state,
274                    code_challenge_verifier,
275                    nonce,
276                    created_at,
277                    completed_at,
278                    consumed_at,
279                    id_token,
280                    userinfo
281                ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
282            "#,
283            Uuid::from(id),
284            Uuid::from(upstream_oauth_provider.id),
285            &state_str,
286            code_challenge_verifier.as_deref(),
287            nonce,
288            created_at,
289        )
290        .traced()
291        .execute(&mut *self.conn)
292        .await?;
293
294        Ok(UpstreamOAuthAuthorizationSession {
295            id,
296            state: UpstreamOAuthAuthorizationSessionState::default(),
297            provider_id: upstream_oauth_provider.id,
298            state_str,
299            code_challenge_verifier,
300            nonce,
301            created_at,
302        })
303    }
304
305    #[tracing::instrument(
306        name = "db.upstream_oauth_authorization_session.complete_with_link",
307        skip_all,
308        fields(
309            db.query.text,
310            %upstream_oauth_authorization_session.id,
311            %upstream_oauth_link.id,
312        ),
313        err,
314    )]
315    async fn complete_with_link(
316        &mut self,
317        clock: &dyn Clock,
318        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
319        upstream_oauth_link: &UpstreamOAuthLink,
320        id_token: Option<String>,
321        id_token_claims: Option<serde_json::Value>,
322        extra_callback_parameters: Option<serde_json::Value>,
323        userinfo: Option<serde_json::Value>,
324    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
325        let completed_at = clock.now();
326
327        sqlx::query!(
328            r#"
329                UPDATE upstream_oauth_authorization_sessions
330                SET upstream_oauth_link_id = $1
331                  , completed_at = $2
332                  , id_token = $3
333                  , id_token_claims = $4
334                  , extra_callback_parameters = $5
335                  , userinfo = $6
336                WHERE upstream_oauth_authorization_session_id = $7
337            "#,
338            Uuid::from(upstream_oauth_link.id),
339            completed_at,
340            id_token,
341            id_token_claims,
342            extra_callback_parameters,
343            userinfo,
344            Uuid::from(upstream_oauth_authorization_session.id),
345        )
346        .traced()
347        .execute(&mut *self.conn)
348        .await?;
349
350        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
351            .complete(
352                completed_at,
353                upstream_oauth_link,
354                id_token,
355                id_token_claims,
356                extra_callback_parameters,
357                userinfo,
358            )
359            .map_err(DatabaseError::to_invalid_operation)?;
360
361        Ok(upstream_oauth_authorization_session)
362    }
363
364    /// Mark a session as consumed
365    #[tracing::instrument(
366        name = "db.upstream_oauth_authorization_session.consume",
367        skip_all,
368        fields(
369            db.query.text,
370            %upstream_oauth_authorization_session.id,
371        ),
372        err,
373    )]
374    async fn consume(
375        &mut self,
376        clock: &dyn Clock,
377        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
378        browser_session: &BrowserSession,
379    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
380        let consumed_at = clock.now();
381        sqlx::query!(
382            r#"
383                UPDATE upstream_oauth_authorization_sessions
384                SET consumed_at = $1,
385                    user_session_id = $2
386                WHERE upstream_oauth_authorization_session_id = $3
387            "#,
388            consumed_at,
389            Uuid::from(browser_session.id),
390            Uuid::from(upstream_oauth_authorization_session.id),
391        )
392        .traced()
393        .execute(&mut *self.conn)
394        .await?;
395
396        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
397            .consume(consumed_at)
398            .map_err(DatabaseError::to_invalid_operation)?;
399
400        Ok(upstream_oauth_authorization_session)
401    }
402
403    #[tracing::instrument(
404        name = "db.upstream_oauth_authorization_session.list",
405        skip_all,
406        fields(
407            db.query.text,
408        ),
409        err,
410    )]
411    async fn list(
412        &mut self,
413        filter: UpstreamOAuthSessionFilter<'_>,
414        pagination: Pagination,
415    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
416        let (sql, arguments) = Query::select()
417            .expr_as(
418                Expr::col((
419                    UpstreamOAuthAuthorizationSessions::Table,
420                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
421                )),
422                SessionLookupIden::UpstreamOauthAuthorizationSessionId,
423            )
424            .expr_as(
425                Expr::col((
426                    UpstreamOAuthAuthorizationSessions::Table,
427                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
428                )),
429                SessionLookupIden::UpstreamOauthProviderId,
430            )
431            .expr_as(
432                Expr::col((
433                    UpstreamOAuthAuthorizationSessions::Table,
434                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
435                )),
436                SessionLookupIden::UpstreamOauthLinkId,
437            )
438            .expr_as(
439                Expr::col((
440                    UpstreamOAuthAuthorizationSessions::Table,
441                    UpstreamOAuthAuthorizationSessions::State,
442                )),
443                SessionLookupIden::State,
444            )
445            .expr_as(
446                Expr::col((
447                    UpstreamOAuthAuthorizationSessions::Table,
448                    UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
449                )),
450                SessionLookupIden::CodeChallengeVerifier,
451            )
452            .expr_as(
453                Expr::col((
454                    UpstreamOAuthAuthorizationSessions::Table,
455                    UpstreamOAuthAuthorizationSessions::Nonce,
456                )),
457                SessionLookupIden::Nonce,
458            )
459            .expr_as(
460                Expr::col((
461                    UpstreamOAuthAuthorizationSessions::Table,
462                    UpstreamOAuthAuthorizationSessions::IdToken,
463                )),
464                SessionLookupIden::IdToken,
465            )
466            .expr_as(
467                Expr::col((
468                    UpstreamOAuthAuthorizationSessions::Table,
469                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
470                )),
471                SessionLookupIden::IdTokenClaims,
472            )
473            .expr_as(
474                Expr::col((
475                    UpstreamOAuthAuthorizationSessions::Table,
476                    UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
477                )),
478                SessionLookupIden::ExtraCallbackParameters,
479            )
480            .expr_as(
481                Expr::col((
482                    UpstreamOAuthAuthorizationSessions::Table,
483                    UpstreamOAuthAuthorizationSessions::Userinfo,
484                )),
485                SessionLookupIden::Userinfo,
486            )
487            .expr_as(
488                Expr::col((
489                    UpstreamOAuthAuthorizationSessions::Table,
490                    UpstreamOAuthAuthorizationSessions::CreatedAt,
491                )),
492                SessionLookupIden::CreatedAt,
493            )
494            .expr_as(
495                Expr::col((
496                    UpstreamOAuthAuthorizationSessions::Table,
497                    UpstreamOAuthAuthorizationSessions::CompletedAt,
498                )),
499                SessionLookupIden::CompletedAt,
500            )
501            .expr_as(
502                Expr::col((
503                    UpstreamOAuthAuthorizationSessions::Table,
504                    UpstreamOAuthAuthorizationSessions::ConsumedAt,
505                )),
506                SessionLookupIden::ConsumedAt,
507            )
508            .expr_as(
509                Expr::col((
510                    UpstreamOAuthAuthorizationSessions::Table,
511                    UpstreamOAuthAuthorizationSessions::UnlinkedAt,
512                )),
513                SessionLookupIden::UnlinkedAt,
514            )
515            .from(UpstreamOAuthAuthorizationSessions::Table)
516            .apply_filter(filter)
517            .generate_pagination(
518                (
519                    UpstreamOAuthAuthorizationSessions::Table,
520                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
521                ),
522                pagination,
523            )
524            .build_sqlx(PostgresQueryBuilder);
525
526        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
527            .traced()
528            .fetch_all(&mut *self.conn)
529            .await?;
530
531        let page = pagination
532            .process(edges)
533            .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
534
535        Ok(page)
536    }
537
538    #[tracing::instrument(
539        name = "db.upstream_oauth_authorization_session.count",
540        skip_all,
541        fields(
542            db.query.text,
543        ),
544        err,
545    )]
546    async fn count(
547        &mut self,
548        filter: UpstreamOAuthSessionFilter<'_>,
549    ) -> Result<usize, Self::Error> {
550        let (sql, arguments) = Query::select()
551            .expr(
552                Expr::col((
553                    UpstreamOAuthAuthorizationSessions::Table,
554                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
555                ))
556                .count(),
557            )
558            .from(UpstreamOAuthAuthorizationSessions::Table)
559            .apply_filter(filter)
560            .build_sqlx(PostgresQueryBuilder);
561
562        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
563            .traced()
564            .fetch_one(&mut *self.conn)
565            .await?;
566
567        count
568            .try_into()
569            .map_err(DatabaseError::to_invalid_operation)
570    }
571
572    #[tracing::instrument(
573        name = "db.upstream_oauth_authorization_session.cleanup",
574        skip_all,
575        fields(
576            db.query.text,
577            since = since.map(tracing::field::display),
578            until = %until,
579            limit = limit,
580        ),
581        err,
582    )]
583    async fn cleanup_orphaned(
584        &mut self,
585        since: Option<Ulid>,
586        until: Ulid,
587        limit: usize,
588    ) -> Result<(usize, Option<Ulid>), Self::Error> {
589        // Use ULID cursor-based pagination for pending sessions only.
590        // We only delete sessions that are not yet completed.
591        // `MAX(uuid)` isn't a thing in Postgres, so we aggregate on the client side.
592        let res = sqlx::query_scalar!(
593            r#"
594                WITH to_delete AS (
595                    SELECT upstream_oauth_authorization_session_id
596                    FROM upstream_oauth_authorization_sessions
597                    WHERE ($1::uuid IS NULL OR upstream_oauth_authorization_session_id > $1)
598                      AND upstream_oauth_authorization_session_id <= $2
599                      AND user_session_id IS NULL
600                    ORDER BY upstream_oauth_authorization_session_id
601                    LIMIT $3
602                )
603                DELETE FROM upstream_oauth_authorization_sessions
604                USING to_delete
605                WHERE upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id = to_delete.upstream_oauth_authorization_session_id
606                RETURNING upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id
607            "#,
608            since.map(Uuid::from),
609            Uuid::from(until),
610            i64::try_from(limit).unwrap_or(i64::MAX)
611        )
612        .traced()
613        .fetch_all(&mut *self.conn)
614        .await?;
615
616        let count = res.len();
617        let max_id = res.into_iter().max();
618
619        Ok((count, max_id.map(Ulid::from)))
620    }
621}