mas_storage_pg/oauth2/
authorization_grant.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Clock, Pkce, Session,
12};
13use mas_iana::oauth::PkceCodeChallengeMethod;
14use mas_storage::oauth2::OAuth2AuthorizationGrantRepository;
15use oauth2_types::{requests::ResponseMode, scope::Scope};
16use rand::RngCore;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use url::Url;
20use uuid::Uuid;
21
22use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
23
24/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
25/// connection
26pub struct PgOAuth2AuthorizationGrantRepository<'c> {
27    conn: &'c mut PgConnection,
28}
29
30impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
31    /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
32    /// PostgreSQL connection
33    pub fn new(conn: &'c mut PgConnection) -> Self {
34        Self { conn }
35    }
36}
37
38#[allow(clippy::struct_excessive_bools)]
39struct GrantLookup {
40    oauth2_authorization_grant_id: Uuid,
41    created_at: DateTime<Utc>,
42    cancelled_at: Option<DateTime<Utc>>,
43    fulfilled_at: Option<DateTime<Utc>>,
44    exchanged_at: Option<DateTime<Utc>>,
45    scope: String,
46    state: Option<String>,
47    nonce: Option<String>,
48    redirect_uri: String,
49    response_mode: String,
50    response_type_code: bool,
51    response_type_id_token: bool,
52    authorization_code: Option<String>,
53    code_challenge: Option<String>,
54    code_challenge_method: Option<String>,
55    login_hint: Option<String>,
56    locale: Option<String>,
57    oauth2_client_id: Uuid,
58    oauth2_session_id: Option<Uuid>,
59}
60
61impl TryFrom<GrantLookup> for AuthorizationGrant {
62    type Error = DatabaseInconsistencyError;
63
64    fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
65        let id = value.oauth2_authorization_grant_id.into();
66        let scope: Scope = value.scope.parse().map_err(|e| {
67            DatabaseInconsistencyError::on("oauth2_authorization_grants")
68                .column("scope")
69                .row(id)
70                .source(e)
71        })?;
72
73        let stage = match (
74            value.fulfilled_at,
75            value.exchanged_at,
76            value.cancelled_at,
77            value.oauth2_session_id,
78        ) {
79            (None, None, None, None) => AuthorizationGrantStage::Pending,
80            (Some(fulfilled_at), None, None, Some(session_id)) => {
81                AuthorizationGrantStage::Fulfilled {
82                    session_id: session_id.into(),
83                    fulfilled_at,
84                }
85            }
86            (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
87                AuthorizationGrantStage::Exchanged {
88                    session_id: session_id.into(),
89                    fulfilled_at,
90                    exchanged_at,
91                }
92            }
93            (None, None, Some(cancelled_at), None) => {
94                AuthorizationGrantStage::Cancelled { cancelled_at }
95            }
96            _ => {
97                return Err(
98                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
99                        .column("stage")
100                        .row(id),
101                );
102            }
103        };
104
105        let pkce = match (value.code_challenge, value.code_challenge_method) {
106            (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
107                Some(Pkce {
108                    challenge_method: PkceCodeChallengeMethod::Plain,
109                    challenge,
110                })
111            }
112            (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
113                challenge_method: PkceCodeChallengeMethod::S256,
114                challenge,
115            }),
116            (None, None) => None,
117            _ => {
118                return Err(
119                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
120                        .column("code_challenge_method")
121                        .row(id),
122                );
123            }
124        };
125
126        let code: Option<AuthorizationCode> =
127            match (value.response_type_code, value.authorization_code, pkce) {
128                (false, None, None) => None,
129                (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
130                _ => {
131                    return Err(
132                        DatabaseInconsistencyError::on("oauth2_authorization_grants")
133                            .column("authorization_code")
134                            .row(id),
135                    );
136                }
137            };
138
139        let redirect_uri = value.redirect_uri.parse().map_err(|e| {
140            DatabaseInconsistencyError::on("oauth2_authorization_grants")
141                .column("redirect_uri")
142                .row(id)
143                .source(e)
144        })?;
145
146        let response_mode = value.response_mode.parse().map_err(|e| {
147            DatabaseInconsistencyError::on("oauth2_authorization_grants")
148                .column("response_mode")
149                .row(id)
150                .source(e)
151        })?;
152
153        Ok(AuthorizationGrant {
154            id,
155            stage,
156            client_id: value.oauth2_client_id.into(),
157            code,
158            scope,
159            state: value.state,
160            nonce: value.nonce,
161            response_mode,
162            redirect_uri,
163            created_at: value.created_at,
164            response_type_id_token: value.response_type_id_token,
165            login_hint: value.login_hint,
166            locale: value.locale,
167        })
168    }
169}
170
171#[async_trait]
172impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
173    type Error = DatabaseError;
174
175    #[tracing::instrument(
176        name = "db.oauth2_authorization_grant.add",
177        skip_all,
178        fields(
179            db.query.text,
180            grant.id,
181            grant.scope = %scope,
182            %client.id,
183        ),
184        err,
185    )]
186    async fn add(
187        &mut self,
188        rng: &mut (dyn RngCore + Send),
189        clock: &dyn Clock,
190        client: &Client,
191        redirect_uri: Url,
192        scope: Scope,
193        code: Option<AuthorizationCode>,
194        state: Option<String>,
195        nonce: Option<String>,
196        response_mode: ResponseMode,
197        response_type_id_token: bool,
198        login_hint: Option<String>,
199        locale: Option<String>,
200    ) -> Result<AuthorizationGrant, Self::Error> {
201        let code_challenge = code
202            .as_ref()
203            .and_then(|c| c.pkce.as_ref())
204            .map(|p| &p.challenge);
205        let code_challenge_method = code
206            .as_ref()
207            .and_then(|c| c.pkce.as_ref())
208            .map(|p| p.challenge_method.to_string());
209        let code_str = code.as_ref().map(|c| &c.code);
210
211        let created_at = clock.now();
212        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
213        tracing::Span::current().record("grant.id", tracing::field::display(id));
214
215        sqlx::query!(
216            r#"
217                INSERT INTO oauth2_authorization_grants (
218                     oauth2_authorization_grant_id,
219                     oauth2_client_id,
220                     redirect_uri,
221                     scope,
222                     state,
223                     nonce,
224                     response_mode,
225                     code_challenge,
226                     code_challenge_method,
227                     response_type_code,
228                     response_type_id_token,
229                     authorization_code,
230                     login_hint,
231                     locale,
232                     created_at
233                )
234                VALUES
235                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
236            "#,
237            Uuid::from(id),
238            Uuid::from(client.id),
239            redirect_uri.to_string(),
240            scope.to_string(),
241            state,
242            nonce,
243            response_mode.to_string(),
244            code_challenge,
245            code_challenge_method,
246            code.is_some(),
247            response_type_id_token,
248            code_str,
249            login_hint,
250            locale,
251            created_at,
252        )
253        .traced()
254        .execute(&mut *self.conn)
255        .await?;
256
257        Ok(AuthorizationGrant {
258            id,
259            stage: AuthorizationGrantStage::Pending,
260            code,
261            redirect_uri,
262            client_id: client.id,
263            scope,
264            state,
265            nonce,
266            response_mode,
267            created_at,
268            response_type_id_token,
269            login_hint,
270            locale,
271        })
272    }
273
274    #[tracing::instrument(
275        name = "db.oauth2_authorization_grant.lookup",
276        skip_all,
277        fields(
278            db.query.text,
279            grant.id = %id,
280        ),
281        err,
282    )]
283    async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
284        let res = sqlx::query_as!(
285            GrantLookup,
286            r#"
287                SELECT oauth2_authorization_grant_id
288                     , created_at
289                     , cancelled_at
290                     , fulfilled_at
291                     , exchanged_at
292                     , scope
293                     , state
294                     , redirect_uri
295                     , response_mode
296                     , nonce
297                     , oauth2_client_id
298                     , authorization_code
299                     , response_type_code
300                     , response_type_id_token
301                     , code_challenge
302                     , code_challenge_method
303                     , login_hint
304                     , locale
305                     , oauth2_session_id
306                FROM
307                    oauth2_authorization_grants
308
309                WHERE oauth2_authorization_grant_id = $1
310            "#,
311            Uuid::from(id),
312        )
313        .traced()
314        .fetch_optional(&mut *self.conn)
315        .await?;
316
317        let Some(res) = res else { return Ok(None) };
318
319        Ok(Some(res.try_into()?))
320    }
321
322    #[tracing::instrument(
323        name = "db.oauth2_authorization_grant.find_by_code",
324        skip_all,
325        fields(
326            db.query.text,
327        ),
328        err,
329    )]
330    async fn find_by_code(
331        &mut self,
332        code: &str,
333    ) -> Result<Option<AuthorizationGrant>, Self::Error> {
334        let res = sqlx::query_as!(
335            GrantLookup,
336            r#"
337                SELECT oauth2_authorization_grant_id
338                     , created_at
339                     , cancelled_at
340                     , fulfilled_at
341                     , exchanged_at
342                     , scope
343                     , state
344                     , redirect_uri
345                     , response_mode
346                     , nonce
347                     , oauth2_client_id
348                     , authorization_code
349                     , response_type_code
350                     , response_type_id_token
351                     , code_challenge
352                     , code_challenge_method
353                     , login_hint
354                     , locale
355                     , oauth2_session_id
356                FROM
357                    oauth2_authorization_grants
358
359                WHERE authorization_code = $1
360            "#,
361            code,
362        )
363        .traced()
364        .fetch_optional(&mut *self.conn)
365        .await?;
366
367        let Some(res) = res else { return Ok(None) };
368
369        Ok(Some(res.try_into()?))
370    }
371
372    #[tracing::instrument(
373        name = "db.oauth2_authorization_grant.fulfill",
374        skip_all,
375        fields(
376            db.query.text,
377            %grant.id,
378            client.id = %grant.client_id,
379            %session.id,
380        ),
381        err,
382    )]
383    async fn fulfill(
384        &mut self,
385        clock: &dyn Clock,
386        session: &Session,
387        grant: AuthorizationGrant,
388    ) -> Result<AuthorizationGrant, Self::Error> {
389        let fulfilled_at = clock.now();
390        let res = sqlx::query!(
391            r#"
392                UPDATE oauth2_authorization_grants
393                SET fulfilled_at = $2
394                  , oauth2_session_id = $3
395                WHERE oauth2_authorization_grant_id = $1
396            "#,
397            Uuid::from(grant.id),
398            fulfilled_at,
399            Uuid::from(session.id),
400        )
401        .traced()
402        .execute(&mut *self.conn)
403        .await?;
404
405        DatabaseError::ensure_affected_rows(&res, 1)?;
406
407        // XXX: check affected rows & new methods
408        let grant = grant
409            .fulfill(fulfilled_at, session)
410            .map_err(DatabaseError::to_invalid_operation)?;
411
412        Ok(grant)
413    }
414
415    #[tracing::instrument(
416        name = "db.oauth2_authorization_grant.exchange",
417        skip_all,
418        fields(
419            db.query.text,
420            %grant.id,
421            client.id = %grant.client_id,
422        ),
423        err,
424    )]
425    async fn exchange(
426        &mut self,
427        clock: &dyn Clock,
428        grant: AuthorizationGrant,
429    ) -> Result<AuthorizationGrant, Self::Error> {
430        let exchanged_at = clock.now();
431        let res = sqlx::query!(
432            r#"
433                UPDATE oauth2_authorization_grants
434                SET exchanged_at = $2
435                WHERE oauth2_authorization_grant_id = $1
436            "#,
437            Uuid::from(grant.id),
438            exchanged_at,
439        )
440        .traced()
441        .execute(&mut *self.conn)
442        .await?;
443
444        DatabaseError::ensure_affected_rows(&res, 1)?;
445
446        let grant = grant
447            .exchange(exchanged_at)
448            .map_err(DatabaseError::to_invalid_operation)?;
449
450        Ok(grant)
451    }
452
453    #[tracing::instrument(
454        name = "db.oauth2_authorization_grant.cleanup",
455        skip_all,
456        fields(
457            db.query.text,
458            since = since.map(tracing::field::display),
459            until = %until,
460            limit = limit,
461        ),
462        err,
463    )]
464    async fn cleanup(
465        &mut self,
466        since: Option<Ulid>,
467        until: Ulid,
468        limit: usize,
469    ) -> Result<(usize, Option<Ulid>), Self::Error> {
470        // `MAX(uuid)` isn't a thing in Postgres, so we can't just re-select the
471        // deleted rows and do a MAX on the `oauth2_authorization_grant_id`.
472        // Instead, we do the aggregation on the client side, which is a little
473        // less efficient, but good enough.
474        let res = sqlx::query_scalar!(
475            r#"
476                WITH to_delete AS (
477                    SELECT oauth2_authorization_grant_id
478                    FROM oauth2_authorization_grants
479                    WHERE ($1::uuid IS NULL OR oauth2_authorization_grant_id > $1)
480                    AND oauth2_authorization_grant_id <= $2
481                    ORDER BY oauth2_authorization_grant_id
482                    LIMIT $3
483                )
484                DELETE FROM oauth2_authorization_grants
485                USING to_delete
486                WHERE oauth2_authorization_grants.oauth2_authorization_grant_id = to_delete.oauth2_authorization_grant_id
487                RETURNING oauth2_authorization_grants.oauth2_authorization_grant_id
488            "#,
489            since.map(Uuid::from),
490            Uuid::from(until),
491            i64::try_from(limit).unwrap_or(i64::MAX)
492        )
493        .traced()
494        .fetch_all(&mut *self.conn)
495        .await?;
496
497        let count = res.len();
498        let max_id = res.into_iter().max();
499
500        Ok((count, max_id.map(Ulid::from)))
501    }
502}