mas_storage_pg/oauth2/
client.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    str::FromStr,
10    string::ToString,
11};
12
13use async_trait::async_trait;
14use mas_data_model::{Client, JwksOrJwksUri, User};
15use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
16use mas_jose::jwk::PublicJsonWebKeySet;
17use mas_storage::{Clock, oauth2::OAuth2ClientRepository};
18use oauth2_types::{
19    oidc::ApplicationType,
20    requests::GrantType,
21    scope::{Scope, ScopeToken},
22};
23use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
24use rand::RngCore;
25use sqlx::PgConnection;
26use tracing::{Instrument, info_span};
27use ulid::Ulid;
28use url::Url;
29use uuid::Uuid;
30
31use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
32
33/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
34pub struct PgOAuth2ClientRepository<'c> {
35    conn: &'c mut PgConnection,
36}
37
38impl<'c> PgOAuth2ClientRepository<'c> {
39    /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
40    /// connection
41    pub fn new(conn: &'c mut PgConnection) -> Self {
42        Self { conn }
43    }
44}
45
46#[allow(clippy::struct_excessive_bools)]
47#[derive(Debug)]
48struct OAuth2ClientLookup {
49    oauth2_client_id: Uuid,
50    metadata_digest: Option<String>,
51    encrypted_client_secret: Option<String>,
52    application_type: Option<String>,
53    redirect_uris: Vec<String>,
54    grant_type_authorization_code: bool,
55    grant_type_refresh_token: bool,
56    grant_type_client_credentials: bool,
57    grant_type_device_code: bool,
58    client_name: Option<String>,
59    logo_uri: Option<String>,
60    client_uri: Option<String>,
61    policy_uri: Option<String>,
62    tos_uri: Option<String>,
63    jwks_uri: Option<String>,
64    jwks: Option<serde_json::Value>,
65    id_token_signed_response_alg: Option<String>,
66    userinfo_signed_response_alg: Option<String>,
67    token_endpoint_auth_method: Option<String>,
68    token_endpoint_auth_signing_alg: Option<String>,
69    initiate_login_uri: Option<String>,
70}
71
72impl TryInto<Client> for OAuth2ClientLookup {
73    type Error = DatabaseInconsistencyError;
74
75    #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
76    fn try_into(self) -> Result<Client, Self::Error> {
77        let id = Ulid::from(self.oauth2_client_id);
78
79        let redirect_uris: Result<Vec<Url>, _> =
80            self.redirect_uris.iter().map(|s| s.parse()).collect();
81        let redirect_uris = redirect_uris.map_err(|e| {
82            DatabaseInconsistencyError::on("oauth2_clients")
83                .column("redirect_uris")
84                .row(id)
85                .source(e)
86        })?;
87
88        let application_type = self
89            .application_type
90            .map(|s| s.parse())
91            .transpose()
92            .map_err(|e| {
93                DatabaseInconsistencyError::on("oauth2_clients")
94                    .column("application_type")
95                    .row(id)
96                    .source(e)
97            })?;
98
99        let mut grant_types = Vec::new();
100        if self.grant_type_authorization_code {
101            grant_types.push(GrantType::AuthorizationCode);
102        }
103        if self.grant_type_refresh_token {
104            grant_types.push(GrantType::RefreshToken);
105        }
106        if self.grant_type_client_credentials {
107            grant_types.push(GrantType::ClientCredentials);
108        }
109        if self.grant_type_device_code {
110            grant_types.push(GrantType::DeviceCode);
111        }
112
113        let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
114            DatabaseInconsistencyError::on("oauth2_clients")
115                .column("logo_uri")
116                .row(id)
117                .source(e)
118        })?;
119
120        let client_uri = self
121            .client_uri
122            .map(|s| s.parse())
123            .transpose()
124            .map_err(|e| {
125                DatabaseInconsistencyError::on("oauth2_clients")
126                    .column("client_uri")
127                    .row(id)
128                    .source(e)
129            })?;
130
131        let policy_uri = self
132            .policy_uri
133            .map(|s| s.parse())
134            .transpose()
135            .map_err(|e| {
136                DatabaseInconsistencyError::on("oauth2_clients")
137                    .column("policy_uri")
138                    .row(id)
139                    .source(e)
140            })?;
141
142        let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
143            DatabaseInconsistencyError::on("oauth2_clients")
144                .column("tos_uri")
145                .row(id)
146                .source(e)
147        })?;
148
149        let id_token_signed_response_alg = self
150            .id_token_signed_response_alg
151            .map(|s| s.parse())
152            .transpose()
153            .map_err(|e| {
154                DatabaseInconsistencyError::on("oauth2_clients")
155                    .column("id_token_signed_response_alg")
156                    .row(id)
157                    .source(e)
158            })?;
159
160        let userinfo_signed_response_alg = self
161            .userinfo_signed_response_alg
162            .map(|s| s.parse())
163            .transpose()
164            .map_err(|e| {
165                DatabaseInconsistencyError::on("oauth2_clients")
166                    .column("userinfo_signed_response_alg")
167                    .row(id)
168                    .source(e)
169            })?;
170
171        let token_endpoint_auth_method = self
172            .token_endpoint_auth_method
173            .map(|s| s.parse())
174            .transpose()
175            .map_err(|e| {
176                DatabaseInconsistencyError::on("oauth2_clients")
177                    .column("token_endpoint_auth_method")
178                    .row(id)
179                    .source(e)
180            })?;
181
182        let token_endpoint_auth_signing_alg = self
183            .token_endpoint_auth_signing_alg
184            .map(|s| s.parse())
185            .transpose()
186            .map_err(|e| {
187                DatabaseInconsistencyError::on("oauth2_clients")
188                    .column("token_endpoint_auth_signing_alg")
189                    .row(id)
190                    .source(e)
191            })?;
192
193        let initiate_login_uri = self
194            .initiate_login_uri
195            .map(|s| s.parse())
196            .transpose()
197            .map_err(|e| {
198                DatabaseInconsistencyError::on("oauth2_clients")
199                    .column("initiate_login_uri")
200                    .row(id)
201                    .source(e)
202            })?;
203
204        let jwks = match (self.jwks, self.jwks_uri) {
205            (None, None) => None,
206            (Some(jwks), None) => {
207                let jwks = serde_json::from_value(jwks).map_err(|e| {
208                    DatabaseInconsistencyError::on("oauth2_clients")
209                        .column("jwks")
210                        .row(id)
211                        .source(e)
212                })?;
213                Some(JwksOrJwksUri::Jwks(jwks))
214            }
215            (None, Some(jwks_uri)) => {
216                let jwks_uri = jwks_uri.parse().map_err(|e| {
217                    DatabaseInconsistencyError::on("oauth2_clients")
218                        .column("jwks_uri")
219                        .row(id)
220                        .source(e)
221                })?;
222
223                Some(JwksOrJwksUri::JwksUri(jwks_uri))
224            }
225            _ => {
226                return Err(DatabaseInconsistencyError::on("oauth2_clients")
227                    .column("jwks(_uri)")
228                    .row(id));
229            }
230        };
231
232        Ok(Client {
233            id,
234            client_id: id.to_string(),
235            metadata_digest: self.metadata_digest,
236            encrypted_client_secret: self.encrypted_client_secret,
237            application_type,
238            redirect_uris,
239            grant_types,
240            client_name: self.client_name,
241            logo_uri,
242            client_uri,
243            policy_uri,
244            tos_uri,
245            jwks,
246            id_token_signed_response_alg,
247            userinfo_signed_response_alg,
248            token_endpoint_auth_method,
249            token_endpoint_auth_signing_alg,
250            initiate_login_uri,
251        })
252    }
253}
254
255#[async_trait]
256impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
257    type Error = DatabaseError;
258
259    #[tracing::instrument(
260        name = "db.oauth2_client.lookup",
261        skip_all,
262        fields(
263            db.query.text,
264            oauth2_client.id = %id,
265        ),
266        err,
267    )]
268    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
269        let res = sqlx::query_as!(
270            OAuth2ClientLookup,
271            r#"
272                SELECT oauth2_client_id
273                     , metadata_digest
274                     , encrypted_client_secret
275                     , application_type
276                     , redirect_uris
277                     , grant_type_authorization_code
278                     , grant_type_refresh_token
279                     , grant_type_client_credentials
280                     , grant_type_device_code
281                     , client_name
282                     , logo_uri
283                     , client_uri
284                     , policy_uri
285                     , tos_uri
286                     , jwks_uri
287                     , jwks
288                     , id_token_signed_response_alg
289                     , userinfo_signed_response_alg
290                     , token_endpoint_auth_method
291                     , token_endpoint_auth_signing_alg
292                     , initiate_login_uri
293                FROM oauth2_clients c
294
295                WHERE oauth2_client_id = $1
296            "#,
297            Uuid::from(id),
298        )
299        .traced()
300        .fetch_optional(&mut *self.conn)
301        .await?;
302
303        let Some(res) = res else { return Ok(None) };
304
305        Ok(Some(res.try_into()?))
306    }
307
308    #[tracing::instrument(
309        name = "db.oauth2_client.find_by_metadata_digest",
310        skip_all,
311        fields(
312            db.query.text,
313        ),
314        err,
315    )]
316    async fn find_by_metadata_digest(
317        &mut self,
318        digest: &str,
319    ) -> Result<Option<Client>, Self::Error> {
320        let res = sqlx::query_as!(
321            OAuth2ClientLookup,
322            r#"
323                SELECT oauth2_client_id
324                    , metadata_digest
325                    , encrypted_client_secret
326                    , application_type
327                    , redirect_uris
328                    , grant_type_authorization_code
329                    , grant_type_refresh_token
330                    , grant_type_client_credentials
331                    , grant_type_device_code
332                    , client_name
333                    , logo_uri
334                    , client_uri
335                    , policy_uri
336                    , tos_uri
337                    , jwks_uri
338                    , jwks
339                    , id_token_signed_response_alg
340                    , userinfo_signed_response_alg
341                    , token_endpoint_auth_method
342                    , token_endpoint_auth_signing_alg
343                    , initiate_login_uri
344                FROM oauth2_clients
345                WHERE metadata_digest = $1
346            "#,
347            digest,
348        )
349        .traced()
350        .fetch_optional(&mut *self.conn)
351        .await?;
352
353        let Some(res) = res else { return Ok(None) };
354
355        Ok(Some(res.try_into()?))
356    }
357
358    #[tracing::instrument(
359        name = "db.oauth2_client.load_batch",
360        skip_all,
361        fields(
362            db.query.text,
363        ),
364        err,
365    )]
366    async fn load_batch(
367        &mut self,
368        ids: BTreeSet<Ulid>,
369    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
370        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
371        let res = sqlx::query_as!(
372            OAuth2ClientLookup,
373            r#"
374                SELECT oauth2_client_id
375                     , metadata_digest
376                     , encrypted_client_secret
377                     , application_type
378                     , redirect_uris
379                     , grant_type_authorization_code
380                     , grant_type_refresh_token
381                     , grant_type_client_credentials
382                     , grant_type_device_code
383                     , client_name
384                     , logo_uri
385                     , client_uri
386                     , policy_uri
387                     , tos_uri
388                     , jwks_uri
389                     , jwks
390                     , id_token_signed_response_alg
391                     , userinfo_signed_response_alg
392                     , token_endpoint_auth_method
393                     , token_endpoint_auth_signing_alg
394                     , initiate_login_uri
395                FROM oauth2_clients c
396
397                WHERE oauth2_client_id = ANY($1::uuid[])
398            "#,
399            &ids,
400        )
401        .traced()
402        .fetch_all(&mut *self.conn)
403        .await?;
404
405        res.into_iter()
406            .map(|r| {
407                r.try_into()
408                    .map(|c: Client| (c.id, c))
409                    .map_err(DatabaseError::from)
410            })
411            .collect()
412    }
413
414    #[tracing::instrument(
415        name = "db.oauth2_client.add",
416        skip_all,
417        fields(
418            db.query.text,
419            client.id,
420            client.name = client_name
421        ),
422        err,
423    )]
424    #[allow(clippy::too_many_lines)]
425    async fn add(
426        &mut self,
427        rng: &mut (dyn RngCore + Send),
428        clock: &dyn Clock,
429        redirect_uris: Vec<Url>,
430        metadata_digest: Option<String>,
431        encrypted_client_secret: Option<String>,
432        application_type: Option<ApplicationType>,
433        grant_types: Vec<GrantType>,
434        client_name: Option<String>,
435        logo_uri: Option<Url>,
436        client_uri: Option<Url>,
437        policy_uri: Option<Url>,
438        tos_uri: Option<Url>,
439        jwks_uri: Option<Url>,
440        jwks: Option<PublicJsonWebKeySet>,
441        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
442        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
443        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
444        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
445        initiate_login_uri: Option<Url>,
446    ) -> Result<Client, Self::Error> {
447        let now = clock.now();
448        let id = Ulid::from_datetime_with_source(now.into(), rng);
449        tracing::Span::current().record("client.id", tracing::field::display(id));
450
451        let jwks_json = jwks
452            .as_ref()
453            .map(serde_json::to_value)
454            .transpose()
455            .map_err(DatabaseError::to_invalid_operation)?;
456
457        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
458
459        sqlx::query!(
460            r#"
461                INSERT INTO oauth2_clients
462                    ( oauth2_client_id
463                    , metadata_digest
464                    , encrypted_client_secret
465                    , application_type
466                    , redirect_uris
467                    , grant_type_authorization_code
468                    , grant_type_refresh_token
469                    , grant_type_client_credentials
470                    , grant_type_device_code
471                    , client_name
472                    , logo_uri
473                    , client_uri
474                    , policy_uri
475                    , tos_uri
476                    , jwks_uri
477                    , jwks
478                    , id_token_signed_response_alg
479                    , userinfo_signed_response_alg
480                    , token_endpoint_auth_method
481                    , token_endpoint_auth_signing_alg
482                    , initiate_login_uri
483                    , is_static
484                    )
485                VALUES
486                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
487                    $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
488            "#,
489            Uuid::from(id),
490            metadata_digest,
491            encrypted_client_secret,
492            application_type.as_ref().map(ToString::to_string),
493            &redirect_uris_array,
494            grant_types.contains(&GrantType::AuthorizationCode),
495            grant_types.contains(&GrantType::RefreshToken),
496            grant_types.contains(&GrantType::ClientCredentials),
497            grant_types.contains(&GrantType::DeviceCode),
498            client_name,
499            logo_uri.as_ref().map(Url::as_str),
500            client_uri.as_ref().map(Url::as_str),
501            policy_uri.as_ref().map(Url::as_str),
502            tos_uri.as_ref().map(Url::as_str),
503            jwks_uri.as_ref().map(Url::as_str),
504            jwks_json,
505            id_token_signed_response_alg
506                .as_ref()
507                .map(ToString::to_string),
508            userinfo_signed_response_alg
509                .as_ref()
510                .map(ToString::to_string),
511            token_endpoint_auth_method.as_ref().map(ToString::to_string),
512            token_endpoint_auth_signing_alg
513                .as_ref()
514                .map(ToString::to_string),
515            initiate_login_uri.as_ref().map(Url::as_str),
516        )
517        .traced()
518        .execute(&mut *self.conn)
519        .await?;
520
521        let jwks = match (jwks, jwks_uri) {
522            (None, None) => None,
523            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
524            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
525            _ => return Err(DatabaseError::invalid_operation()),
526        };
527
528        Ok(Client {
529            id,
530            client_id: id.to_string(),
531            metadata_digest: None,
532            encrypted_client_secret,
533            application_type,
534            redirect_uris,
535            grant_types,
536            client_name,
537            logo_uri,
538            client_uri,
539            policy_uri,
540            tos_uri,
541            jwks,
542            id_token_signed_response_alg,
543            userinfo_signed_response_alg,
544            token_endpoint_auth_method,
545            token_endpoint_auth_signing_alg,
546            initiate_login_uri,
547        })
548    }
549
550    #[tracing::instrument(
551        name = "db.oauth2_client.upsert_static",
552        skip_all,
553        fields(
554            db.query.text,
555            client.id = %client_id,
556        ),
557        err,
558    )]
559    async fn upsert_static(
560        &mut self,
561        client_id: Ulid,
562        client_auth_method: OAuthClientAuthenticationMethod,
563        encrypted_client_secret: Option<String>,
564        jwks: Option<PublicJsonWebKeySet>,
565        jwks_uri: Option<Url>,
566        redirect_uris: Vec<Url>,
567    ) -> Result<Client, Self::Error> {
568        let jwks_json = jwks
569            .as_ref()
570            .map(serde_json::to_value)
571            .transpose()
572            .map_err(DatabaseError::to_invalid_operation)?;
573
574        let client_auth_method = client_auth_method.to_string();
575        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
576
577        sqlx::query!(
578            r#"
579                INSERT INTO oauth2_clients
580                    ( oauth2_client_id
581                    , encrypted_client_secret
582                    , redirect_uris
583                    , grant_type_authorization_code
584                    , grant_type_refresh_token
585                    , grant_type_client_credentials
586                    , grant_type_device_code
587                    , token_endpoint_auth_method
588                    , jwks
589                    , jwks_uri
590                    , is_static
591                    )
592                VALUES
593                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
594                ON CONFLICT (oauth2_client_id)
595                DO
596                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
597                             , redirect_uris = EXCLUDED.redirect_uris
598                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
599                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
600                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
601                             , grant_type_device_code = EXCLUDED.grant_type_device_code
602                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
603                             , jwks = EXCLUDED.jwks
604                             , jwks_uri = EXCLUDED.jwks_uri
605                             , is_static = TRUE
606            "#,
607            Uuid::from(client_id),
608            encrypted_client_secret,
609            &redirect_uris_array,
610            true,
611            true,
612            true,
613            true,
614            client_auth_method,
615            jwks_json,
616            jwks_uri.as_ref().map(Url::as_str),
617        )
618        .traced()
619        .execute(&mut *self.conn)
620        .await?;
621
622        let jwks = match (jwks, jwks_uri) {
623            (None, None) => None,
624            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
625            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
626            _ => return Err(DatabaseError::invalid_operation()),
627        };
628
629        Ok(Client {
630            id: client_id,
631            client_id: client_id.to_string(),
632            metadata_digest: None,
633            encrypted_client_secret,
634            application_type: None,
635            redirect_uris,
636            grant_types: vec![
637                GrantType::AuthorizationCode,
638                GrantType::RefreshToken,
639                GrantType::ClientCredentials,
640            ],
641            client_name: None,
642            logo_uri: None,
643            client_uri: None,
644            policy_uri: None,
645            tos_uri: None,
646            jwks,
647            id_token_signed_response_alg: None,
648            userinfo_signed_response_alg: None,
649            token_endpoint_auth_method: None,
650            token_endpoint_auth_signing_alg: None,
651            initiate_login_uri: None,
652        })
653    }
654
655    #[tracing::instrument(
656        name = "db.oauth2_client.all_static",
657        skip_all,
658        fields(
659            db.query.text,
660        ),
661        err,
662    )]
663    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
664        let res = sqlx::query_as!(
665            OAuth2ClientLookup,
666            r#"
667                SELECT oauth2_client_id
668                     , metadata_digest
669                     , encrypted_client_secret
670                     , application_type
671                     , redirect_uris
672                     , grant_type_authorization_code
673                     , grant_type_refresh_token
674                     , grant_type_client_credentials
675                     , grant_type_device_code
676                     , client_name
677                     , logo_uri
678                     , client_uri
679                     , policy_uri
680                     , tos_uri
681                     , jwks_uri
682                     , jwks
683                     , id_token_signed_response_alg
684                     , userinfo_signed_response_alg
685                     , token_endpoint_auth_method
686                     , token_endpoint_auth_signing_alg
687                     , initiate_login_uri
688                FROM oauth2_clients c
689                WHERE is_static = TRUE
690            "#,
691        )
692        .traced()
693        .fetch_all(&mut *self.conn)
694        .await?;
695
696        res.into_iter()
697            .map(|r| r.try_into().map_err(DatabaseError::from))
698            .collect()
699    }
700
701    #[tracing::instrument(
702        name = "db.oauth2_client.get_consent_for_user",
703        skip_all,
704        fields(
705            db.query.text,
706            %user.id,
707            %client.id,
708        ),
709        err,
710    )]
711    async fn get_consent_for_user(
712        &mut self,
713        client: &Client,
714        user: &User,
715    ) -> Result<Scope, Self::Error> {
716        let scope_tokens: Vec<String> = sqlx::query_scalar!(
717            r#"
718                SELECT scope_token
719                FROM oauth2_consents
720                WHERE user_id = $1 AND oauth2_client_id = $2
721            "#,
722            Uuid::from(user.id),
723            Uuid::from(client.id),
724        )
725        .fetch_all(&mut *self.conn)
726        .await?;
727
728        let scope: Result<Scope, _> = scope_tokens
729            .into_iter()
730            .map(|s| ScopeToken::from_str(&s))
731            .collect();
732
733        let scope = scope.map_err(|e| {
734            DatabaseInconsistencyError::on("oauth2_consents")
735                .column("scope_token")
736                .source(e)
737        })?;
738
739        Ok(scope)
740    }
741
742    #[tracing::instrument(
743        name = "db.oauth2_client.give_consent_for_user",
744        skip_all,
745        fields(
746            db.query.text,
747            %user.id,
748            %client.id,
749            %scope,
750        ),
751        err,
752    )]
753    async fn give_consent_for_user(
754        &mut self,
755        rng: &mut (dyn RngCore + Send),
756        clock: &dyn Clock,
757        client: &Client,
758        user: &User,
759        scope: &Scope,
760    ) -> Result<(), Self::Error> {
761        let now = clock.now();
762        let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
763            .iter()
764            .map(|token| {
765                (
766                    token.to_string(),
767                    Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
768                )
769            })
770            .unzip();
771
772        sqlx::query!(
773            r#"
774                INSERT INTO oauth2_consents
775                    (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
776                SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
777                ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
778            "#,
779            &ids,
780            Uuid::from(user.id),
781            Uuid::from(client.id),
782            &tokens,
783            now,
784        )
785        .traced()
786        .execute(&mut *self.conn)
787        .await?;
788
789        Ok(())
790    }
791
792    #[tracing::instrument(
793        name = "db.oauth2_client.delete_by_id",
794        skip_all,
795        fields(
796            db.query.text,
797            client.id = %id,
798        ),
799        err,
800    )]
801    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
802        // Delete the authorization grants
803        {
804            let span = info_span!(
805                "db.oauth2_client.delete_by_id.authorization_grants",
806                { DB_QUERY_TEXT } = tracing::field::Empty,
807            );
808
809            sqlx::query!(
810                r#"
811                    DELETE FROM oauth2_authorization_grants
812                    WHERE oauth2_client_id = $1
813                "#,
814                Uuid::from(id),
815            )
816            .record(&span)
817            .execute(&mut *self.conn)
818            .instrument(span)
819            .await?;
820        }
821
822        // Delete the user consents
823        {
824            let span = info_span!(
825                "db.oauth2_client.delete_by_id.consents",
826                { DB_QUERY_TEXT } = tracing::field::Empty,
827            );
828
829            sqlx::query!(
830                r#"
831                    DELETE FROM oauth2_consents
832                    WHERE oauth2_client_id = $1
833                "#,
834                Uuid::from(id),
835            )
836            .record(&span)
837            .execute(&mut *self.conn)
838            .instrument(span)
839            .await?;
840        }
841
842        // Delete the OAuth 2 sessions related data
843        {
844            let span = info_span!(
845                "db.oauth2_client.delete_by_id.access_tokens",
846                { DB_QUERY_TEXT } = tracing::field::Empty,
847            );
848
849            sqlx::query!(
850                r#"
851                    DELETE FROM oauth2_access_tokens
852                    WHERE oauth2_session_id IN (
853                        SELECT oauth2_session_id
854                        FROM oauth2_sessions
855                        WHERE oauth2_client_id = $1
856                    )
857                "#,
858                Uuid::from(id),
859            )
860            .record(&span)
861            .execute(&mut *self.conn)
862            .instrument(span)
863            .await?;
864        }
865
866        {
867            let span = info_span!(
868                "db.oauth2_client.delete_by_id.refresh_tokens",
869                { DB_QUERY_TEXT } = tracing::field::Empty,
870            );
871
872            sqlx::query!(
873                r#"
874                    DELETE FROM oauth2_refresh_tokens
875                    WHERE oauth2_session_id IN (
876                        SELECT oauth2_session_id
877                        FROM oauth2_sessions
878                        WHERE oauth2_client_id = $1
879                    )
880                "#,
881                Uuid::from(id),
882            )
883            .record(&span)
884            .execute(&mut *self.conn)
885            .instrument(span)
886            .await?;
887        }
888
889        {
890            let span = info_span!(
891                "db.oauth2_client.delete_by_id.sessions",
892                { DB_QUERY_TEXT } = tracing::field::Empty,
893            );
894
895            sqlx::query!(
896                r#"
897                    DELETE FROM oauth2_sessions
898                    WHERE oauth2_client_id = $1
899                "#,
900                Uuid::from(id),
901            )
902            .record(&span)
903            .execute(&mut *self.conn)
904            .instrument(span)
905            .await?;
906        }
907
908        // Now delete the client itself
909        let res = sqlx::query!(
910            r#"
911                DELETE FROM oauth2_clients
912                WHERE oauth2_client_id = $1
913            "#,
914            Uuid::from(id),
915        )
916        .traced()
917        .execute(&mut *self.conn)
918        .await?;
919
920        DatabaseError::ensure_affected_rows(&res, 1)
921    }
922}