mas_storage_pg/oauth2/
refresh_token.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
10use mas_storage::{Clock, oauth2::OAuth2RefreshTokenRepository};
11use rand::RngCore;
12use sqlx::PgConnection;
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
17
18/// An implementation of [`OAuth2RefreshTokenRepository`] for a PostgreSQL
19/// connection
20pub struct PgOAuth2RefreshTokenRepository<'c> {
21    conn: &'c mut PgConnection,
22}
23
24impl<'c> PgOAuth2RefreshTokenRepository<'c> {
25    /// Create a new [`PgOAuth2RefreshTokenRepository`] from an active
26    /// PostgreSQL connection
27    pub fn new(conn: &'c mut PgConnection) -> Self {
28        Self { conn }
29    }
30}
31
32struct OAuth2RefreshTokenLookup {
33    oauth2_refresh_token_id: Uuid,
34    refresh_token: String,
35    created_at: DateTime<Utc>,
36    consumed_at: Option<DateTime<Utc>>,
37    revoked_at: Option<DateTime<Utc>>,
38    oauth2_access_token_id: Option<Uuid>,
39    oauth2_session_id: Uuid,
40    next_oauth2_refresh_token_id: Option<Uuid>,
41}
42
43impl TryFrom<OAuth2RefreshTokenLookup> for RefreshToken {
44    type Error = DatabaseInconsistencyError;
45
46    fn try_from(value: OAuth2RefreshTokenLookup) -> Result<Self, Self::Error> {
47        let id = value.oauth2_refresh_token_id.into();
48        let state = match (
49            value.revoked_at,
50            value.consumed_at,
51            value.next_oauth2_refresh_token_id,
52        ) {
53            (None, None, None) => RefreshTokenState::Valid,
54            (Some(revoked_at), None, None) => RefreshTokenState::Revoked { revoked_at },
55            (None, Some(consumed_at), None) => RefreshTokenState::Consumed {
56                consumed_at,
57                next_refresh_token_id: None,
58            },
59            (None, Some(consumed_at), Some(id)) => RefreshTokenState::Consumed {
60                consumed_at,
61                next_refresh_token_id: Some(Ulid::from(id)),
62            },
63            _ => {
64                return Err(DatabaseInconsistencyError::on("oauth2_refresh_tokens")
65                    .column("next_oauth2_refresh_token_id")
66                    .row(id));
67            }
68        };
69
70        Ok(RefreshToken {
71            id,
72            state,
73            session_id: value.oauth2_session_id.into(),
74            refresh_token: value.refresh_token,
75            created_at: value.created_at,
76            access_token_id: value.oauth2_access_token_id.map(Ulid::from),
77        })
78    }
79}
80
81#[async_trait]
82impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
83    type Error = DatabaseError;
84
85    #[tracing::instrument(
86        name = "db.oauth2_refresh_token.lookup",
87        skip_all,
88        fields(
89            db.query.text,
90            refresh_token.id = %id,
91        ),
92        err,
93    )]
94    async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
95        let res = sqlx::query_as!(
96            OAuth2RefreshTokenLookup,
97            r#"
98                SELECT oauth2_refresh_token_id
99                     , refresh_token
100                     , created_at
101                     , consumed_at
102                     , revoked_at
103                     , oauth2_access_token_id
104                     , oauth2_session_id
105                     , next_oauth2_refresh_token_id
106                FROM oauth2_refresh_tokens
107
108                WHERE oauth2_refresh_token_id = $1
109            "#,
110            Uuid::from(id),
111        )
112        .traced()
113        .fetch_optional(&mut *self.conn)
114        .await?;
115
116        let Some(res) = res else { return Ok(None) };
117
118        Ok(Some(res.try_into()?))
119    }
120
121    #[tracing::instrument(
122        name = "db.oauth2_refresh_token.find_by_token",
123        skip_all,
124        fields(
125            db.query.text,
126        ),
127        err,
128    )]
129    async fn find_by_token(
130        &mut self,
131        refresh_token: &str,
132    ) -> Result<Option<RefreshToken>, Self::Error> {
133        let res = sqlx::query_as!(
134            OAuth2RefreshTokenLookup,
135            r#"
136                SELECT oauth2_refresh_token_id
137                     , refresh_token
138                     , created_at
139                     , consumed_at
140                     , revoked_at
141                     , oauth2_access_token_id
142                     , oauth2_session_id
143                     , next_oauth2_refresh_token_id
144                FROM oauth2_refresh_tokens
145
146                WHERE refresh_token = $1
147            "#,
148            refresh_token,
149        )
150        .traced()
151        .fetch_optional(&mut *self.conn)
152        .await?;
153
154        let Some(res) = res else { return Ok(None) };
155
156        Ok(Some(res.try_into()?))
157    }
158
159    #[tracing::instrument(
160        name = "db.oauth2_refresh_token.add",
161        skip_all,
162        fields(
163            db.query.text,
164            %session.id,
165            client.id = %session.client_id,
166            refresh_token.id,
167        ),
168        err,
169    )]
170    async fn add(
171        &mut self,
172        rng: &mut (dyn RngCore + Send),
173        clock: &dyn Clock,
174        session: &Session,
175        access_token: &AccessToken,
176        refresh_token: String,
177    ) -> Result<RefreshToken, Self::Error> {
178        let created_at = clock.now();
179        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
180        tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
181
182        sqlx::query!(
183            r#"
184                INSERT INTO oauth2_refresh_tokens
185                    (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
186                     refresh_token, created_at)
187                VALUES
188                    ($1, $2, $3, $4, $5)
189            "#,
190            Uuid::from(id),
191            Uuid::from(session.id),
192            Uuid::from(access_token.id),
193            refresh_token,
194            created_at,
195        )
196        .traced()
197        .execute(&mut *self.conn)
198        .await?;
199
200        Ok(RefreshToken {
201            id,
202            state: RefreshTokenState::default(),
203            session_id: session.id,
204            refresh_token,
205            access_token_id: Some(access_token.id),
206            created_at,
207        })
208    }
209
210    #[tracing::instrument(
211        name = "db.oauth2_refresh_token.consume",
212        skip_all,
213        fields(
214            db.query.text,
215            %refresh_token.id,
216            session.id = %refresh_token.session_id,
217        ),
218        err,
219    )]
220    async fn consume(
221        &mut self,
222        clock: &dyn Clock,
223        refresh_token: RefreshToken,
224        replaced_by: &RefreshToken,
225    ) -> Result<RefreshToken, Self::Error> {
226        let consumed_at = clock.now();
227        let res = sqlx::query!(
228            r#"
229                UPDATE oauth2_refresh_tokens
230                SET consumed_at = $2,
231                    next_oauth2_refresh_token_id = $3
232                WHERE oauth2_refresh_token_id = $1
233            "#,
234            Uuid::from(refresh_token.id),
235            consumed_at,
236            Uuid::from(replaced_by.id),
237        )
238        .traced()
239        .execute(&mut *self.conn)
240        .await?;
241
242        DatabaseError::ensure_affected_rows(&res, 1)?;
243
244        refresh_token
245            .consume(consumed_at, replaced_by)
246            .map_err(DatabaseError::to_invalid_operation)
247    }
248
249    #[tracing::instrument(
250        name = "db.oauth2_refresh_token.revoke",
251        skip_all,
252        fields(
253            db.query.text,
254            %refresh_token.id,
255            session.id = %refresh_token.session_id,
256        ),
257        err,
258    )]
259    async fn revoke(
260        &mut self,
261        clock: &dyn Clock,
262        refresh_token: RefreshToken,
263    ) -> Result<RefreshToken, Self::Error> {
264        let revoked_at = clock.now();
265        let res = sqlx::query!(
266            r#"
267                UPDATE oauth2_refresh_tokens
268                SET revoked_at = $2
269                WHERE oauth2_refresh_token_id = $1
270            "#,
271            Uuid::from(refresh_token.id),
272            revoked_at,
273        )
274        .traced()
275        .execute(&mut *self.conn)
276        .await?;
277
278        DatabaseError::ensure_affected_rows(&res, 1)?;
279
280        refresh_token
281            .revoke(revoked_at)
282            .map_err(DatabaseError::to_invalid_operation)
283    }
284}