mas_storage_pg/upstream_oauth2/
link.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
10use mas_storage::{
11    Clock, Page, Pagination,
12    upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
13};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use uuid::Uuid;
20
21use crate::{
22    DatabaseError,
23    filter::{Filter, StatementExt},
24    iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
25    pagination::QueryBuilderExt,
26    tracing::ExecuteExt,
27};
28
29/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
30/// connection
31pub struct PgUpstreamOAuthLinkRepository<'c> {
32    conn: &'c mut PgConnection,
33}
34
35impl<'c> PgUpstreamOAuthLinkRepository<'c> {
36    /// Create a new [`PgUpstreamOAuthLinkRepository`] from an active PostgreSQL
37    /// connection
38    pub fn new(conn: &'c mut PgConnection) -> Self {
39        Self { conn }
40    }
41}
42
43#[derive(sqlx::FromRow)]
44#[enum_def]
45struct LinkLookup {
46    upstream_oauth_link_id: Uuid,
47    upstream_oauth_provider_id: Uuid,
48    user_id: Option<Uuid>,
49    subject: String,
50    human_account_name: Option<String>,
51    created_at: DateTime<Utc>,
52}
53
54impl From<LinkLookup> for UpstreamOAuthLink {
55    fn from(value: LinkLookup) -> Self {
56        UpstreamOAuthLink {
57            id: Ulid::from(value.upstream_oauth_link_id),
58            provider_id: Ulid::from(value.upstream_oauth_provider_id),
59            user_id: value.user_id.map(Ulid::from),
60            subject: value.subject,
61            human_account_name: value.human_account_name,
62            created_at: value.created_at,
63        }
64    }
65}
66
67impl Filter for UpstreamOAuthLinkFilter<'_> {
68    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
69        sea_query::Condition::all()
70            .add_option(self.user().map(|user| {
71                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
72                    .eq(Uuid::from(user.id))
73            }))
74            .add_option(self.provider().map(|provider| {
75                Expr::col((
76                    UpstreamOAuthLinks::Table,
77                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
78                ))
79                .eq(Uuid::from(provider.id))
80            }))
81            .add_option(self.provider_enabled().map(|enabled| {
82                Expr::col((
83                    UpstreamOAuthLinks::Table,
84                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
85                ))
86                .eq(Expr::any(
87                    Query::select()
88                        .expr(Expr::col((
89                            UpstreamOAuthProviders::Table,
90                            UpstreamOAuthProviders::UpstreamOAuthProviderId,
91                        )))
92                        .from(UpstreamOAuthProviders::Table)
93                        .and_where(
94                            Expr::col((
95                                UpstreamOAuthProviders::Table,
96                                UpstreamOAuthProviders::DisabledAt,
97                            ))
98                            .is_null()
99                            .eq(enabled),
100                        )
101                        .take(),
102                ))
103            }))
104            .add_option(self.subject().map(|subject| {
105                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
106            }))
107    }
108}
109
110#[async_trait]
111impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
112    type Error = DatabaseError;
113
114    #[tracing::instrument(
115        name = "db.upstream_oauth_link.lookup",
116        skip_all,
117        fields(
118            db.query.text,
119            upstream_oauth_link.id = %id,
120        ),
121        err,
122    )]
123    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
124        let res = sqlx::query_as!(
125            LinkLookup,
126            r#"
127                SELECT
128                    upstream_oauth_link_id,
129                    upstream_oauth_provider_id,
130                    user_id,
131                    subject,
132                    human_account_name,
133                    created_at
134                FROM upstream_oauth_links
135                WHERE upstream_oauth_link_id = $1
136            "#,
137            Uuid::from(id),
138        )
139        .traced()
140        .fetch_optional(&mut *self.conn)
141        .await?
142        .map(Into::into);
143
144        Ok(res)
145    }
146
147    #[tracing::instrument(
148        name = "db.upstream_oauth_link.find_by_subject",
149        skip_all,
150        fields(
151            db.query.text,
152            upstream_oauth_link.subject = subject,
153            %upstream_oauth_provider.id,
154            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
155            %upstream_oauth_provider.client_id,
156        ),
157        err,
158    )]
159    async fn find_by_subject(
160        &mut self,
161        upstream_oauth_provider: &UpstreamOAuthProvider,
162        subject: &str,
163    ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
164        let res = sqlx::query_as!(
165            LinkLookup,
166            r#"
167                SELECT
168                    upstream_oauth_link_id,
169                    upstream_oauth_provider_id,
170                    user_id,
171                    subject,
172                    human_account_name,
173                    created_at
174                FROM upstream_oauth_links
175                WHERE upstream_oauth_provider_id = $1
176                  AND subject = $2
177            "#,
178            Uuid::from(upstream_oauth_provider.id),
179            subject,
180        )
181        .traced()
182        .fetch_optional(&mut *self.conn)
183        .await?
184        .map(Into::into);
185
186        Ok(res)
187    }
188
189    #[tracing::instrument(
190        name = "db.upstream_oauth_link.add",
191        skip_all,
192        fields(
193            db.query.text,
194            upstream_oauth_link.id,
195            upstream_oauth_link.subject = subject,
196            upstream_oauth_link.human_account_name = human_account_name,
197            %upstream_oauth_provider.id,
198            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
199            %upstream_oauth_provider.client_id,
200        ),
201        err,
202    )]
203    async fn add(
204        &mut self,
205        rng: &mut (dyn RngCore + Send),
206        clock: &dyn Clock,
207        upstream_oauth_provider: &UpstreamOAuthProvider,
208        subject: String,
209        human_account_name: Option<String>,
210    ) -> Result<UpstreamOAuthLink, Self::Error> {
211        let created_at = clock.now();
212        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
213        tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
214
215        sqlx::query!(
216            r#"
217                INSERT INTO upstream_oauth_links (
218                    upstream_oauth_link_id,
219                    upstream_oauth_provider_id,
220                    user_id,
221                    subject,
222                    human_account_name,
223                    created_at
224                ) VALUES ($1, $2, NULL, $3, $4, $5)
225            "#,
226            Uuid::from(id),
227            Uuid::from(upstream_oauth_provider.id),
228            &subject,
229            human_account_name.as_deref(),
230            created_at,
231        )
232        .traced()
233        .execute(&mut *self.conn)
234        .await?;
235
236        Ok(UpstreamOAuthLink {
237            id,
238            provider_id: upstream_oauth_provider.id,
239            user_id: None,
240            subject,
241            human_account_name,
242            created_at,
243        })
244    }
245
246    #[tracing::instrument(
247        name = "db.upstream_oauth_link.associate_to_user",
248        skip_all,
249        fields(
250            db.query.text,
251            %upstream_oauth_link.id,
252            %upstream_oauth_link.subject,
253            %user.id,
254            %user.username,
255        ),
256        err,
257    )]
258    async fn associate_to_user(
259        &mut self,
260        upstream_oauth_link: &UpstreamOAuthLink,
261        user: &User,
262    ) -> Result<(), Self::Error> {
263        sqlx::query!(
264            r#"
265                UPDATE upstream_oauth_links
266                SET user_id = $1
267                WHERE upstream_oauth_link_id = $2
268            "#,
269            Uuid::from(user.id),
270            Uuid::from(upstream_oauth_link.id),
271        )
272        .traced()
273        .execute(&mut *self.conn)
274        .await?;
275
276        Ok(())
277    }
278
279    #[tracing::instrument(
280        name = "db.upstream_oauth_link.list",
281        skip_all,
282        fields(
283            db.query.text,
284        ),
285        err,
286    )]
287    async fn list(
288        &mut self,
289        filter: UpstreamOAuthLinkFilter<'_>,
290        pagination: Pagination,
291    ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
292        let (sql, arguments) = Query::select()
293            .expr_as(
294                Expr::col((
295                    UpstreamOAuthLinks::Table,
296                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
297                )),
298                LinkLookupIden::UpstreamOauthLinkId,
299            )
300            .expr_as(
301                Expr::col((
302                    UpstreamOAuthLinks::Table,
303                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
304                )),
305                LinkLookupIden::UpstreamOauthProviderId,
306            )
307            .expr_as(
308                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
309                LinkLookupIden::UserId,
310            )
311            .expr_as(
312                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
313                LinkLookupIden::Subject,
314            )
315            .expr_as(
316                Expr::col((
317                    UpstreamOAuthLinks::Table,
318                    UpstreamOAuthLinks::HumanAccountName,
319                )),
320                LinkLookupIden::HumanAccountName,
321            )
322            .expr_as(
323                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
324                LinkLookupIden::CreatedAt,
325            )
326            .from(UpstreamOAuthLinks::Table)
327            .apply_filter(filter)
328            .generate_pagination(
329                (
330                    UpstreamOAuthLinks::Table,
331                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
332                ),
333                pagination,
334            )
335            .build_sqlx(PostgresQueryBuilder);
336
337        let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
338            .traced()
339            .fetch_all(&mut *self.conn)
340            .await?;
341
342        let page = pagination.process(edges).map(UpstreamOAuthLink::from);
343
344        Ok(page)
345    }
346
347    #[tracing::instrument(
348        name = "db.upstream_oauth_link.count",
349        skip_all,
350        fields(
351            db.query.text,
352        ),
353        err,
354    )]
355    async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
356        let (sql, arguments) = Query::select()
357            .expr(
358                Expr::col((
359                    UpstreamOAuthLinks::Table,
360                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
361                ))
362                .count(),
363            )
364            .from(UpstreamOAuthLinks::Table)
365            .apply_filter(filter)
366            .build_sqlx(PostgresQueryBuilder);
367
368        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
369            .traced()
370            .fetch_one(&mut *self.conn)
371            .await?;
372
373        count
374            .try_into()
375            .map_err(DatabaseError::to_invalid_operation)
376    }
377}