mas_storage_pg/upstream_oauth2/
link.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{Clock, UpstreamOAuthLink, UpstreamOAuthProvider, User};
11use mas_storage::{
12    Page, Pagination,
13    pagination::Node,
14    upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
15};
16use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use tracing::Instrument;
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26    DatabaseError,
27    filter::{Filter, StatementExt},
28    iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
29    pagination::QueryBuilderExt,
30    tracing::ExecuteExt,
31};
32
33/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
34/// connection
35pub struct PgUpstreamOAuthLinkRepository<'c> {
36    conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthLinkRepository<'c> {
40    /// Create a new [`PgUpstreamOAuthLinkRepository`] from an active PostgreSQL
41    /// connection
42    pub fn new(conn: &'c mut PgConnection) -> Self {
43        Self { conn }
44    }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct LinkLookup {
50    upstream_oauth_link_id: Uuid,
51    upstream_oauth_provider_id: Uuid,
52    user_id: Option<Uuid>,
53    subject: String,
54    human_account_name: Option<String>,
55    created_at: DateTime<Utc>,
56}
57
58impl Node<Ulid> for LinkLookup {
59    fn cursor(&self) -> Ulid {
60        self.upstream_oauth_link_id.into()
61    }
62}
63
64impl From<LinkLookup> for UpstreamOAuthLink {
65    fn from(value: LinkLookup) -> Self {
66        UpstreamOAuthLink {
67            id: Ulid::from(value.upstream_oauth_link_id),
68            provider_id: Ulid::from(value.upstream_oauth_provider_id),
69            user_id: value.user_id.map(Ulid::from),
70            subject: value.subject,
71            human_account_name: value.human_account_name,
72            created_at: value.created_at,
73        }
74    }
75}
76
77impl Filter for UpstreamOAuthLinkFilter<'_> {
78    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
79        sea_query::Condition::all()
80            .add_option(self.user().map(|user| {
81                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
82                    .eq(Uuid::from(user.id))
83            }))
84            .add_option(self.provider().map(|provider| {
85                Expr::col((
86                    UpstreamOAuthLinks::Table,
87                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
88                ))
89                .eq(Uuid::from(provider.id))
90            }))
91            .add_option(self.provider_enabled().map(|enabled| {
92                Expr::col((
93                    UpstreamOAuthLinks::Table,
94                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
95                ))
96                .eq(Expr::any(
97                    Query::select()
98                        .expr(Expr::col((
99                            UpstreamOAuthProviders::Table,
100                            UpstreamOAuthProviders::UpstreamOAuthProviderId,
101                        )))
102                        .from(UpstreamOAuthProviders::Table)
103                        .and_where(
104                            Expr::col((
105                                UpstreamOAuthProviders::Table,
106                                UpstreamOAuthProviders::DisabledAt,
107                            ))
108                            .is_null()
109                            .eq(enabled),
110                        )
111                        .take(),
112                ))
113            }))
114            .add_option(self.subject().map(|subject| {
115                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
116            }))
117    }
118}
119
120#[async_trait]
121impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
122    type Error = DatabaseError;
123
124    #[tracing::instrument(
125        name = "db.upstream_oauth_link.lookup",
126        skip_all,
127        fields(
128            db.query.text,
129            upstream_oauth_link.id = %id,
130        ),
131        err,
132    )]
133    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
134        let res = sqlx::query_as!(
135            LinkLookup,
136            r#"
137                SELECT
138                    upstream_oauth_link_id,
139                    upstream_oauth_provider_id,
140                    user_id,
141                    subject,
142                    human_account_name,
143                    created_at
144                FROM upstream_oauth_links
145                WHERE upstream_oauth_link_id = $1
146            "#,
147            Uuid::from(id),
148        )
149        .traced()
150        .fetch_optional(&mut *self.conn)
151        .await?
152        .map(Into::into);
153
154        Ok(res)
155    }
156
157    #[tracing::instrument(
158        name = "db.upstream_oauth_link.find_by_subject",
159        skip_all,
160        fields(
161            db.query.text,
162            upstream_oauth_link.subject = subject,
163            %upstream_oauth_provider.id,
164            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
165            %upstream_oauth_provider.client_id,
166        ),
167        err,
168    )]
169    async fn find_by_subject(
170        &mut self,
171        upstream_oauth_provider: &UpstreamOAuthProvider,
172        subject: &str,
173    ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
174        let res = sqlx::query_as!(
175            LinkLookup,
176            r#"
177                SELECT
178                    upstream_oauth_link_id,
179                    upstream_oauth_provider_id,
180                    user_id,
181                    subject,
182                    human_account_name,
183                    created_at
184                FROM upstream_oauth_links
185                WHERE upstream_oauth_provider_id = $1
186                  AND subject = $2
187            "#,
188            Uuid::from(upstream_oauth_provider.id),
189            subject,
190        )
191        .traced()
192        .fetch_optional(&mut *self.conn)
193        .await?
194        .map(Into::into);
195
196        Ok(res)
197    }
198
199    #[tracing::instrument(
200        name = "db.upstream_oauth_link.add",
201        skip_all,
202        fields(
203            db.query.text,
204            upstream_oauth_link.id,
205            upstream_oauth_link.subject = subject,
206            upstream_oauth_link.human_account_name = human_account_name,
207            %upstream_oauth_provider.id,
208            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
209            %upstream_oauth_provider.client_id,
210        ),
211        err,
212    )]
213    async fn add(
214        &mut self,
215        rng: &mut (dyn RngCore + Send),
216        clock: &dyn Clock,
217        upstream_oauth_provider: &UpstreamOAuthProvider,
218        subject: String,
219        human_account_name: Option<String>,
220    ) -> Result<UpstreamOAuthLink, Self::Error> {
221        let created_at = clock.now();
222        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
223        tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
224
225        sqlx::query!(
226            r#"
227                INSERT INTO upstream_oauth_links (
228                    upstream_oauth_link_id,
229                    upstream_oauth_provider_id,
230                    user_id,
231                    subject,
232                    human_account_name,
233                    created_at
234                ) VALUES ($1, $2, NULL, $3, $4, $5)
235            "#,
236            Uuid::from(id),
237            Uuid::from(upstream_oauth_provider.id),
238            &subject,
239            human_account_name.as_deref(),
240            created_at,
241        )
242        .traced()
243        .execute(&mut *self.conn)
244        .await?;
245
246        Ok(UpstreamOAuthLink {
247            id,
248            provider_id: upstream_oauth_provider.id,
249            user_id: None,
250            subject,
251            human_account_name,
252            created_at,
253        })
254    }
255
256    #[tracing::instrument(
257        name = "db.upstream_oauth_link.associate_to_user",
258        skip_all,
259        fields(
260            db.query.text,
261            %upstream_oauth_link.id,
262            %upstream_oauth_link.subject,
263            %user.id,
264            %user.username,
265        ),
266        err,
267    )]
268    async fn associate_to_user(
269        &mut self,
270        upstream_oauth_link: &UpstreamOAuthLink,
271        user: &User,
272    ) -> Result<(), Self::Error> {
273        sqlx::query!(
274            r#"
275                UPDATE upstream_oauth_links
276                SET user_id = $1
277                WHERE upstream_oauth_link_id = $2
278            "#,
279            Uuid::from(user.id),
280            Uuid::from(upstream_oauth_link.id),
281        )
282        .traced()
283        .execute(&mut *self.conn)
284        .await?;
285
286        Ok(())
287    }
288
289    #[tracing::instrument(
290        name = "db.upstream_oauth_link.list",
291        skip_all,
292        fields(
293            db.query.text,
294        ),
295        err,
296    )]
297    async fn list(
298        &mut self,
299        filter: UpstreamOAuthLinkFilter<'_>,
300        pagination: Pagination,
301    ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
302        let (sql, arguments) = Query::select()
303            .expr_as(
304                Expr::col((
305                    UpstreamOAuthLinks::Table,
306                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
307                )),
308                LinkLookupIden::UpstreamOauthLinkId,
309            )
310            .expr_as(
311                Expr::col((
312                    UpstreamOAuthLinks::Table,
313                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
314                )),
315                LinkLookupIden::UpstreamOauthProviderId,
316            )
317            .expr_as(
318                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
319                LinkLookupIden::UserId,
320            )
321            .expr_as(
322                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
323                LinkLookupIden::Subject,
324            )
325            .expr_as(
326                Expr::col((
327                    UpstreamOAuthLinks::Table,
328                    UpstreamOAuthLinks::HumanAccountName,
329                )),
330                LinkLookupIden::HumanAccountName,
331            )
332            .expr_as(
333                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
334                LinkLookupIden::CreatedAt,
335            )
336            .from(UpstreamOAuthLinks::Table)
337            .apply_filter(filter)
338            .generate_pagination(
339                (
340                    UpstreamOAuthLinks::Table,
341                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
342                ),
343                pagination,
344            )
345            .build_sqlx(PostgresQueryBuilder);
346
347        let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
348            .traced()
349            .fetch_all(&mut *self.conn)
350            .await?;
351
352        let page = pagination.process(edges).map(UpstreamOAuthLink::from);
353
354        Ok(page)
355    }
356
357    #[tracing::instrument(
358        name = "db.upstream_oauth_link.count",
359        skip_all,
360        fields(
361            db.query.text,
362        ),
363        err,
364    )]
365    async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
366        let (sql, arguments) = Query::select()
367            .expr(
368                Expr::col((
369                    UpstreamOAuthLinks::Table,
370                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
371                ))
372                .count(),
373            )
374            .from(UpstreamOAuthLinks::Table)
375            .apply_filter(filter)
376            .build_sqlx(PostgresQueryBuilder);
377
378        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
379            .traced()
380            .fetch_one(&mut *self.conn)
381            .await?;
382
383        count
384            .try_into()
385            .map_err(DatabaseError::to_invalid_operation)
386    }
387
388    #[tracing::instrument(
389        name = "db.upstream_oauth_link.remove",
390        skip_all,
391        fields(
392            db.query.text,
393            upstream_oauth_link.id,
394            upstream_oauth_link.provider_id,
395            %upstream_oauth_link.subject,
396        ),
397        err,
398    )]
399    async fn remove(
400        &mut self,
401        clock: &dyn Clock,
402        upstream_oauth_link: UpstreamOAuthLink,
403    ) -> Result<(), Self::Error> {
404        // Unlink the authorization sessions first, as they have a foreign key
405        // constraint on the links.
406        let span = tracing::info_span!(
407            "db.upstream_oauth_link.remove.unlink",
408            { DB_QUERY_TEXT } = tracing::field::Empty
409        );
410        sqlx::query!(
411            r#"
412                UPDATE upstream_oauth_authorization_sessions SET
413                    upstream_oauth_link_id = NULL,
414                    unlinked_at = $2
415                WHERE upstream_oauth_link_id = $1
416            "#,
417            Uuid::from(upstream_oauth_link.id),
418            clock.now()
419        )
420        .record(&span)
421        .execute(&mut *self.conn)
422        .instrument(span)
423        .await?;
424
425        // Then delete the link itself
426        let span = tracing::info_span!(
427            "db.upstream_oauth_link.remove.delete",
428            { DB_QUERY_TEXT } = tracing::field::Empty
429        );
430        let res = sqlx::query!(
431            r#"
432                DELETE FROM upstream_oauth_links
433                WHERE upstream_oauth_link_id = $1
434            "#,
435            Uuid::from(upstream_oauth_link.id),
436        )
437        .record(&span)
438        .execute(&mut *self.conn)
439        .instrument(span)
440        .await?;
441
442        DatabaseError::ensure_affected_rows(&res, 1)?;
443
444        Ok(())
445    }
446
447    #[tracing::instrument(
448        name = "db.upstream_oauth_link.cleanup_orphaned",
449        skip_all,
450        fields(
451            db.query.text,
452            since = since.map(tracing::field::display),
453            until = %until,
454            limit = limit,
455        ),
456        err,
457    )]
458    async fn cleanup_orphaned(
459        &mut self,
460        since: Option<Ulid>,
461        until: Ulid,
462        limit: usize,
463    ) -> Result<(usize, Option<Ulid>), Self::Error> {
464        // Use ULID cursor-based pagination for orphaned links only.
465        // We only delete links that have no user associated with them.
466        // `MAX(uuid)` isn't a thing in Postgres, so we aggregate on the client side.
467        let res = sqlx::query_scalar!(
468            r#"
469                WITH
470                  to_delete AS (
471                    SELECT upstream_oauth_link_id
472                    FROM upstream_oauth_links
473                    WHERE user_id IS NULL
474                    AND ($1::uuid IS NULL OR upstream_oauth_link_id > $1)
475                    AND upstream_oauth_link_id <= $2
476                    ORDER BY upstream_oauth_link_id
477                    LIMIT $3
478                  ),
479                  deleted_sessions AS (
480                    DELETE FROM upstream_oauth_authorization_sessions
481                    USING to_delete
482                    WHERE upstream_oauth_authorization_sessions.upstream_oauth_link_id = to_delete.upstream_oauth_link_id
483                  )
484                DELETE FROM upstream_oauth_links
485                USING to_delete
486                WHERE upstream_oauth_links.upstream_oauth_link_id = to_delete.upstream_oauth_link_id
487                RETURNING upstream_oauth_links.upstream_oauth_link_id
488            "#,
489            since.map(Uuid::from),
490            Uuid::from(until),
491            i64::try_from(limit).unwrap_or(i64::MAX)
492        )
493        .traced()
494        .fetch_all(&mut *self.conn)
495        .await?;
496
497        let count = res.len();
498        let max_id = res.into_iter().max();
499
500        Ok((count, max_id.map(Ulid::from)))
501    }
502}