1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11 Clock, Page, Pagination,
12 upstream_oauth2::{
13 UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14 },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26 DatabaseError, DatabaseInconsistencyError,
27 filter::{Filter, StatementExt},
28 iden::UpstreamOAuthProviders,
29 pagination::QueryBuilderExt,
30 tracing::ExecuteExt,
31};
32
33pub struct PgUpstreamOAuthProviderRepository<'c> {
36 conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40 pub fn new(conn: &'c mut PgConnection) -> Self {
43 Self { conn }
44 }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50 upstream_oauth_provider_id: Uuid,
51 issuer: Option<String>,
52 human_name: Option<String>,
53 brand_name: Option<String>,
54 scope: String,
55 client_id: String,
56 encrypted_client_secret: Option<String>,
57 token_endpoint_signing_alg: Option<String>,
58 token_endpoint_auth_method: String,
59 id_token_signed_response_alg: String,
60 fetch_userinfo: bool,
61 userinfo_signed_response_alg: Option<String>,
62 created_at: DateTime<Utc>,
63 disabled_at: Option<DateTime<Utc>>,
64 claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65 jwks_uri_override: Option<String>,
66 authorization_endpoint_override: Option<String>,
67 token_endpoint_override: Option<String>,
68 userinfo_endpoint_override: Option<String>,
69 discovery_mode: String,
70 pkce_mode: String,
71 response_mode: Option<String>,
72 additional_parameters: Option<Json<Vec<(String, String)>>>,
73}
74
75impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
76 type Error = DatabaseInconsistencyError;
77
78 #[allow(clippy::too_many_lines)]
79 fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
80 let id = value.upstream_oauth_provider_id.into();
81 let scope = value.scope.parse().map_err(|e| {
82 DatabaseInconsistencyError::on("upstream_oauth_providers")
83 .column("scope")
84 .row(id)
85 .source(e)
86 })?;
87 let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
88 DatabaseInconsistencyError::on("upstream_oauth_providers")
89 .column("token_endpoint_auth_method")
90 .row(id)
91 .source(e)
92 })?;
93 let token_endpoint_signing_alg = value
94 .token_endpoint_signing_alg
95 .map(|x| x.parse())
96 .transpose()
97 .map_err(|e| {
98 DatabaseInconsistencyError::on("upstream_oauth_providers")
99 .column("token_endpoint_signing_alg")
100 .row(id)
101 .source(e)
102 })?;
103 let id_token_signed_response_alg =
104 value.id_token_signed_response_alg.parse().map_err(|e| {
105 DatabaseInconsistencyError::on("upstream_oauth_providers")
106 .column("id_token_signed_response_alg")
107 .row(id)
108 .source(e)
109 })?;
110
111 let userinfo_signed_response_alg = value
112 .userinfo_signed_response_alg
113 .map(|x| x.parse())
114 .transpose()
115 .map_err(|e| {
116 DatabaseInconsistencyError::on("upstream_oauth_providers")
117 .column("userinfo_signed_response_alg")
118 .row(id)
119 .source(e)
120 })?;
121
122 let authorization_endpoint_override = value
123 .authorization_endpoint_override
124 .map(|x| x.parse())
125 .transpose()
126 .map_err(|e| {
127 DatabaseInconsistencyError::on("upstream_oauth_providers")
128 .column("authorization_endpoint_override")
129 .row(id)
130 .source(e)
131 })?;
132
133 let token_endpoint_override = value
134 .token_endpoint_override
135 .map(|x| x.parse())
136 .transpose()
137 .map_err(|e| {
138 DatabaseInconsistencyError::on("upstream_oauth_providers")
139 .column("token_endpoint_override")
140 .row(id)
141 .source(e)
142 })?;
143
144 let userinfo_endpoint_override = value
145 .userinfo_endpoint_override
146 .map(|x| x.parse())
147 .transpose()
148 .map_err(|e| {
149 DatabaseInconsistencyError::on("upstream_oauth_providers")
150 .column("userinfo_endpoint_override")
151 .row(id)
152 .source(e)
153 })?;
154
155 let jwks_uri_override = value
156 .jwks_uri_override
157 .map(|x| x.parse())
158 .transpose()
159 .map_err(|e| {
160 DatabaseInconsistencyError::on("upstream_oauth_providers")
161 .column("jwks_uri_override")
162 .row(id)
163 .source(e)
164 })?;
165
166 let discovery_mode = value.discovery_mode.parse().map_err(|e| {
167 DatabaseInconsistencyError::on("upstream_oauth_providers")
168 .column("discovery_mode")
169 .row(id)
170 .source(e)
171 })?;
172
173 let pkce_mode = value.pkce_mode.parse().map_err(|e| {
174 DatabaseInconsistencyError::on("upstream_oauth_providers")
175 .column("pkce_mode")
176 .row(id)
177 .source(e)
178 })?;
179
180 let response_mode = value
181 .response_mode
182 .map(|x| x.parse())
183 .transpose()
184 .map_err(|e| {
185 DatabaseInconsistencyError::on("upstream_oauth_providers")
186 .column("response_mode")
187 .row(id)
188 .source(e)
189 })?;
190
191 let additional_authorization_parameters = value
192 .additional_parameters
193 .map(|Json(x)| x)
194 .unwrap_or_default();
195
196 Ok(UpstreamOAuthProvider {
197 id,
198 issuer: value.issuer,
199 human_name: value.human_name,
200 brand_name: value.brand_name,
201 scope,
202 client_id: value.client_id,
203 encrypted_client_secret: value.encrypted_client_secret,
204 token_endpoint_auth_method,
205 token_endpoint_signing_alg,
206 id_token_signed_response_alg,
207 fetch_userinfo: value.fetch_userinfo,
208 userinfo_signed_response_alg,
209 created_at: value.created_at,
210 disabled_at: value.disabled_at,
211 claims_imports: value.claims_imports.0,
212 authorization_endpoint_override,
213 token_endpoint_override,
214 userinfo_endpoint_override,
215 jwks_uri_override,
216 discovery_mode,
217 pkce_mode,
218 response_mode,
219 additional_authorization_parameters,
220 })
221 }
222}
223
224impl Filter for UpstreamOAuthProviderFilter<'_> {
225 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
226 sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
227 Expr::col((
228 UpstreamOAuthProviders::Table,
229 UpstreamOAuthProviders::DisabledAt,
230 ))
231 .is_null()
232 .eq(enabled)
233 }))
234 }
235}
236
237#[async_trait]
238impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
239 type Error = DatabaseError;
240
241 #[tracing::instrument(
242 name = "db.upstream_oauth_provider.lookup",
243 skip_all,
244 fields(
245 db.query.text,
246 upstream_oauth_provider.id = %id,
247 ),
248 err,
249 )]
250 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
251 let res = sqlx::query_as!(
252 ProviderLookup,
253 r#"
254 SELECT
255 upstream_oauth_provider_id,
256 issuer,
257 human_name,
258 brand_name,
259 scope,
260 client_id,
261 encrypted_client_secret,
262 token_endpoint_signing_alg,
263 token_endpoint_auth_method,
264 id_token_signed_response_alg,
265 fetch_userinfo,
266 userinfo_signed_response_alg,
267 created_at,
268 disabled_at,
269 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
270 jwks_uri_override,
271 authorization_endpoint_override,
272 token_endpoint_override,
273 userinfo_endpoint_override,
274 discovery_mode,
275 pkce_mode,
276 response_mode,
277 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
278 FROM upstream_oauth_providers
279 WHERE upstream_oauth_provider_id = $1
280 "#,
281 Uuid::from(id),
282 )
283 .traced()
284 .fetch_optional(&mut *self.conn)
285 .await?;
286
287 let res = res
288 .map(UpstreamOAuthProvider::try_from)
289 .transpose()
290 .map_err(DatabaseError::from)?;
291
292 Ok(res)
293 }
294
295 #[tracing::instrument(
296 name = "db.upstream_oauth_provider.add",
297 skip_all,
298 fields(
299 db.query.text,
300 upstream_oauth_provider.id,
301 upstream_oauth_provider.issuer = params.issuer,
302 upstream_oauth_provider.client_id = %params.client_id,
303 ),
304 err,
305 )]
306 async fn add(
307 &mut self,
308 rng: &mut (dyn RngCore + Send),
309 clock: &dyn Clock,
310 params: UpstreamOAuthProviderParams,
311 ) -> Result<UpstreamOAuthProvider, Self::Error> {
312 let created_at = clock.now();
313 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
314 tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
315
316 sqlx::query!(
317 r#"
318 INSERT INTO upstream_oauth_providers (
319 upstream_oauth_provider_id,
320 issuer,
321 human_name,
322 brand_name,
323 scope,
324 token_endpoint_auth_method,
325 token_endpoint_signing_alg,
326 id_token_signed_response_alg,
327 fetch_userinfo,
328 userinfo_signed_response_alg,
329 client_id,
330 encrypted_client_secret,
331 claims_imports,
332 authorization_endpoint_override,
333 token_endpoint_override,
334 userinfo_endpoint_override,
335 jwks_uri_override,
336 discovery_mode,
337 pkce_mode,
338 response_mode,
339 created_at
340 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
341 $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
342 "#,
343 Uuid::from(id),
344 params.issuer.as_deref(),
345 params.human_name.as_deref(),
346 params.brand_name.as_deref(),
347 params.scope.to_string(),
348 params.token_endpoint_auth_method.to_string(),
349 params
350 .token_endpoint_signing_alg
351 .as_ref()
352 .map(ToString::to_string),
353 params.id_token_signed_response_alg.to_string(),
354 params.fetch_userinfo,
355 params
356 .userinfo_signed_response_alg
357 .as_ref()
358 .map(ToString::to_string),
359 ¶ms.client_id,
360 params.encrypted_client_secret.as_deref(),
361 Json(¶ms.claims_imports) as _,
362 params
363 .authorization_endpoint_override
364 .as_ref()
365 .map(ToString::to_string),
366 params
367 .token_endpoint_override
368 .as_ref()
369 .map(ToString::to_string),
370 params
371 .userinfo_endpoint_override
372 .as_ref()
373 .map(ToString::to_string),
374 params.jwks_uri_override.as_ref().map(ToString::to_string),
375 params.discovery_mode.as_str(),
376 params.pkce_mode.as_str(),
377 params.response_mode.as_ref().map(ToString::to_string),
378 created_at,
379 )
380 .traced()
381 .execute(&mut *self.conn)
382 .await?;
383
384 Ok(UpstreamOAuthProvider {
385 id,
386 issuer: params.issuer,
387 human_name: params.human_name,
388 brand_name: params.brand_name,
389 scope: params.scope,
390 client_id: params.client_id,
391 encrypted_client_secret: params.encrypted_client_secret,
392 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
393 token_endpoint_auth_method: params.token_endpoint_auth_method,
394 id_token_signed_response_alg: params.id_token_signed_response_alg,
395 fetch_userinfo: params.fetch_userinfo,
396 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
397 created_at,
398 disabled_at: None,
399 claims_imports: params.claims_imports,
400 authorization_endpoint_override: params.authorization_endpoint_override,
401 token_endpoint_override: params.token_endpoint_override,
402 userinfo_endpoint_override: params.userinfo_endpoint_override,
403 jwks_uri_override: params.jwks_uri_override,
404 discovery_mode: params.discovery_mode,
405 pkce_mode: params.pkce_mode,
406 response_mode: params.response_mode,
407 additional_authorization_parameters: params.additional_authorization_parameters,
408 })
409 }
410
411 #[tracing::instrument(
412 name = "db.upstream_oauth_provider.delete_by_id",
413 skip_all,
414 fields(
415 db.query.text,
416 upstream_oauth_provider.id = %id,
417 ),
418 err,
419 )]
420 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
421 {
424 let span = info_span!(
425 "db.oauth2_client.delete_by_id.authorization_sessions",
426 upstream_oauth_provider.id = %id,
427 { DB_QUERY_TEXT } = tracing::field::Empty,
428 );
429 sqlx::query!(
430 r#"
431 DELETE FROM upstream_oauth_authorization_sessions
432 WHERE upstream_oauth_provider_id = $1
433 "#,
434 Uuid::from(id),
435 )
436 .record(&span)
437 .execute(&mut *self.conn)
438 .instrument(span)
439 .await?;
440 }
441
442 {
445 let span = info_span!(
446 "db.oauth2_client.delete_by_id.links",
447 upstream_oauth_provider.id = %id,
448 { DB_QUERY_TEXT } = tracing::field::Empty,
449 );
450 sqlx::query!(
451 r#"
452 DELETE FROM upstream_oauth_links
453 WHERE upstream_oauth_provider_id = $1
454 "#,
455 Uuid::from(id),
456 )
457 .record(&span)
458 .execute(&mut *self.conn)
459 .instrument(span)
460 .await?;
461 }
462
463 let res = sqlx::query!(
464 r#"
465 DELETE FROM upstream_oauth_providers
466 WHERE upstream_oauth_provider_id = $1
467 "#,
468 Uuid::from(id),
469 )
470 .traced()
471 .execute(&mut *self.conn)
472 .await?;
473
474 DatabaseError::ensure_affected_rows(&res, 1)
475 }
476
477 #[tracing::instrument(
478 name = "db.upstream_oauth_provider.add",
479 skip_all,
480 fields(
481 db.query.text,
482 upstream_oauth_provider.id = %id,
483 upstream_oauth_provider.issuer = params.issuer,
484 upstream_oauth_provider.client_id = %params.client_id,
485 ),
486 err,
487 )]
488 async fn upsert(
489 &mut self,
490 clock: &dyn Clock,
491 id: Ulid,
492 params: UpstreamOAuthProviderParams,
493 ) -> Result<UpstreamOAuthProvider, Self::Error> {
494 let created_at = clock.now();
495
496 let created_at = sqlx::query_scalar!(
497 r#"
498 INSERT INTO upstream_oauth_providers (
499 upstream_oauth_provider_id,
500 issuer,
501 human_name,
502 brand_name,
503 scope,
504 token_endpoint_auth_method,
505 token_endpoint_signing_alg,
506 id_token_signed_response_alg,
507 fetch_userinfo,
508 userinfo_signed_response_alg,
509 client_id,
510 encrypted_client_secret,
511 claims_imports,
512 authorization_endpoint_override,
513 token_endpoint_override,
514 userinfo_endpoint_override,
515 jwks_uri_override,
516 discovery_mode,
517 pkce_mode,
518 response_mode,
519 additional_parameters,
520 ui_order,
521 created_at
522 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
523 $12, $13, $14, $15, $16, $17, $18, $19, $20,
524 $21, $22, $23)
525 ON CONFLICT (upstream_oauth_provider_id)
526 DO UPDATE
527 SET
528 issuer = EXCLUDED.issuer,
529 human_name = EXCLUDED.human_name,
530 brand_name = EXCLUDED.brand_name,
531 scope = EXCLUDED.scope,
532 token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
533 token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
534 id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
535 fetch_userinfo = EXCLUDED.fetch_userinfo,
536 userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
537 disabled_at = NULL,
538 client_id = EXCLUDED.client_id,
539 encrypted_client_secret = EXCLUDED.encrypted_client_secret,
540 claims_imports = EXCLUDED.claims_imports,
541 authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
542 token_endpoint_override = EXCLUDED.token_endpoint_override,
543 userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
544 jwks_uri_override = EXCLUDED.jwks_uri_override,
545 discovery_mode = EXCLUDED.discovery_mode,
546 pkce_mode = EXCLUDED.pkce_mode,
547 response_mode = EXCLUDED.response_mode,
548 additional_parameters = EXCLUDED.additional_parameters,
549 ui_order = EXCLUDED.ui_order
550 RETURNING created_at
551 "#,
552 Uuid::from(id),
553 params.issuer.as_deref(),
554 params.human_name.as_deref(),
555 params.brand_name.as_deref(),
556 params.scope.to_string(),
557 params.token_endpoint_auth_method.to_string(),
558 params
559 .token_endpoint_signing_alg
560 .as_ref()
561 .map(ToString::to_string),
562 params.id_token_signed_response_alg.to_string(),
563 params.fetch_userinfo,
564 params
565 .userinfo_signed_response_alg
566 .as_ref()
567 .map(ToString::to_string),
568 ¶ms.client_id,
569 params.encrypted_client_secret.as_deref(),
570 Json(¶ms.claims_imports) as _,
571 params
572 .authorization_endpoint_override
573 .as_ref()
574 .map(ToString::to_string),
575 params
576 .token_endpoint_override
577 .as_ref()
578 .map(ToString::to_string),
579 params
580 .userinfo_endpoint_override
581 .as_ref()
582 .map(ToString::to_string),
583 params.jwks_uri_override.as_ref().map(ToString::to_string),
584 params.discovery_mode.as_str(),
585 params.pkce_mode.as_str(),
586 params.response_mode.as_ref().map(ToString::to_string),
587 Json(¶ms.additional_authorization_parameters) as _,
588 params.ui_order,
589 created_at,
590 )
591 .traced()
592 .fetch_one(&mut *self.conn)
593 .await?;
594
595 Ok(UpstreamOAuthProvider {
596 id,
597 issuer: params.issuer,
598 human_name: params.human_name,
599 brand_name: params.brand_name,
600 scope: params.scope,
601 client_id: params.client_id,
602 encrypted_client_secret: params.encrypted_client_secret,
603 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
604 token_endpoint_auth_method: params.token_endpoint_auth_method,
605 id_token_signed_response_alg: params.id_token_signed_response_alg,
606 fetch_userinfo: params.fetch_userinfo,
607 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
608 created_at,
609 disabled_at: None,
610 claims_imports: params.claims_imports,
611 authorization_endpoint_override: params.authorization_endpoint_override,
612 token_endpoint_override: params.token_endpoint_override,
613 userinfo_endpoint_override: params.userinfo_endpoint_override,
614 jwks_uri_override: params.jwks_uri_override,
615 discovery_mode: params.discovery_mode,
616 pkce_mode: params.pkce_mode,
617 response_mode: params.response_mode,
618 additional_authorization_parameters: params.additional_authorization_parameters,
619 })
620 }
621
622 #[tracing::instrument(
623 name = "db.upstream_oauth_provider.disable",
624 skip_all,
625 fields(
626 db.query.text,
627 %upstream_oauth_provider.id,
628 ),
629 err,
630 )]
631 async fn disable(
632 &mut self,
633 clock: &dyn Clock,
634 mut upstream_oauth_provider: UpstreamOAuthProvider,
635 ) -> Result<UpstreamOAuthProvider, Self::Error> {
636 let disabled_at = clock.now();
637 let res = sqlx::query!(
638 r#"
639 UPDATE upstream_oauth_providers
640 SET disabled_at = $2
641 WHERE upstream_oauth_provider_id = $1
642 "#,
643 Uuid::from(upstream_oauth_provider.id),
644 disabled_at,
645 )
646 .traced()
647 .execute(&mut *self.conn)
648 .await?;
649
650 DatabaseError::ensure_affected_rows(&res, 1)?;
651
652 upstream_oauth_provider.disabled_at = Some(disabled_at);
653
654 Ok(upstream_oauth_provider)
655 }
656
657 #[tracing::instrument(
658 name = "db.upstream_oauth_provider.list",
659 skip_all,
660 fields(
661 db.query.text,
662 ),
663 err,
664 )]
665 async fn list(
666 &mut self,
667 filter: UpstreamOAuthProviderFilter<'_>,
668 pagination: Pagination,
669 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
670 let (sql, arguments) = Query::select()
671 .expr_as(
672 Expr::col((
673 UpstreamOAuthProviders::Table,
674 UpstreamOAuthProviders::UpstreamOAuthProviderId,
675 )),
676 ProviderLookupIden::UpstreamOauthProviderId,
677 )
678 .expr_as(
679 Expr::col((
680 UpstreamOAuthProviders::Table,
681 UpstreamOAuthProviders::Issuer,
682 )),
683 ProviderLookupIden::Issuer,
684 )
685 .expr_as(
686 Expr::col((
687 UpstreamOAuthProviders::Table,
688 UpstreamOAuthProviders::HumanName,
689 )),
690 ProviderLookupIden::HumanName,
691 )
692 .expr_as(
693 Expr::col((
694 UpstreamOAuthProviders::Table,
695 UpstreamOAuthProviders::BrandName,
696 )),
697 ProviderLookupIden::BrandName,
698 )
699 .expr_as(
700 Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
701 ProviderLookupIden::Scope,
702 )
703 .expr_as(
704 Expr::col((
705 UpstreamOAuthProviders::Table,
706 UpstreamOAuthProviders::ClientId,
707 )),
708 ProviderLookupIden::ClientId,
709 )
710 .expr_as(
711 Expr::col((
712 UpstreamOAuthProviders::Table,
713 UpstreamOAuthProviders::EncryptedClientSecret,
714 )),
715 ProviderLookupIden::EncryptedClientSecret,
716 )
717 .expr_as(
718 Expr::col((
719 UpstreamOAuthProviders::Table,
720 UpstreamOAuthProviders::TokenEndpointSigningAlg,
721 )),
722 ProviderLookupIden::TokenEndpointSigningAlg,
723 )
724 .expr_as(
725 Expr::col((
726 UpstreamOAuthProviders::Table,
727 UpstreamOAuthProviders::TokenEndpointAuthMethod,
728 )),
729 ProviderLookupIden::TokenEndpointAuthMethod,
730 )
731 .expr_as(
732 Expr::col((
733 UpstreamOAuthProviders::Table,
734 UpstreamOAuthProviders::IdTokenSignedResponseAlg,
735 )),
736 ProviderLookupIden::IdTokenSignedResponseAlg,
737 )
738 .expr_as(
739 Expr::col((
740 UpstreamOAuthProviders::Table,
741 UpstreamOAuthProviders::FetchUserinfo,
742 )),
743 ProviderLookupIden::FetchUserinfo,
744 )
745 .expr_as(
746 Expr::col((
747 UpstreamOAuthProviders::Table,
748 UpstreamOAuthProviders::UserinfoSignedResponseAlg,
749 )),
750 ProviderLookupIden::UserinfoSignedResponseAlg,
751 )
752 .expr_as(
753 Expr::col((
754 UpstreamOAuthProviders::Table,
755 UpstreamOAuthProviders::CreatedAt,
756 )),
757 ProviderLookupIden::CreatedAt,
758 )
759 .expr_as(
760 Expr::col((
761 UpstreamOAuthProviders::Table,
762 UpstreamOAuthProviders::DisabledAt,
763 )),
764 ProviderLookupIden::DisabledAt,
765 )
766 .expr_as(
767 Expr::col((
768 UpstreamOAuthProviders::Table,
769 UpstreamOAuthProviders::ClaimsImports,
770 )),
771 ProviderLookupIden::ClaimsImports,
772 )
773 .expr_as(
774 Expr::col((
775 UpstreamOAuthProviders::Table,
776 UpstreamOAuthProviders::JwksUriOverride,
777 )),
778 ProviderLookupIden::JwksUriOverride,
779 )
780 .expr_as(
781 Expr::col((
782 UpstreamOAuthProviders::Table,
783 UpstreamOAuthProviders::TokenEndpointOverride,
784 )),
785 ProviderLookupIden::TokenEndpointOverride,
786 )
787 .expr_as(
788 Expr::col((
789 UpstreamOAuthProviders::Table,
790 UpstreamOAuthProviders::AuthorizationEndpointOverride,
791 )),
792 ProviderLookupIden::AuthorizationEndpointOverride,
793 )
794 .expr_as(
795 Expr::col((
796 UpstreamOAuthProviders::Table,
797 UpstreamOAuthProviders::UserinfoEndpointOverride,
798 )),
799 ProviderLookupIden::UserinfoEndpointOverride,
800 )
801 .expr_as(
802 Expr::col((
803 UpstreamOAuthProviders::Table,
804 UpstreamOAuthProviders::DiscoveryMode,
805 )),
806 ProviderLookupIden::DiscoveryMode,
807 )
808 .expr_as(
809 Expr::col((
810 UpstreamOAuthProviders::Table,
811 UpstreamOAuthProviders::PkceMode,
812 )),
813 ProviderLookupIden::PkceMode,
814 )
815 .expr_as(
816 Expr::col((
817 UpstreamOAuthProviders::Table,
818 UpstreamOAuthProviders::ResponseMode,
819 )),
820 ProviderLookupIden::ResponseMode,
821 )
822 .expr_as(
823 Expr::col((
824 UpstreamOAuthProviders::Table,
825 UpstreamOAuthProviders::AdditionalParameters,
826 )),
827 ProviderLookupIden::AdditionalParameters,
828 )
829 .from(UpstreamOAuthProviders::Table)
830 .apply_filter(filter)
831 .generate_pagination(
832 (
833 UpstreamOAuthProviders::Table,
834 UpstreamOAuthProviders::UpstreamOAuthProviderId,
835 ),
836 pagination,
837 )
838 .build_sqlx(PostgresQueryBuilder);
839
840 let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
841 .traced()
842 .fetch_all(&mut *self.conn)
843 .await?;
844
845 let page = pagination
846 .process(edges)
847 .try_map(UpstreamOAuthProvider::try_from)?;
848
849 return Ok(page);
850 }
851
852 #[tracing::instrument(
853 name = "db.upstream_oauth_provider.count",
854 skip_all,
855 fields(
856 db.query.text,
857 ),
858 err,
859 )]
860 async fn count(
861 &mut self,
862 filter: UpstreamOAuthProviderFilter<'_>,
863 ) -> Result<usize, Self::Error> {
864 let (sql, arguments) = Query::select()
865 .expr(
866 Expr::col((
867 UpstreamOAuthProviders::Table,
868 UpstreamOAuthProviders::UpstreamOAuthProviderId,
869 ))
870 .count(),
871 )
872 .from(UpstreamOAuthProviders::Table)
873 .apply_filter(filter)
874 .build_sqlx(PostgresQueryBuilder);
875
876 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
877 .traced()
878 .fetch_one(&mut *self.conn)
879 .await?;
880
881 count
882 .try_into()
883 .map_err(DatabaseError::to_invalid_operation)
884 }
885
886 #[tracing::instrument(
887 name = "db.upstream_oauth_provider.all_enabled",
888 skip_all,
889 fields(
890 db.query.text,
891 ),
892 err,
893 )]
894 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
895 let res = sqlx::query_as!(
896 ProviderLookup,
897 r#"
898 SELECT
899 upstream_oauth_provider_id,
900 issuer,
901 human_name,
902 brand_name,
903 scope,
904 client_id,
905 encrypted_client_secret,
906 token_endpoint_signing_alg,
907 token_endpoint_auth_method,
908 id_token_signed_response_alg,
909 fetch_userinfo,
910 userinfo_signed_response_alg,
911 created_at,
912 disabled_at,
913 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
914 jwks_uri_override,
915 authorization_endpoint_override,
916 token_endpoint_override,
917 userinfo_endpoint_override,
918 discovery_mode,
919 pkce_mode,
920 response_mode,
921 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
922 FROM upstream_oauth_providers
923 WHERE disabled_at IS NULL
924 ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
925 "#,
926 )
927 .traced()
928 .fetch_all(&mut *self.conn)
929 .await?;
930
931 let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
932 Ok(res?)
933 }
934}