mas_storage_pg/compat/
access_token.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use mas_data_model::{CompatAccessToken, CompatSession};
10use mas_storage::{Clock, compat::CompatAccessTokenRepository};
11use rand::RngCore;
12use sqlx::PgConnection;
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{DatabaseError, tracing::ExecuteExt};
17
18/// An implementation of [`CompatAccessTokenRepository`] for a PostgreSQL
19/// connection
20pub struct PgCompatAccessTokenRepository<'c> {
21    conn: &'c mut PgConnection,
22}
23
24impl<'c> PgCompatAccessTokenRepository<'c> {
25    /// Create a new [`PgCompatAccessTokenRepository`] from an active PostgreSQL
26    /// connection
27    pub fn new(conn: &'c mut PgConnection) -> Self {
28        Self { conn }
29    }
30}
31
32struct CompatAccessTokenLookup {
33    compat_access_token_id: Uuid,
34    access_token: String,
35    created_at: DateTime<Utc>,
36    expires_at: Option<DateTime<Utc>>,
37    compat_session_id: Uuid,
38}
39
40impl From<CompatAccessTokenLookup> for CompatAccessToken {
41    fn from(value: CompatAccessTokenLookup) -> Self {
42        Self {
43            id: value.compat_access_token_id.into(),
44            session_id: value.compat_session_id.into(),
45            token: value.access_token,
46            created_at: value.created_at,
47            expires_at: value.expires_at,
48        }
49    }
50}
51
52#[async_trait]
53impl CompatAccessTokenRepository for PgCompatAccessTokenRepository<'_> {
54    type Error = DatabaseError;
55
56    #[tracing::instrument(
57        name = "db.compat_access_token.lookup",
58        skip_all,
59        fields(
60            db.query.text,
61            compat_session.id = %id,
62        ),
63        err,
64    )]
65    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
66        let res = sqlx::query_as!(
67            CompatAccessTokenLookup,
68            r#"
69                SELECT compat_access_token_id
70                     , access_token
71                     , created_at
72                     , expires_at
73                     , compat_session_id
74
75                FROM compat_access_tokens
76
77                WHERE compat_access_token_id = $1
78            "#,
79            Uuid::from(id),
80        )
81        .traced()
82        .fetch_optional(&mut *self.conn)
83        .await?;
84
85        let Some(res) = res else { return Ok(None) };
86
87        Ok(Some(res.into()))
88    }
89
90    #[tracing::instrument(
91        name = "db.compat_access_token.find_by_token",
92        skip_all,
93        fields(
94            db.query.text,
95        ),
96        err,
97    )]
98    async fn find_by_token(
99        &mut self,
100        access_token: &str,
101    ) -> Result<Option<CompatAccessToken>, Self::Error> {
102        let res = sqlx::query_as!(
103            CompatAccessTokenLookup,
104            r#"
105                SELECT compat_access_token_id
106                     , access_token
107                     , created_at
108                     , expires_at
109                     , compat_session_id
110
111                FROM compat_access_tokens
112
113                WHERE access_token = $1
114            "#,
115            access_token,
116        )
117        .traced()
118        .fetch_optional(&mut *self.conn)
119        .await?;
120
121        let Some(res) = res else { return Ok(None) };
122
123        Ok(Some(res.into()))
124    }
125
126    #[tracing::instrument(
127        name = "db.compat_access_token.add",
128        skip_all,
129        fields(
130            db.query.text,
131            compat_access_token.id,
132            %compat_session.id,
133            user.id = %compat_session.user_id,
134        ),
135        err,
136    )]
137    async fn add(
138        &mut self,
139        rng: &mut (dyn RngCore + Send),
140        clock: &dyn Clock,
141        compat_session: &CompatSession,
142        token: String,
143        expires_after: Option<Duration>,
144    ) -> Result<CompatAccessToken, Self::Error> {
145        let created_at = clock.now();
146        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
147        tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
148
149        let expires_at = expires_after.map(|expires_after| created_at + expires_after);
150
151        sqlx::query!(
152            r#"
153                INSERT INTO compat_access_tokens
154                    (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
155                VALUES ($1, $2, $3, $4, $5)
156            "#,
157            Uuid::from(id),
158            Uuid::from(compat_session.id),
159            token,
160            created_at,
161            expires_at,
162        )
163        .traced()
164        .execute(&mut *self.conn)
165        .await?;
166
167        Ok(CompatAccessToken {
168            id,
169            session_id: compat_session.id,
170            token,
171            created_at,
172            expires_at,
173        })
174    }
175
176    #[tracing::instrument(
177        name = "db.compat_access_token.expire",
178        skip_all,
179        fields(
180            db.query.text,
181            %compat_access_token.id,
182            compat_session.id = %compat_access_token.session_id,
183        ),
184        err,
185    )]
186    async fn expire(
187        &mut self,
188        clock: &dyn Clock,
189        mut compat_access_token: CompatAccessToken,
190    ) -> Result<CompatAccessToken, Self::Error> {
191        let expires_at = clock.now();
192        let res = sqlx::query!(
193            r#"
194                UPDATE compat_access_tokens
195                SET expires_at = $2
196                WHERE compat_access_token_id = $1
197            "#,
198            Uuid::from(compat_access_token.id),
199            expires_at,
200        )
201        .traced()
202        .execute(&mut *self.conn)
203        .await?;
204
205        DatabaseError::ensure_affected_rows(&res, 1)?;
206
207        compat_access_token.expires_at = Some(expires_at);
208        Ok(compat_access_token)
209    }
210}