mas_storage_pg/oauth2/
device_code_grant.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Clock, DeviceCodeGrant, DeviceCodeGrantState, Session};
13use mas_storage::oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository};
14use oauth2_types::scope::Scope;
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
21
22/// An implementation of [`OAuth2DeviceCodeGrantRepository`] for a PostgreSQL
23/// connection
24pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
25    conn: &'c mut PgConnection,
26}
27
28impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
29    /// Create a new [`PgOAuth2DeviceCodeGrantRepository`] from an active
30    /// PostgreSQL connection
31    pub fn new(conn: &'c mut PgConnection) -> Self {
32        Self { conn }
33    }
34}
35
36struct OAuth2DeviceGrantLookup {
37    oauth2_device_code_grant_id: Uuid,
38    oauth2_client_id: Uuid,
39    scope: String,
40    device_code: String,
41    user_code: String,
42    created_at: DateTime<Utc>,
43    expires_at: DateTime<Utc>,
44    fulfilled_at: Option<DateTime<Utc>>,
45    rejected_at: Option<DateTime<Utc>>,
46    exchanged_at: Option<DateTime<Utc>>,
47    user_session_id: Option<Uuid>,
48    oauth2_session_id: Option<Uuid>,
49    ip_address: Option<IpAddr>,
50    user_agent: Option<String>,
51}
52
53impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
54    type Error = DatabaseInconsistencyError;
55
56    fn try_from(
57        OAuth2DeviceGrantLookup {
58            oauth2_device_code_grant_id,
59            oauth2_client_id,
60            scope,
61            device_code,
62            user_code,
63            created_at,
64            expires_at,
65            fulfilled_at,
66            rejected_at,
67            exchanged_at,
68            user_session_id,
69            oauth2_session_id,
70            ip_address,
71            user_agent,
72        }: OAuth2DeviceGrantLookup,
73    ) -> Result<Self, Self::Error> {
74        let id = Ulid::from(oauth2_device_code_grant_id);
75        let client_id = Ulid::from(oauth2_client_id);
76
77        let scope: Scope = scope.parse().map_err(|e| {
78            DatabaseInconsistencyError::on("oauth2_authorization_grants")
79                .column("scope")
80                .row(id)
81                .source(e)
82        })?;
83
84        let state = match (
85            fulfilled_at,
86            rejected_at,
87            exchanged_at,
88            user_session_id,
89            oauth2_session_id,
90        ) {
91            (None, None, None, None, None) => DeviceCodeGrantState::Pending,
92
93            (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
94                DeviceCodeGrantState::Fulfilled {
95                    browser_session_id: Ulid::from(user_session_id),
96                    fulfilled_at,
97                }
98            }
99
100            (None, Some(rejected_at), None, Some(user_session_id), None) => {
101                DeviceCodeGrantState::Rejected {
102                    browser_session_id: Ulid::from(user_session_id),
103                    rejected_at,
104                }
105            }
106
107            (
108                Some(fulfilled_at),
109                None,
110                Some(exchanged_at),
111                Some(user_session_id),
112                Some(oauth2_session_id),
113            ) => DeviceCodeGrantState::Exchanged {
114                browser_session_id: Ulid::from(user_session_id),
115                session_id: Ulid::from(oauth2_session_id),
116                fulfilled_at,
117                exchanged_at,
118            },
119
120            _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
121        };
122
123        Ok(DeviceCodeGrant {
124            id,
125            state,
126            client_id,
127            scope,
128            user_code,
129            device_code,
130            created_at,
131            expires_at,
132            ip_address,
133            user_agent,
134        })
135    }
136}
137
138#[async_trait]
139impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
140    type Error = DatabaseError;
141
142    #[tracing::instrument(
143        name = "db.oauth2_device_code_grant.add",
144        skip_all,
145        fields(
146            db.query.text,
147            oauth2_device_code.id,
148            oauth2_device_code.scope = %params.scope,
149            oauth2_client.id = %params.client.id,
150        ),
151        err,
152    )]
153    async fn add(
154        &mut self,
155        rng: &mut (dyn RngCore + Send),
156        clock: &dyn Clock,
157        params: OAuth2DeviceCodeGrantParams<'_>,
158    ) -> Result<DeviceCodeGrant, Self::Error> {
159        let now = clock.now();
160        let id = Ulid::from_datetime_with_source(now.into(), rng);
161        tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
162
163        let created_at = now;
164        let expires_at = now + params.expires_in;
165        let client_id = params.client.id;
166
167        sqlx::query!(
168            r#"
169                INSERT INTO "oauth2_device_code_grant"
170                    ( oauth2_device_code_grant_id
171                    , oauth2_client_id
172                    , scope
173                    , device_code
174                    , user_code
175                    , created_at
176                    , expires_at
177                    , ip_address
178                    , user_agent
179                    )
180                VALUES
181                    ($1, $2, $3, $4, $5, $6, $7, $8, $9)
182            "#,
183            Uuid::from(id),
184            Uuid::from(client_id),
185            params.scope.to_string(),
186            &params.device_code,
187            &params.user_code,
188            created_at,
189            expires_at,
190            params.ip_address as Option<IpAddr>,
191            params.user_agent.as_deref(),
192        )
193        .traced()
194        .execute(&mut *self.conn)
195        .await?;
196
197        Ok(DeviceCodeGrant {
198            id,
199            state: DeviceCodeGrantState::Pending,
200            client_id,
201            scope: params.scope,
202            user_code: params.user_code,
203            device_code: params.device_code,
204            created_at,
205            expires_at,
206            ip_address: params.ip_address,
207            user_agent: params.user_agent,
208        })
209    }
210
211    #[tracing::instrument(
212        name = "db.oauth2_device_code_grant.lookup",
213        skip_all,
214        fields(
215            db.query.text,
216            oauth2_device_code.id = %id,
217        ),
218        err,
219    )]
220    async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
221        let res = sqlx::query_as!(
222            OAuth2DeviceGrantLookup,
223            r#"
224                SELECT oauth2_device_code_grant_id
225                     , oauth2_client_id
226                     , scope
227                     , device_code
228                     , user_code
229                     , created_at
230                     , expires_at
231                     , fulfilled_at
232                     , rejected_at
233                     , exchanged_at
234                     , user_session_id
235                     , oauth2_session_id
236                     , ip_address as "ip_address: IpAddr"
237                     , user_agent
238                FROM
239                    oauth2_device_code_grant
240
241                WHERE oauth2_device_code_grant_id = $1
242            "#,
243            Uuid::from(id),
244        )
245        .traced()
246        .fetch_optional(&mut *self.conn)
247        .await?;
248
249        let Some(res) = res else { return Ok(None) };
250
251        Ok(Some(res.try_into()?))
252    }
253
254    #[tracing::instrument(
255        name = "db.oauth2_device_code_grant.find_by_user_code",
256        skip_all,
257        fields(
258            db.query.text,
259            oauth2_device_code.user_code = %user_code,
260        ),
261        err,
262    )]
263    async fn find_by_user_code(
264        &mut self,
265        user_code: &str,
266    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
267        let res = sqlx::query_as!(
268            OAuth2DeviceGrantLookup,
269            r#"
270                SELECT oauth2_device_code_grant_id
271                     , oauth2_client_id
272                     , scope
273                     , device_code
274                     , user_code
275                     , created_at
276                     , expires_at
277                     , fulfilled_at
278                     , rejected_at
279                     , exchanged_at
280                     , user_session_id
281                     , oauth2_session_id
282                     , ip_address as "ip_address: IpAddr"
283                     , user_agent
284                FROM
285                    oauth2_device_code_grant
286
287                WHERE user_code = $1
288            "#,
289            user_code,
290        )
291        .traced()
292        .fetch_optional(&mut *self.conn)
293        .await?;
294
295        let Some(res) = res else { return Ok(None) };
296
297        Ok(Some(res.try_into()?))
298    }
299
300    #[tracing::instrument(
301        name = "db.oauth2_device_code_grant.find_by_device_code",
302        skip_all,
303        fields(
304            db.query.text,
305            oauth2_device_code.device_code = %device_code,
306        ),
307        err,
308    )]
309    async fn find_by_device_code(
310        &mut self,
311        device_code: &str,
312    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
313        let res = sqlx::query_as!(
314            OAuth2DeviceGrantLookup,
315            r#"
316                SELECT oauth2_device_code_grant_id
317                     , oauth2_client_id
318                     , scope
319                     , device_code
320                     , user_code
321                     , created_at
322                     , expires_at
323                     , fulfilled_at
324                     , rejected_at
325                     , exchanged_at
326                     , user_session_id
327                     , oauth2_session_id
328                     , ip_address as "ip_address: IpAddr"
329                     , user_agent
330                FROM
331                    oauth2_device_code_grant
332
333                WHERE device_code = $1
334            "#,
335            device_code,
336        )
337        .traced()
338        .fetch_optional(&mut *self.conn)
339        .await?;
340
341        let Some(res) = res else { return Ok(None) };
342
343        Ok(Some(res.try_into()?))
344    }
345
346    #[tracing::instrument(
347        name = "db.oauth2_device_code_grant.fulfill",
348        skip_all,
349        fields(
350            db.query.text,
351            oauth2_device_code.id = %device_code_grant.id,
352            oauth2_client.id = %device_code_grant.client_id,
353            browser_session.id = %browser_session.id,
354            user.id = %browser_session.user.id,
355        ),
356        err,
357    )]
358    async fn fulfill(
359        &mut self,
360        clock: &dyn Clock,
361        device_code_grant: DeviceCodeGrant,
362        browser_session: &BrowserSession,
363    ) -> Result<DeviceCodeGrant, Self::Error> {
364        let fulfilled_at = clock.now();
365        let device_code_grant = device_code_grant
366            .fulfill(browser_session, fulfilled_at)
367            .map_err(DatabaseError::to_invalid_operation)?;
368
369        let res = sqlx::query!(
370            r#"
371                UPDATE oauth2_device_code_grant
372                SET fulfilled_at = $1
373                  , user_session_id = $2
374                WHERE oauth2_device_code_grant_id = $3
375            "#,
376            fulfilled_at,
377            Uuid::from(browser_session.id),
378            Uuid::from(device_code_grant.id),
379        )
380        .traced()
381        .execute(&mut *self.conn)
382        .await?;
383
384        DatabaseError::ensure_affected_rows(&res, 1)?;
385
386        Ok(device_code_grant)
387    }
388
389    #[tracing::instrument(
390        name = "db.oauth2_device_code_grant.reject",
391        skip_all,
392        fields(
393            db.query.text,
394            oauth2_device_code.id = %device_code_grant.id,
395            oauth2_client.id = %device_code_grant.client_id,
396            browser_session.id = %browser_session.id,
397            user.id = %browser_session.user.id,
398        ),
399        err,
400    )]
401    async fn reject(
402        &mut self,
403        clock: &dyn Clock,
404        device_code_grant: DeviceCodeGrant,
405        browser_session: &BrowserSession,
406    ) -> Result<DeviceCodeGrant, Self::Error> {
407        let fulfilled_at = clock.now();
408        let device_code_grant = device_code_grant
409            .reject(browser_session, fulfilled_at)
410            .map_err(DatabaseError::to_invalid_operation)?;
411
412        let res = sqlx::query!(
413            r#"
414                UPDATE oauth2_device_code_grant
415                SET rejected_at = $1
416                  , user_session_id = $2
417                WHERE oauth2_device_code_grant_id = $3
418            "#,
419            fulfilled_at,
420            Uuid::from(browser_session.id),
421            Uuid::from(device_code_grant.id),
422        )
423        .traced()
424        .execute(&mut *self.conn)
425        .await?;
426
427        DatabaseError::ensure_affected_rows(&res, 1)?;
428
429        Ok(device_code_grant)
430    }
431
432    #[tracing::instrument(
433        name = "db.oauth2_device_code_grant.exchange",
434        skip_all,
435        fields(
436            db.query.text,
437            oauth2_device_code.id = %device_code_grant.id,
438            oauth2_client.id = %device_code_grant.client_id,
439            oauth2_session.id = %session.id,
440        ),
441        err,
442    )]
443    async fn exchange(
444        &mut self,
445        clock: &dyn Clock,
446        device_code_grant: DeviceCodeGrant,
447        session: &Session,
448    ) -> Result<DeviceCodeGrant, Self::Error> {
449        let exchanged_at = clock.now();
450        let device_code_grant = device_code_grant
451            .exchange(session, exchanged_at)
452            .map_err(DatabaseError::to_invalid_operation)?;
453
454        let res = sqlx::query!(
455            r#"
456                UPDATE oauth2_device_code_grant
457                SET exchanged_at = $1
458                  , oauth2_session_id = $2
459                WHERE oauth2_device_code_grant_id = $3
460            "#,
461            exchanged_at,
462            Uuid::from(session.id),
463            Uuid::from(device_code_grant.id),
464        )
465        .traced()
466        .execute(&mut *self.conn)
467        .await?;
468
469        DatabaseError::ensure_affected_rows(&res, 1)?;
470
471        Ok(device_code_grant)
472    }
473
474    #[tracing::instrument(
475        name = "db.oauth2_device_code_grant.cleanup",
476        skip_all,
477        fields(
478            db.query.text,
479            since = since.map(tracing::field::display),
480            until = %until,
481            limit = limit,
482        ),
483        err,
484    )]
485    async fn cleanup(
486        &mut self,
487        since: Option<Ulid>,
488        until: Ulid,
489        limit: usize,
490    ) -> Result<(usize, Option<Ulid>), Self::Error> {
491        // `MAX(uuid)` isn't a thing in Postgres, so we can't just re-select the
492        // deleted rows and do a MAX on the `oauth2_device_code_grant_id`.
493        // Instead, we do the aggregation on the client side, which is a little
494        // less efficient, but good enough.
495        let res = sqlx::query_scalar!(
496            r#"
497                WITH to_delete AS (
498                    SELECT oauth2_device_code_grant_id
499                    FROM oauth2_device_code_grant
500                    WHERE ($1::uuid IS NULL OR oauth2_device_code_grant_id > $1)
501                    AND oauth2_device_code_grant_id <= $2
502                    ORDER BY oauth2_device_code_grant_id
503                    LIMIT $3
504                )
505                DELETE FROM oauth2_device_code_grant
506                USING to_delete
507                WHERE oauth2_device_code_grant.oauth2_device_code_grant_id = to_delete.oauth2_device_code_grant_id
508                RETURNING oauth2_device_code_grant.oauth2_device_code_grant_id
509            "#,
510            since.map(Uuid::from),
511            Uuid::from(until),
512            i64::try_from(limit).unwrap_or(i64::MAX)
513        )
514        .traced()
515        .fetch_all(&mut *self.conn)
516        .await?;
517
518        let count = res.len();
519        let max_id = res.into_iter().max();
520
521        Ok((count, max_id.map(Ulid::from)))
522    }
523}