mas_storage_pg/compat/
refresh_token.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
11};
12use mas_storage::{Clock, compat::CompatRefreshTokenRepository};
13use rand::RngCore;
14use sqlx::PgConnection;
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, tracing::ExecuteExt};
19
20/// An implementation of [`CompatRefreshTokenRepository`] for a PostgreSQL
21/// connection
22pub struct PgCompatRefreshTokenRepository<'c> {
23    conn: &'c mut PgConnection,
24}
25
26impl<'c> PgCompatRefreshTokenRepository<'c> {
27    /// Create a new [`PgCompatRefreshTokenRepository`] from an active
28    /// PostgreSQL connection
29    pub fn new(conn: &'c mut PgConnection) -> Self {
30        Self { conn }
31    }
32}
33
34struct CompatRefreshTokenLookup {
35    compat_refresh_token_id: Uuid,
36    refresh_token: String,
37    created_at: DateTime<Utc>,
38    consumed_at: Option<DateTime<Utc>>,
39    compat_access_token_id: Uuid,
40    compat_session_id: Uuid,
41}
42
43impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
44    fn from(value: CompatRefreshTokenLookup) -> Self {
45        let state = match value.consumed_at {
46            Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
47            None => CompatRefreshTokenState::Valid,
48        };
49
50        Self {
51            id: value.compat_refresh_token_id.into(),
52            state,
53            session_id: value.compat_session_id.into(),
54            token: value.refresh_token,
55            created_at: value.created_at,
56            access_token_id: value.compat_access_token_id.into(),
57        }
58    }
59}
60
61#[async_trait]
62impl CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'_> {
63    type Error = DatabaseError;
64
65    #[tracing::instrument(
66        name = "db.compat_refresh_token.lookup",
67        skip_all,
68        fields(
69            db.query.text,
70            compat_refresh_token.id = %id,
71        ),
72        err,
73    )]
74    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
75        let res = sqlx::query_as!(
76            CompatRefreshTokenLookup,
77            r#"
78                SELECT compat_refresh_token_id
79                     , refresh_token
80                     , created_at
81                     , consumed_at
82                     , compat_session_id
83                     , compat_access_token_id
84
85                FROM compat_refresh_tokens
86
87                WHERE compat_refresh_token_id = $1
88            "#,
89            Uuid::from(id),
90        )
91        .traced()
92        .fetch_optional(&mut *self.conn)
93        .await?;
94
95        let Some(res) = res else { return Ok(None) };
96
97        Ok(Some(res.into()))
98    }
99
100    #[tracing::instrument(
101        name = "db.compat_refresh_token.find_by_token",
102        skip_all,
103        fields(
104            db.query.text,
105        ),
106        err,
107    )]
108    async fn find_by_token(
109        &mut self,
110        refresh_token: &str,
111    ) -> Result<Option<CompatRefreshToken>, Self::Error> {
112        let res = sqlx::query_as!(
113            CompatRefreshTokenLookup,
114            r#"
115                SELECT compat_refresh_token_id
116                     , refresh_token
117                     , created_at
118                     , consumed_at
119                     , compat_session_id
120                     , compat_access_token_id
121
122                FROM compat_refresh_tokens
123
124                WHERE refresh_token = $1
125            "#,
126            refresh_token,
127        )
128        .traced()
129        .fetch_optional(&mut *self.conn)
130        .await?;
131
132        let Some(res) = res else { return Ok(None) };
133
134        Ok(Some(res.into()))
135    }
136
137    #[tracing::instrument(
138        name = "db.compat_refresh_token.add",
139        skip_all,
140        fields(
141            db.query.text,
142            compat_refresh_token.id,
143            %compat_session.id,
144            user.id = %compat_session.user_id,
145        ),
146        err,
147    )]
148    async fn add(
149        &mut self,
150        rng: &mut (dyn RngCore + Send),
151        clock: &dyn Clock,
152        compat_session: &CompatSession,
153        compat_access_token: &CompatAccessToken,
154        token: String,
155    ) -> Result<CompatRefreshToken, Self::Error> {
156        let created_at = clock.now();
157        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
158        tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
159
160        sqlx::query!(
161            r#"
162                INSERT INTO compat_refresh_tokens
163                    (compat_refresh_token_id, compat_session_id,
164                     compat_access_token_id, refresh_token, created_at)
165                VALUES ($1, $2, $3, $4, $5)
166            "#,
167            Uuid::from(id),
168            Uuid::from(compat_session.id),
169            Uuid::from(compat_access_token.id),
170            token,
171            created_at,
172        )
173        .traced()
174        .execute(&mut *self.conn)
175        .await?;
176
177        Ok(CompatRefreshToken {
178            id,
179            state: CompatRefreshTokenState::default(),
180            session_id: compat_session.id,
181            access_token_id: compat_access_token.id,
182            token,
183            created_at,
184        })
185    }
186
187    #[tracing::instrument(
188        name = "db.compat_refresh_token.consume",
189        skip_all,
190        fields(
191            db.query.text,
192            %compat_refresh_token.id,
193            compat_session.id = %compat_refresh_token.session_id,
194        ),
195        err,
196    )]
197    async fn consume(
198        &mut self,
199        clock: &dyn Clock,
200        compat_refresh_token: CompatRefreshToken,
201    ) -> Result<CompatRefreshToken, Self::Error> {
202        let consumed_at = clock.now();
203        let res = sqlx::query!(
204            r#"
205                UPDATE compat_refresh_tokens
206                SET consumed_at = $2
207                WHERE compat_session_id = $1
208                  AND consumed_at IS NULL
209            "#,
210            Uuid::from(compat_refresh_token.session_id),
211            consumed_at,
212        )
213        .traced()
214        .execute(&mut *self.conn)
215        .await?;
216
217        // This can affect multiple rows in case we've imported refresh tokens
218        // from Synapse. What we care about is that it at least affected one,
219        // which is what we're checking here
220        if res.rows_affected() == 0 {
221            return Err(DatabaseError::RowsAffected {
222                expected: 1,
223                actual: 0,
224            });
225        }
226
227        let compat_refresh_token = compat_refresh_token
228            .consume(consumed_at)
229            .map_err(DatabaseError::to_invalid_operation)?;
230
231        Ok(compat_refresh_token)
232    }
233}