mas_storage_pg/oauth2/
authorization_grant.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::num::NonZeroU32;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
13};
14use mas_iana::oauth::PkceCodeChallengeMethod;
15use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
16use oauth2_types::{requests::ResponseMode, scope::Scope};
17use rand::RngCore;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use url::Url;
21use uuid::Uuid;
22
23use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
24
25/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
26/// connection
27pub struct PgOAuth2AuthorizationGrantRepository<'c> {
28    conn: &'c mut PgConnection,
29}
30
31impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
32    /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
33    /// PostgreSQL connection
34    pub fn new(conn: &'c mut PgConnection) -> Self {
35        Self { conn }
36    }
37}
38
39#[allow(clippy::struct_excessive_bools)]
40struct GrantLookup {
41    oauth2_authorization_grant_id: Uuid,
42    created_at: DateTime<Utc>,
43    cancelled_at: Option<DateTime<Utc>>,
44    fulfilled_at: Option<DateTime<Utc>>,
45    exchanged_at: Option<DateTime<Utc>>,
46    scope: String,
47    state: Option<String>,
48    nonce: Option<String>,
49    redirect_uri: String,
50    response_mode: String,
51    max_age: Option<i32>,
52    response_type_code: bool,
53    response_type_id_token: bool,
54    authorization_code: Option<String>,
55    code_challenge: Option<String>,
56    code_challenge_method: Option<String>,
57    requires_consent: bool,
58    login_hint: Option<String>,
59    oauth2_client_id: Uuid,
60    oauth2_session_id: Option<Uuid>,
61}
62
63impl TryFrom<GrantLookup> for AuthorizationGrant {
64    type Error = DatabaseInconsistencyError;
65
66    #[allow(clippy::too_many_lines)]
67    fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
68        let id = value.oauth2_authorization_grant_id.into();
69        let scope: Scope = value.scope.parse().map_err(|e| {
70            DatabaseInconsistencyError::on("oauth2_authorization_grants")
71                .column("scope")
72                .row(id)
73                .source(e)
74        })?;
75
76        let stage = match (
77            value.fulfilled_at,
78            value.exchanged_at,
79            value.cancelled_at,
80            value.oauth2_session_id,
81        ) {
82            (None, None, None, None) => AuthorizationGrantStage::Pending,
83            (Some(fulfilled_at), None, None, Some(session_id)) => {
84                AuthorizationGrantStage::Fulfilled {
85                    session_id: session_id.into(),
86                    fulfilled_at,
87                }
88            }
89            (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
90                AuthorizationGrantStage::Exchanged {
91                    session_id: session_id.into(),
92                    fulfilled_at,
93                    exchanged_at,
94                }
95            }
96            (None, None, Some(cancelled_at), None) => {
97                AuthorizationGrantStage::Cancelled { cancelled_at }
98            }
99            _ => {
100                return Err(
101                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
102                        .column("stage")
103                        .row(id),
104                );
105            }
106        };
107
108        let pkce = match (value.code_challenge, value.code_challenge_method) {
109            (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
110                Some(Pkce {
111                    challenge_method: PkceCodeChallengeMethod::Plain,
112                    challenge,
113                })
114            }
115            (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
116                challenge_method: PkceCodeChallengeMethod::S256,
117                challenge,
118            }),
119            (None, None) => None,
120            _ => {
121                return Err(
122                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
123                        .column("code_challenge_method")
124                        .row(id),
125                );
126            }
127        };
128
129        let code: Option<AuthorizationCode> =
130            match (value.response_type_code, value.authorization_code, pkce) {
131                (false, None, None) => None,
132                (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
133                _ => {
134                    return Err(
135                        DatabaseInconsistencyError::on("oauth2_authorization_grants")
136                            .column("authorization_code")
137                            .row(id),
138                    );
139                }
140            };
141
142        let redirect_uri = value.redirect_uri.parse().map_err(|e| {
143            DatabaseInconsistencyError::on("oauth2_authorization_grants")
144                .column("redirect_uri")
145                .row(id)
146                .source(e)
147        })?;
148
149        let response_mode = value.response_mode.parse().map_err(|e| {
150            DatabaseInconsistencyError::on("oauth2_authorization_grants")
151                .column("response_mode")
152                .row(id)
153                .source(e)
154        })?;
155
156        let max_age = value
157            .max_age
158            .map(u32::try_from)
159            .transpose()
160            .map_err(|e| {
161                DatabaseInconsistencyError::on("oauth2_authorization_grants")
162                    .column("max_age")
163                    .row(id)
164                    .source(e)
165            })?
166            .map(NonZeroU32::try_from)
167            .transpose()
168            .map_err(|e| {
169                DatabaseInconsistencyError::on("oauth2_authorization_grants")
170                    .column("max_age")
171                    .row(id)
172                    .source(e)
173            })?;
174
175        Ok(AuthorizationGrant {
176            id,
177            stage,
178            client_id: value.oauth2_client_id.into(),
179            code,
180            scope,
181            state: value.state,
182            nonce: value.nonce,
183            max_age,
184            response_mode,
185            redirect_uri,
186            created_at: value.created_at,
187            response_type_id_token: value.response_type_id_token,
188            requires_consent: value.requires_consent,
189            login_hint: value.login_hint,
190        })
191    }
192}
193
194#[async_trait]
195impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
196    type Error = DatabaseError;
197
198    #[tracing::instrument(
199        name = "db.oauth2_authorization_grant.add",
200        skip_all,
201        fields(
202            db.query.text,
203            grant.id,
204            grant.scope = %scope,
205            %client.id,
206        ),
207        err,
208    )]
209    async fn add(
210        &mut self,
211        rng: &mut (dyn RngCore + Send),
212        clock: &dyn Clock,
213        client: &Client,
214        redirect_uri: Url,
215        scope: Scope,
216        code: Option<AuthorizationCode>,
217        state: Option<String>,
218        nonce: Option<String>,
219        max_age: Option<NonZeroU32>,
220        response_mode: ResponseMode,
221        response_type_id_token: bool,
222        requires_consent: bool,
223        login_hint: Option<String>,
224    ) -> Result<AuthorizationGrant, Self::Error> {
225        let code_challenge = code
226            .as_ref()
227            .and_then(|c| c.pkce.as_ref())
228            .map(|p| &p.challenge);
229        let code_challenge_method = code
230            .as_ref()
231            .and_then(|c| c.pkce.as_ref())
232            .map(|p| p.challenge_method.to_string());
233        // TODO: this conversion is a bit ugly
234        let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
235        let code_str = code.as_ref().map(|c| &c.code);
236
237        let created_at = clock.now();
238        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
239        tracing::Span::current().record("grant.id", tracing::field::display(id));
240
241        sqlx::query!(
242            r#"
243                INSERT INTO oauth2_authorization_grants (
244                     oauth2_authorization_grant_id,
245                     oauth2_client_id,
246                     redirect_uri,
247                     scope,
248                     state,
249                     nonce,
250                     max_age,
251                     response_mode,
252                     code_challenge,
253                     code_challenge_method,
254                     response_type_code,
255                     response_type_id_token,
256                     authorization_code,
257                     requires_consent,
258                     login_hint,
259                     created_at
260                )
261                VALUES
262                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
263            "#,
264            Uuid::from(id),
265            Uuid::from(client.id),
266            redirect_uri.to_string(),
267            scope.to_string(),
268            state,
269            nonce,
270            max_age_i32,
271            response_mode.to_string(),
272            code_challenge,
273            code_challenge_method,
274            code.is_some(),
275            response_type_id_token,
276            code_str,
277            requires_consent,
278            login_hint,
279            created_at,
280        )
281        .traced()
282        .execute(&mut *self.conn)
283        .await?;
284
285        Ok(AuthorizationGrant {
286            id,
287            stage: AuthorizationGrantStage::Pending,
288            code,
289            redirect_uri,
290            client_id: client.id,
291            scope,
292            state,
293            nonce,
294            max_age,
295            response_mode,
296            created_at,
297            response_type_id_token,
298            requires_consent,
299            login_hint,
300        })
301    }
302
303    #[tracing::instrument(
304        name = "db.oauth2_authorization_grant.lookup",
305        skip_all,
306        fields(
307            db.query.text,
308            grant.id = %id,
309        ),
310        err,
311    )]
312    async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
313        let res = sqlx::query_as!(
314            GrantLookup,
315            r#"
316                SELECT oauth2_authorization_grant_id
317                     , created_at
318                     , cancelled_at
319                     , fulfilled_at
320                     , exchanged_at
321                     , scope
322                     , state
323                     , redirect_uri
324                     , response_mode
325                     , nonce
326                     , max_age
327                     , oauth2_client_id
328                     , authorization_code
329                     , response_type_code
330                     , response_type_id_token
331                     , code_challenge
332                     , code_challenge_method
333                     , requires_consent
334                     , login_hint
335                     , oauth2_session_id
336                FROM
337                    oauth2_authorization_grants
338
339                WHERE oauth2_authorization_grant_id = $1
340            "#,
341            Uuid::from(id),
342        )
343        .traced()
344        .fetch_optional(&mut *self.conn)
345        .await?;
346
347        let Some(res) = res else { return Ok(None) };
348
349        Ok(Some(res.try_into()?))
350    }
351
352    #[tracing::instrument(
353        name = "db.oauth2_authorization_grant.find_by_code",
354        skip_all,
355        fields(
356            db.query.text,
357        ),
358        err,
359    )]
360    async fn find_by_code(
361        &mut self,
362        code: &str,
363    ) -> Result<Option<AuthorizationGrant>, Self::Error> {
364        let res = sqlx::query_as!(
365            GrantLookup,
366            r#"
367                SELECT oauth2_authorization_grant_id
368                     , created_at
369                     , cancelled_at
370                     , fulfilled_at
371                     , exchanged_at
372                     , scope
373                     , state
374                     , redirect_uri
375                     , response_mode
376                     , nonce
377                     , max_age
378                     , oauth2_client_id
379                     , authorization_code
380                     , response_type_code
381                     , response_type_id_token
382                     , code_challenge
383                     , code_challenge_method
384                     , requires_consent
385                     , login_hint
386                     , oauth2_session_id
387                FROM
388                    oauth2_authorization_grants
389
390                WHERE authorization_code = $1
391            "#,
392            code,
393        )
394        .traced()
395        .fetch_optional(&mut *self.conn)
396        .await?;
397
398        let Some(res) = res else { return Ok(None) };
399
400        Ok(Some(res.try_into()?))
401    }
402
403    #[tracing::instrument(
404        name = "db.oauth2_authorization_grant.fulfill",
405        skip_all,
406        fields(
407            db.query.text,
408            %grant.id,
409            client.id = %grant.client_id,
410            %session.id,
411        ),
412        err,
413    )]
414    async fn fulfill(
415        &mut self,
416        clock: &dyn Clock,
417        session: &Session,
418        grant: AuthorizationGrant,
419    ) -> Result<AuthorizationGrant, Self::Error> {
420        let fulfilled_at = clock.now();
421        let res = sqlx::query!(
422            r#"
423                UPDATE oauth2_authorization_grants
424                SET fulfilled_at = $2
425                  , oauth2_session_id = $3
426                WHERE oauth2_authorization_grant_id = $1
427            "#,
428            Uuid::from(grant.id),
429            fulfilled_at,
430            Uuid::from(session.id),
431        )
432        .traced()
433        .execute(&mut *self.conn)
434        .await?;
435
436        DatabaseError::ensure_affected_rows(&res, 1)?;
437
438        // XXX: check affected rows & new methods
439        let grant = grant
440            .fulfill(fulfilled_at, session)
441            .map_err(DatabaseError::to_invalid_operation)?;
442
443        Ok(grant)
444    }
445
446    #[tracing::instrument(
447        name = "db.oauth2_authorization_grant.exchange",
448        skip_all,
449        fields(
450            db.query.text,
451            %grant.id,
452            client.id = %grant.client_id,
453        ),
454        err,
455    )]
456    async fn exchange(
457        &mut self,
458        clock: &dyn Clock,
459        grant: AuthorizationGrant,
460    ) -> Result<AuthorizationGrant, Self::Error> {
461        let exchanged_at = clock.now();
462        let res = sqlx::query!(
463            r#"
464                UPDATE oauth2_authorization_grants
465                SET exchanged_at = $2
466                WHERE oauth2_authorization_grant_id = $1
467            "#,
468            Uuid::from(grant.id),
469            exchanged_at,
470        )
471        .traced()
472        .execute(&mut *self.conn)
473        .await?;
474
475        DatabaseError::ensure_affected_rows(&res, 1)?;
476
477        let grant = grant
478            .exchange(exchanged_at)
479            .map_err(DatabaseError::to_invalid_operation)?;
480
481        Ok(grant)
482    }
483
484    #[tracing::instrument(
485        name = "db.oauth2_authorization_grant.give_consent",
486        skip_all,
487        fields(
488            db.query.text,
489            %grant.id,
490            client.id = %grant.client_id,
491        ),
492        err,
493    )]
494    async fn give_consent(
495        &mut self,
496        mut grant: AuthorizationGrant,
497    ) -> Result<AuthorizationGrant, Self::Error> {
498        sqlx::query!(
499            r#"
500                UPDATE oauth2_authorization_grants AS og
501                SET
502                    requires_consent = 'f'
503                WHERE
504                    og.oauth2_authorization_grant_id = $1
505            "#,
506            Uuid::from(grant.id),
507        )
508        .traced()
509        .execute(&mut *self.conn)
510        .await?;
511
512        grant.requires_consent = false;
513
514        Ok(grant)
515    }
516}