mas_storage_pg/oauth2/
access_token.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use mas_data_model::{AccessToken, AccessTokenState, Session};
10use mas_storage::{Clock, oauth2::OAuth2AccessTokenRepository};
11use rand::RngCore;
12use sqlx::PgConnection;
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{DatabaseError, tracing::ExecuteExt};
17
18/// An implementation of [`OAuth2AccessTokenRepository`] for a PostgreSQL
19/// connection
20pub struct PgOAuth2AccessTokenRepository<'c> {
21    conn: &'c mut PgConnection,
22}
23
24impl<'c> PgOAuth2AccessTokenRepository<'c> {
25    /// Create a new [`PgOAuth2AccessTokenRepository`] from an active PostgreSQL
26    /// connection
27    pub fn new(conn: &'c mut PgConnection) -> Self {
28        Self { conn }
29    }
30}
31
32struct OAuth2AccessTokenLookup {
33    oauth2_access_token_id: Uuid,
34    oauth2_session_id: Uuid,
35    access_token: String,
36    created_at: DateTime<Utc>,
37    expires_at: Option<DateTime<Utc>>,
38    revoked_at: Option<DateTime<Utc>>,
39    first_used_at: Option<DateTime<Utc>>,
40}
41
42impl From<OAuth2AccessTokenLookup> for AccessToken {
43    fn from(value: OAuth2AccessTokenLookup) -> Self {
44        let state = match value.revoked_at {
45            None => AccessTokenState::Valid,
46            Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
47        };
48
49        Self {
50            id: value.oauth2_access_token_id.into(),
51            state,
52            session_id: value.oauth2_session_id.into(),
53            access_token: value.access_token,
54            created_at: value.created_at,
55            expires_at: value.expires_at,
56            first_used_at: value.first_used_at,
57        }
58    }
59}
60
61#[async_trait]
62impl OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'_> {
63    type Error = DatabaseError;
64
65    async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
66        let res = sqlx::query_as!(
67            OAuth2AccessTokenLookup,
68            r#"
69                SELECT oauth2_access_token_id
70                     , access_token
71                     , created_at
72                     , expires_at
73                     , revoked_at
74                     , oauth2_session_id
75                     , first_used_at
76
77                FROM oauth2_access_tokens
78
79                WHERE oauth2_access_token_id = $1
80            "#,
81            Uuid::from(id),
82        )
83        .fetch_optional(&mut *self.conn)
84        .await?;
85
86        let Some(res) = res else { return Ok(None) };
87
88        Ok(Some(res.into()))
89    }
90
91    #[tracing::instrument(
92        name = "db.oauth2_access_token.find_by_token",
93        skip_all,
94        fields(
95            db.query.text,
96        ),
97        err,
98    )]
99    async fn find_by_token(
100        &mut self,
101        access_token: &str,
102    ) -> Result<Option<AccessToken>, Self::Error> {
103        let res = sqlx::query_as!(
104            OAuth2AccessTokenLookup,
105            r#"
106                SELECT oauth2_access_token_id
107                     , access_token
108                     , created_at
109                     , expires_at
110                     , revoked_at
111                     , oauth2_session_id
112                     , first_used_at
113
114                FROM oauth2_access_tokens
115
116                WHERE access_token = $1
117            "#,
118            access_token,
119        )
120        .fetch_optional(&mut *self.conn)
121        .await?;
122
123        let Some(res) = res else { return Ok(None) };
124
125        Ok(Some(res.into()))
126    }
127
128    #[tracing::instrument(
129        name = "db.oauth2_access_token.add",
130        skip_all,
131        fields(
132            db.query.text,
133            %session.id,
134            client.id = %session.client_id,
135            access_token.id,
136        ),
137        err,
138    )]
139    async fn add(
140        &mut self,
141        rng: &mut (dyn RngCore + Send),
142        clock: &dyn Clock,
143        session: &Session,
144        access_token: String,
145        expires_after: Option<Duration>,
146    ) -> Result<AccessToken, Self::Error> {
147        let created_at = clock.now();
148        let expires_at = expires_after.map(|d| created_at + d);
149        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
150
151        tracing::Span::current().record("access_token.id", tracing::field::display(id));
152
153        sqlx::query!(
154            r#"
155                INSERT INTO oauth2_access_tokens
156                    (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
157                VALUES
158                    ($1, $2, $3, $4, $5)
159            "#,
160            Uuid::from(id),
161            Uuid::from(session.id),
162            &access_token,
163            created_at,
164            expires_at,
165        )
166            .traced()
167        .execute(&mut *self.conn)
168        .await?;
169
170        Ok(AccessToken {
171            id,
172            state: AccessTokenState::default(),
173            access_token,
174            session_id: session.id,
175            created_at,
176            expires_at,
177            first_used_at: None,
178        })
179    }
180
181    #[tracing::instrument(
182        name = "db.oauth2_access_token.revoke",
183        skip_all,
184        fields(
185            db.query.text,
186            session.id = %access_token.session_id,
187            %access_token.id,
188        ),
189        err,
190    )]
191    async fn revoke(
192        &mut self,
193        clock: &dyn Clock,
194        access_token: AccessToken,
195    ) -> Result<AccessToken, Self::Error> {
196        let revoked_at = clock.now();
197        let res = sqlx::query!(
198            r#"
199                UPDATE oauth2_access_tokens
200                SET revoked_at = $2
201                WHERE oauth2_access_token_id = $1
202            "#,
203            Uuid::from(access_token.id),
204            revoked_at,
205        )
206        .traced()
207        .execute(&mut *self.conn)
208        .await?;
209
210        DatabaseError::ensure_affected_rows(&res, 1)?;
211
212        access_token
213            .revoke(revoked_at)
214            .map_err(DatabaseError::to_invalid_operation)
215    }
216
217    #[tracing::instrument(
218        name = "db.oauth2_access_token.mark_used",
219        skip_all,
220        fields(
221            db.query.text,
222            session.id = %access_token.session_id,
223            %access_token.id,
224        ),
225        err,
226    )]
227    async fn mark_used(
228        &mut self,
229        clock: &dyn Clock,
230        mut access_token: AccessToken,
231    ) -> Result<AccessToken, Self::Error> {
232        let now = clock.now();
233        let res = sqlx::query!(
234            r#"
235                UPDATE oauth2_access_tokens
236                SET first_used_at = $2
237                WHERE oauth2_access_token_id = $1
238            "#,
239            Uuid::from(access_token.id),
240            now,
241        )
242        .execute(&mut *self.conn)
243        .await?;
244
245        DatabaseError::ensure_affected_rows(&res, 1)?;
246
247        access_token.first_used_at = Some(now);
248
249        Ok(access_token)
250    }
251
252    #[tracing::instrument(
253        name = "db.oauth2_access_token.cleanup_revoked",
254        skip_all,
255        fields(
256            db.query.text,
257        ),
258        err,
259    )]
260    async fn cleanup_revoked(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error> {
261        // Cleanup token that were revoked more than an hour ago
262        let threshold = clock.now() - Duration::microseconds(60 * 60 * 1000 * 1000);
263        let res = sqlx::query!(
264            r#"
265                DELETE FROM oauth2_access_tokens
266                WHERE revoked_at < $1
267            "#,
268            threshold,
269        )
270        .traced()
271        .execute(&mut *self.conn)
272        .await?;
273
274        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
275    }
276}