mas_storage_pg/upstream_oauth2/
provider.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11    Clock, Page, Pagination,
12    upstream_oauth2::{
13        UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14    },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26    DatabaseError, DatabaseInconsistencyError,
27    filter::{Filter, StatementExt},
28    iden::UpstreamOAuthProviders,
29    pagination::QueryBuilderExt,
30    tracing::ExecuteExt,
31};
32
33/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
34/// connection
35pub struct PgUpstreamOAuthProviderRepository<'c> {
36    conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40    /// Create a new [`PgUpstreamOAuthProviderRepository`] from an active
41    /// PostgreSQL connection
42    pub fn new(conn: &'c mut PgConnection) -> Self {
43        Self { conn }
44    }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50    upstream_oauth_provider_id: Uuid,
51    issuer: Option<String>,
52    human_name: Option<String>,
53    brand_name: Option<String>,
54    scope: String,
55    client_id: String,
56    encrypted_client_secret: Option<String>,
57    token_endpoint_signing_alg: Option<String>,
58    token_endpoint_auth_method: String,
59    id_token_signed_response_alg: String,
60    fetch_userinfo: bool,
61    userinfo_signed_response_alg: Option<String>,
62    created_at: DateTime<Utc>,
63    disabled_at: Option<DateTime<Utc>>,
64    claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65    jwks_uri_override: Option<String>,
66    authorization_endpoint_override: Option<String>,
67    token_endpoint_override: Option<String>,
68    userinfo_endpoint_override: Option<String>,
69    discovery_mode: String,
70    pkce_mode: String,
71    response_mode: Option<String>,
72    additional_parameters: Option<Json<Vec<(String, String)>>>,
73}
74
75impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
76    type Error = DatabaseInconsistencyError;
77
78    #[allow(clippy::too_many_lines)]
79    fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
80        let id = value.upstream_oauth_provider_id.into();
81        let scope = value.scope.parse().map_err(|e| {
82            DatabaseInconsistencyError::on("upstream_oauth_providers")
83                .column("scope")
84                .row(id)
85                .source(e)
86        })?;
87        let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
88            DatabaseInconsistencyError::on("upstream_oauth_providers")
89                .column("token_endpoint_auth_method")
90                .row(id)
91                .source(e)
92        })?;
93        let token_endpoint_signing_alg = value
94            .token_endpoint_signing_alg
95            .map(|x| x.parse())
96            .transpose()
97            .map_err(|e| {
98                DatabaseInconsistencyError::on("upstream_oauth_providers")
99                    .column("token_endpoint_signing_alg")
100                    .row(id)
101                    .source(e)
102            })?;
103        let id_token_signed_response_alg =
104            value.id_token_signed_response_alg.parse().map_err(|e| {
105                DatabaseInconsistencyError::on("upstream_oauth_providers")
106                    .column("id_token_signed_response_alg")
107                    .row(id)
108                    .source(e)
109            })?;
110
111        let userinfo_signed_response_alg = value
112            .userinfo_signed_response_alg
113            .map(|x| x.parse())
114            .transpose()
115            .map_err(|e| {
116                DatabaseInconsistencyError::on("upstream_oauth_providers")
117                    .column("userinfo_signed_response_alg")
118                    .row(id)
119                    .source(e)
120            })?;
121
122        let authorization_endpoint_override = value
123            .authorization_endpoint_override
124            .map(|x| x.parse())
125            .transpose()
126            .map_err(|e| {
127                DatabaseInconsistencyError::on("upstream_oauth_providers")
128                    .column("authorization_endpoint_override")
129                    .row(id)
130                    .source(e)
131            })?;
132
133        let token_endpoint_override = value
134            .token_endpoint_override
135            .map(|x| x.parse())
136            .transpose()
137            .map_err(|e| {
138                DatabaseInconsistencyError::on("upstream_oauth_providers")
139                    .column("token_endpoint_override")
140                    .row(id)
141                    .source(e)
142            })?;
143
144        let userinfo_endpoint_override = value
145            .userinfo_endpoint_override
146            .map(|x| x.parse())
147            .transpose()
148            .map_err(|e| {
149                DatabaseInconsistencyError::on("upstream_oauth_providers")
150                    .column("userinfo_endpoint_override")
151                    .row(id)
152                    .source(e)
153            })?;
154
155        let jwks_uri_override = value
156            .jwks_uri_override
157            .map(|x| x.parse())
158            .transpose()
159            .map_err(|e| {
160                DatabaseInconsistencyError::on("upstream_oauth_providers")
161                    .column("jwks_uri_override")
162                    .row(id)
163                    .source(e)
164            })?;
165
166        let discovery_mode = value.discovery_mode.parse().map_err(|e| {
167            DatabaseInconsistencyError::on("upstream_oauth_providers")
168                .column("discovery_mode")
169                .row(id)
170                .source(e)
171        })?;
172
173        let pkce_mode = value.pkce_mode.parse().map_err(|e| {
174            DatabaseInconsistencyError::on("upstream_oauth_providers")
175                .column("pkce_mode")
176                .row(id)
177                .source(e)
178        })?;
179
180        let response_mode = value
181            .response_mode
182            .map(|x| x.parse())
183            .transpose()
184            .map_err(|e| {
185                DatabaseInconsistencyError::on("upstream_oauth_providers")
186                    .column("response_mode")
187                    .row(id)
188                    .source(e)
189            })?;
190
191        let additional_authorization_parameters = value
192            .additional_parameters
193            .map(|Json(x)| x)
194            .unwrap_or_default();
195
196        Ok(UpstreamOAuthProvider {
197            id,
198            issuer: value.issuer,
199            human_name: value.human_name,
200            brand_name: value.brand_name,
201            scope,
202            client_id: value.client_id,
203            encrypted_client_secret: value.encrypted_client_secret,
204            token_endpoint_auth_method,
205            token_endpoint_signing_alg,
206            id_token_signed_response_alg,
207            fetch_userinfo: value.fetch_userinfo,
208            userinfo_signed_response_alg,
209            created_at: value.created_at,
210            disabled_at: value.disabled_at,
211            claims_imports: value.claims_imports.0,
212            authorization_endpoint_override,
213            token_endpoint_override,
214            userinfo_endpoint_override,
215            jwks_uri_override,
216            discovery_mode,
217            pkce_mode,
218            response_mode,
219            additional_authorization_parameters,
220        })
221    }
222}
223
224impl Filter for UpstreamOAuthProviderFilter<'_> {
225    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
226        sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
227            Expr::col((
228                UpstreamOAuthProviders::Table,
229                UpstreamOAuthProviders::DisabledAt,
230            ))
231            .is_null()
232            .eq(enabled)
233        }))
234    }
235}
236
237#[async_trait]
238impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
239    type Error = DatabaseError;
240
241    #[tracing::instrument(
242        name = "db.upstream_oauth_provider.lookup",
243        skip_all,
244        fields(
245            db.query.text,
246            upstream_oauth_provider.id = %id,
247        ),
248        err,
249    )]
250    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
251        let res = sqlx::query_as!(
252            ProviderLookup,
253            r#"
254                SELECT
255                    upstream_oauth_provider_id,
256                    issuer,
257                    human_name,
258                    brand_name,
259                    scope,
260                    client_id,
261                    encrypted_client_secret,
262                    token_endpoint_signing_alg,
263                    token_endpoint_auth_method,
264                    id_token_signed_response_alg,
265                    fetch_userinfo,
266                    userinfo_signed_response_alg,
267                    created_at,
268                    disabled_at,
269                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
270                    jwks_uri_override,
271                    authorization_endpoint_override,
272                    token_endpoint_override,
273                    userinfo_endpoint_override,
274                    discovery_mode,
275                    pkce_mode,
276                    response_mode,
277                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
278                FROM upstream_oauth_providers
279                WHERE upstream_oauth_provider_id = $1
280            "#,
281            Uuid::from(id),
282        )
283        .traced()
284        .fetch_optional(&mut *self.conn)
285        .await?;
286
287        let res = res
288            .map(UpstreamOAuthProvider::try_from)
289            .transpose()
290            .map_err(DatabaseError::from)?;
291
292        Ok(res)
293    }
294
295    #[tracing::instrument(
296        name = "db.upstream_oauth_provider.add",
297        skip_all,
298        fields(
299            db.query.text,
300            upstream_oauth_provider.id,
301            upstream_oauth_provider.issuer = params.issuer,
302            upstream_oauth_provider.client_id = %params.client_id,
303        ),
304        err,
305    )]
306    async fn add(
307        &mut self,
308        rng: &mut (dyn RngCore + Send),
309        clock: &dyn Clock,
310        params: UpstreamOAuthProviderParams,
311    ) -> Result<UpstreamOAuthProvider, Self::Error> {
312        let created_at = clock.now();
313        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
314        tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
315
316        sqlx::query!(
317            r#"
318            INSERT INTO upstream_oauth_providers (
319                upstream_oauth_provider_id,
320                issuer,
321                human_name,
322                brand_name,
323                scope,
324                token_endpoint_auth_method,
325                token_endpoint_signing_alg,
326                id_token_signed_response_alg,
327                fetch_userinfo,
328                userinfo_signed_response_alg,
329                client_id,
330                encrypted_client_secret,
331                claims_imports,
332                authorization_endpoint_override,
333                token_endpoint_override,
334                userinfo_endpoint_override,
335                jwks_uri_override,
336                discovery_mode,
337                pkce_mode,
338                response_mode,
339                created_at
340            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
341                      $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
342        "#,
343            Uuid::from(id),
344            params.issuer.as_deref(),
345            params.human_name.as_deref(),
346            params.brand_name.as_deref(),
347            params.scope.to_string(),
348            params.token_endpoint_auth_method.to_string(),
349            params
350                .token_endpoint_signing_alg
351                .as_ref()
352                .map(ToString::to_string),
353            params.id_token_signed_response_alg.to_string(),
354            params.fetch_userinfo,
355            params
356                .userinfo_signed_response_alg
357                .as_ref()
358                .map(ToString::to_string),
359            &params.client_id,
360            params.encrypted_client_secret.as_deref(),
361            Json(&params.claims_imports) as _,
362            params
363                .authorization_endpoint_override
364                .as_ref()
365                .map(ToString::to_string),
366            params
367                .token_endpoint_override
368                .as_ref()
369                .map(ToString::to_string),
370            params
371                .userinfo_endpoint_override
372                .as_ref()
373                .map(ToString::to_string),
374            params.jwks_uri_override.as_ref().map(ToString::to_string),
375            params.discovery_mode.as_str(),
376            params.pkce_mode.as_str(),
377            params.response_mode.as_ref().map(ToString::to_string),
378            created_at,
379        )
380        .traced()
381        .execute(&mut *self.conn)
382        .await?;
383
384        Ok(UpstreamOAuthProvider {
385            id,
386            issuer: params.issuer,
387            human_name: params.human_name,
388            brand_name: params.brand_name,
389            scope: params.scope,
390            client_id: params.client_id,
391            encrypted_client_secret: params.encrypted_client_secret,
392            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
393            token_endpoint_auth_method: params.token_endpoint_auth_method,
394            id_token_signed_response_alg: params.id_token_signed_response_alg,
395            fetch_userinfo: params.fetch_userinfo,
396            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
397            created_at,
398            disabled_at: None,
399            claims_imports: params.claims_imports,
400            authorization_endpoint_override: params.authorization_endpoint_override,
401            token_endpoint_override: params.token_endpoint_override,
402            userinfo_endpoint_override: params.userinfo_endpoint_override,
403            jwks_uri_override: params.jwks_uri_override,
404            discovery_mode: params.discovery_mode,
405            pkce_mode: params.pkce_mode,
406            response_mode: params.response_mode,
407            additional_authorization_parameters: params.additional_authorization_parameters,
408        })
409    }
410
411    #[tracing::instrument(
412        name = "db.upstream_oauth_provider.delete_by_id",
413        skip_all,
414        fields(
415            db.query.text,
416            upstream_oauth_provider.id = %id,
417        ),
418        err,
419    )]
420    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
421        // Delete the authorization sessions first, as they have a foreign key
422        // constraint on the links and the providers.
423        {
424            let span = info_span!(
425                "db.oauth2_client.delete_by_id.authorization_sessions",
426                upstream_oauth_provider.id = %id,
427                { DB_QUERY_TEXT } = tracing::field::Empty,
428            );
429            sqlx::query!(
430                r#"
431                    DELETE FROM upstream_oauth_authorization_sessions
432                    WHERE upstream_oauth_provider_id = $1
433                "#,
434                Uuid::from(id),
435            )
436            .record(&span)
437            .execute(&mut *self.conn)
438            .instrument(span)
439            .await?;
440        }
441
442        // Delete the links next, as they have a foreign key constraint on the
443        // providers.
444        {
445            let span = info_span!(
446                "db.oauth2_client.delete_by_id.links",
447                upstream_oauth_provider.id = %id,
448                { DB_QUERY_TEXT } = tracing::field::Empty,
449            );
450            sqlx::query!(
451                r#"
452                    DELETE FROM upstream_oauth_links
453                    WHERE upstream_oauth_provider_id = $1
454                "#,
455                Uuid::from(id),
456            )
457            .record(&span)
458            .execute(&mut *self.conn)
459            .instrument(span)
460            .await?;
461        }
462
463        let res = sqlx::query!(
464            r#"
465                DELETE FROM upstream_oauth_providers
466                WHERE upstream_oauth_provider_id = $1
467            "#,
468            Uuid::from(id),
469        )
470        .traced()
471        .execute(&mut *self.conn)
472        .await?;
473
474        DatabaseError::ensure_affected_rows(&res, 1)
475    }
476
477    #[tracing::instrument(
478        name = "db.upstream_oauth_provider.add",
479        skip_all,
480        fields(
481            db.query.text,
482            upstream_oauth_provider.id = %id,
483            upstream_oauth_provider.issuer = params.issuer,
484            upstream_oauth_provider.client_id = %params.client_id,
485        ),
486        err,
487    )]
488    async fn upsert(
489        &mut self,
490        clock: &dyn Clock,
491        id: Ulid,
492        params: UpstreamOAuthProviderParams,
493    ) -> Result<UpstreamOAuthProvider, Self::Error> {
494        let created_at = clock.now();
495
496        let created_at = sqlx::query_scalar!(
497            r#"
498                INSERT INTO upstream_oauth_providers (
499                    upstream_oauth_provider_id,
500                    issuer,
501                    human_name,
502                    brand_name,
503                    scope,
504                    token_endpoint_auth_method,
505                    token_endpoint_signing_alg,
506                    id_token_signed_response_alg,
507                    fetch_userinfo,
508                    userinfo_signed_response_alg,
509                    client_id,
510                    encrypted_client_secret,
511                    claims_imports,
512                    authorization_endpoint_override,
513                    token_endpoint_override,
514                    userinfo_endpoint_override,
515                    jwks_uri_override,
516                    discovery_mode,
517                    pkce_mode,
518                    response_mode,
519                    additional_parameters,
520                    ui_order,
521                    created_at
522                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
523                          $12, $13, $14, $15, $16, $17, $18, $19, $20,
524                          $21, $22, $23)
525                ON CONFLICT (upstream_oauth_provider_id)
526                    DO UPDATE
527                    SET
528                        issuer = EXCLUDED.issuer,
529                        human_name = EXCLUDED.human_name,
530                        brand_name = EXCLUDED.brand_name,
531                        scope = EXCLUDED.scope,
532                        token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
533                        token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
534                        id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
535                        fetch_userinfo = EXCLUDED.fetch_userinfo,
536                        userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
537                        disabled_at = NULL,
538                        client_id = EXCLUDED.client_id,
539                        encrypted_client_secret = EXCLUDED.encrypted_client_secret,
540                        claims_imports = EXCLUDED.claims_imports,
541                        authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
542                        token_endpoint_override = EXCLUDED.token_endpoint_override,
543                        userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
544                        jwks_uri_override = EXCLUDED.jwks_uri_override,
545                        discovery_mode = EXCLUDED.discovery_mode,
546                        pkce_mode = EXCLUDED.pkce_mode,
547                        response_mode = EXCLUDED.response_mode,
548                        additional_parameters = EXCLUDED.additional_parameters,
549                        ui_order = EXCLUDED.ui_order
550                RETURNING created_at
551            "#,
552            Uuid::from(id),
553            params.issuer.as_deref(),
554            params.human_name.as_deref(),
555            params.brand_name.as_deref(),
556            params.scope.to_string(),
557            params.token_endpoint_auth_method.to_string(),
558            params
559                .token_endpoint_signing_alg
560                .as_ref()
561                .map(ToString::to_string),
562            params.id_token_signed_response_alg.to_string(),
563            params.fetch_userinfo,
564            params
565                .userinfo_signed_response_alg
566                .as_ref()
567                .map(ToString::to_string),
568            &params.client_id,
569            params.encrypted_client_secret.as_deref(),
570            Json(&params.claims_imports) as _,
571            params
572                .authorization_endpoint_override
573                .as_ref()
574                .map(ToString::to_string),
575            params
576                .token_endpoint_override
577                .as_ref()
578                .map(ToString::to_string),
579            params
580                .userinfo_endpoint_override
581                .as_ref()
582                .map(ToString::to_string),
583            params.jwks_uri_override.as_ref().map(ToString::to_string),
584            params.discovery_mode.as_str(),
585            params.pkce_mode.as_str(),
586            params.response_mode.as_ref().map(ToString::to_string),
587            Json(&params.additional_authorization_parameters) as _,
588            params.ui_order,
589            created_at,
590        )
591        .traced()
592        .fetch_one(&mut *self.conn)
593        .await?;
594
595        Ok(UpstreamOAuthProvider {
596            id,
597            issuer: params.issuer,
598            human_name: params.human_name,
599            brand_name: params.brand_name,
600            scope: params.scope,
601            client_id: params.client_id,
602            encrypted_client_secret: params.encrypted_client_secret,
603            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
604            token_endpoint_auth_method: params.token_endpoint_auth_method,
605            id_token_signed_response_alg: params.id_token_signed_response_alg,
606            fetch_userinfo: params.fetch_userinfo,
607            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
608            created_at,
609            disabled_at: None,
610            claims_imports: params.claims_imports,
611            authorization_endpoint_override: params.authorization_endpoint_override,
612            token_endpoint_override: params.token_endpoint_override,
613            userinfo_endpoint_override: params.userinfo_endpoint_override,
614            jwks_uri_override: params.jwks_uri_override,
615            discovery_mode: params.discovery_mode,
616            pkce_mode: params.pkce_mode,
617            response_mode: params.response_mode,
618            additional_authorization_parameters: params.additional_authorization_parameters,
619        })
620    }
621
622    #[tracing::instrument(
623        name = "db.upstream_oauth_provider.disable",
624        skip_all,
625        fields(
626            db.query.text,
627            %upstream_oauth_provider.id,
628        ),
629        err,
630    )]
631    async fn disable(
632        &mut self,
633        clock: &dyn Clock,
634        mut upstream_oauth_provider: UpstreamOAuthProvider,
635    ) -> Result<UpstreamOAuthProvider, Self::Error> {
636        let disabled_at = clock.now();
637        let res = sqlx::query!(
638            r#"
639                UPDATE upstream_oauth_providers
640                SET disabled_at = $2
641                WHERE upstream_oauth_provider_id = $1
642            "#,
643            Uuid::from(upstream_oauth_provider.id),
644            disabled_at,
645        )
646        .traced()
647        .execute(&mut *self.conn)
648        .await?;
649
650        DatabaseError::ensure_affected_rows(&res, 1)?;
651
652        upstream_oauth_provider.disabled_at = Some(disabled_at);
653
654        Ok(upstream_oauth_provider)
655    }
656
657    #[tracing::instrument(
658        name = "db.upstream_oauth_provider.list",
659        skip_all,
660        fields(
661            db.query.text,
662        ),
663        err,
664    )]
665    async fn list(
666        &mut self,
667        filter: UpstreamOAuthProviderFilter<'_>,
668        pagination: Pagination,
669    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
670        let (sql, arguments) = Query::select()
671            .expr_as(
672                Expr::col((
673                    UpstreamOAuthProviders::Table,
674                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
675                )),
676                ProviderLookupIden::UpstreamOauthProviderId,
677            )
678            .expr_as(
679                Expr::col((
680                    UpstreamOAuthProviders::Table,
681                    UpstreamOAuthProviders::Issuer,
682                )),
683                ProviderLookupIden::Issuer,
684            )
685            .expr_as(
686                Expr::col((
687                    UpstreamOAuthProviders::Table,
688                    UpstreamOAuthProviders::HumanName,
689                )),
690                ProviderLookupIden::HumanName,
691            )
692            .expr_as(
693                Expr::col((
694                    UpstreamOAuthProviders::Table,
695                    UpstreamOAuthProviders::BrandName,
696                )),
697                ProviderLookupIden::BrandName,
698            )
699            .expr_as(
700                Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
701                ProviderLookupIden::Scope,
702            )
703            .expr_as(
704                Expr::col((
705                    UpstreamOAuthProviders::Table,
706                    UpstreamOAuthProviders::ClientId,
707                )),
708                ProviderLookupIden::ClientId,
709            )
710            .expr_as(
711                Expr::col((
712                    UpstreamOAuthProviders::Table,
713                    UpstreamOAuthProviders::EncryptedClientSecret,
714                )),
715                ProviderLookupIden::EncryptedClientSecret,
716            )
717            .expr_as(
718                Expr::col((
719                    UpstreamOAuthProviders::Table,
720                    UpstreamOAuthProviders::TokenEndpointSigningAlg,
721                )),
722                ProviderLookupIden::TokenEndpointSigningAlg,
723            )
724            .expr_as(
725                Expr::col((
726                    UpstreamOAuthProviders::Table,
727                    UpstreamOAuthProviders::TokenEndpointAuthMethod,
728                )),
729                ProviderLookupIden::TokenEndpointAuthMethod,
730            )
731            .expr_as(
732                Expr::col((
733                    UpstreamOAuthProviders::Table,
734                    UpstreamOAuthProviders::IdTokenSignedResponseAlg,
735                )),
736                ProviderLookupIden::IdTokenSignedResponseAlg,
737            )
738            .expr_as(
739                Expr::col((
740                    UpstreamOAuthProviders::Table,
741                    UpstreamOAuthProviders::FetchUserinfo,
742                )),
743                ProviderLookupIden::FetchUserinfo,
744            )
745            .expr_as(
746                Expr::col((
747                    UpstreamOAuthProviders::Table,
748                    UpstreamOAuthProviders::UserinfoSignedResponseAlg,
749                )),
750                ProviderLookupIden::UserinfoSignedResponseAlg,
751            )
752            .expr_as(
753                Expr::col((
754                    UpstreamOAuthProviders::Table,
755                    UpstreamOAuthProviders::CreatedAt,
756                )),
757                ProviderLookupIden::CreatedAt,
758            )
759            .expr_as(
760                Expr::col((
761                    UpstreamOAuthProviders::Table,
762                    UpstreamOAuthProviders::DisabledAt,
763                )),
764                ProviderLookupIden::DisabledAt,
765            )
766            .expr_as(
767                Expr::col((
768                    UpstreamOAuthProviders::Table,
769                    UpstreamOAuthProviders::ClaimsImports,
770                )),
771                ProviderLookupIden::ClaimsImports,
772            )
773            .expr_as(
774                Expr::col((
775                    UpstreamOAuthProviders::Table,
776                    UpstreamOAuthProviders::JwksUriOverride,
777                )),
778                ProviderLookupIden::JwksUriOverride,
779            )
780            .expr_as(
781                Expr::col((
782                    UpstreamOAuthProviders::Table,
783                    UpstreamOAuthProviders::TokenEndpointOverride,
784                )),
785                ProviderLookupIden::TokenEndpointOverride,
786            )
787            .expr_as(
788                Expr::col((
789                    UpstreamOAuthProviders::Table,
790                    UpstreamOAuthProviders::AuthorizationEndpointOverride,
791                )),
792                ProviderLookupIden::AuthorizationEndpointOverride,
793            )
794            .expr_as(
795                Expr::col((
796                    UpstreamOAuthProviders::Table,
797                    UpstreamOAuthProviders::UserinfoEndpointOverride,
798                )),
799                ProviderLookupIden::UserinfoEndpointOverride,
800            )
801            .expr_as(
802                Expr::col((
803                    UpstreamOAuthProviders::Table,
804                    UpstreamOAuthProviders::DiscoveryMode,
805                )),
806                ProviderLookupIden::DiscoveryMode,
807            )
808            .expr_as(
809                Expr::col((
810                    UpstreamOAuthProviders::Table,
811                    UpstreamOAuthProviders::PkceMode,
812                )),
813                ProviderLookupIden::PkceMode,
814            )
815            .expr_as(
816                Expr::col((
817                    UpstreamOAuthProviders::Table,
818                    UpstreamOAuthProviders::ResponseMode,
819                )),
820                ProviderLookupIden::ResponseMode,
821            )
822            .expr_as(
823                Expr::col((
824                    UpstreamOAuthProviders::Table,
825                    UpstreamOAuthProviders::AdditionalParameters,
826                )),
827                ProviderLookupIden::AdditionalParameters,
828            )
829            .from(UpstreamOAuthProviders::Table)
830            .apply_filter(filter)
831            .generate_pagination(
832                (
833                    UpstreamOAuthProviders::Table,
834                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
835                ),
836                pagination,
837            )
838            .build_sqlx(PostgresQueryBuilder);
839
840        let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
841            .traced()
842            .fetch_all(&mut *self.conn)
843            .await?;
844
845        let page = pagination
846            .process(edges)
847            .try_map(UpstreamOAuthProvider::try_from)?;
848
849        return Ok(page);
850    }
851
852    #[tracing::instrument(
853        name = "db.upstream_oauth_provider.count",
854        skip_all,
855        fields(
856            db.query.text,
857        ),
858        err,
859    )]
860    async fn count(
861        &mut self,
862        filter: UpstreamOAuthProviderFilter<'_>,
863    ) -> Result<usize, Self::Error> {
864        let (sql, arguments) = Query::select()
865            .expr(
866                Expr::col((
867                    UpstreamOAuthProviders::Table,
868                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
869                ))
870                .count(),
871            )
872            .from(UpstreamOAuthProviders::Table)
873            .apply_filter(filter)
874            .build_sqlx(PostgresQueryBuilder);
875
876        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
877            .traced()
878            .fetch_one(&mut *self.conn)
879            .await?;
880
881        count
882            .try_into()
883            .map_err(DatabaseError::to_invalid_operation)
884    }
885
886    #[tracing::instrument(
887        name = "db.upstream_oauth_provider.all_enabled",
888        skip_all,
889        fields(
890            db.query.text,
891        ),
892        err,
893    )]
894    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
895        let res = sqlx::query_as!(
896            ProviderLookup,
897            r#"
898                SELECT
899                    upstream_oauth_provider_id,
900                    issuer,
901                    human_name,
902                    brand_name,
903                    scope,
904                    client_id,
905                    encrypted_client_secret,
906                    token_endpoint_signing_alg,
907                    token_endpoint_auth_method,
908                    id_token_signed_response_alg,
909                    fetch_userinfo,
910                    userinfo_signed_response_alg,
911                    created_at,
912                    disabled_at,
913                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
914                    jwks_uri_override,
915                    authorization_endpoint_override,
916                    token_endpoint_override,
917                    userinfo_endpoint_override,
918                    discovery_mode,
919                    pkce_mode,
920                    response_mode,
921                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
922                FROM upstream_oauth_providers
923                WHERE disabled_at IS NULL
924                ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
925            "#,
926        )
927        .traced()
928        .fetch_all(&mut *self.conn)
929        .await?;
930
931        let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
932        Ok(res?)
933    }
934}