1use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use mas_data_model::{CompatAccessToken, CompatSession};
10use mas_storage::{Clock, compat::CompatAccessTokenRepository};
11use rand::RngCore;
12use sqlx::PgConnection;
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{DatabaseError, tracing::ExecuteExt};
17
18pub struct PgCompatAccessTokenRepository<'c> {
21 conn: &'c mut PgConnection,
22}
23
24impl<'c> PgCompatAccessTokenRepository<'c> {
25 pub fn new(conn: &'c mut PgConnection) -> Self {
28 Self { conn }
29 }
30}
31
32struct CompatAccessTokenLookup {
33 compat_access_token_id: Uuid,
34 access_token: String,
35 created_at: DateTime<Utc>,
36 expires_at: Option<DateTime<Utc>>,
37 compat_session_id: Uuid,
38}
39
40impl From<CompatAccessTokenLookup> for CompatAccessToken {
41 fn from(value: CompatAccessTokenLookup) -> Self {
42 Self {
43 id: value.compat_access_token_id.into(),
44 session_id: value.compat_session_id.into(),
45 token: value.access_token,
46 created_at: value.created_at,
47 expires_at: value.expires_at,
48 }
49 }
50}
51
52#[async_trait]
53impl CompatAccessTokenRepository for PgCompatAccessTokenRepository<'_> {
54 type Error = DatabaseError;
55
56 #[tracing::instrument(
57 name = "db.compat_access_token.lookup",
58 skip_all,
59 fields(
60 db.query.text,
61 compat_session.id = %id,
62 ),
63 err,
64 )]
65 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatAccessToken>, Self::Error> {
66 let res = sqlx::query_as!(
67 CompatAccessTokenLookup,
68 r#"
69 SELECT compat_access_token_id
70 , access_token
71 , created_at
72 , expires_at
73 , compat_session_id
74
75 FROM compat_access_tokens
76
77 WHERE compat_access_token_id = $1
78 "#,
79 Uuid::from(id),
80 )
81 .traced()
82 .fetch_optional(&mut *self.conn)
83 .await?;
84
85 let Some(res) = res else { return Ok(None) };
86
87 Ok(Some(res.into()))
88 }
89
90 #[tracing::instrument(
91 name = "db.compat_access_token.find_by_token",
92 skip_all,
93 fields(
94 db.query.text,
95 ),
96 err,
97 )]
98 async fn find_by_token(
99 &mut self,
100 access_token: &str,
101 ) -> Result<Option<CompatAccessToken>, Self::Error> {
102 let res = sqlx::query_as!(
103 CompatAccessTokenLookup,
104 r#"
105 SELECT compat_access_token_id
106 , access_token
107 , created_at
108 , expires_at
109 , compat_session_id
110
111 FROM compat_access_tokens
112
113 WHERE access_token = $1
114 "#,
115 access_token,
116 )
117 .traced()
118 .fetch_optional(&mut *self.conn)
119 .await?;
120
121 let Some(res) = res else { return Ok(None) };
122
123 Ok(Some(res.into()))
124 }
125
126 #[tracing::instrument(
127 name = "db.compat_access_token.add",
128 skip_all,
129 fields(
130 db.query.text,
131 compat_access_token.id,
132 %compat_session.id,
133 user.id = %compat_session.user_id,
134 ),
135 err,
136 )]
137 async fn add(
138 &mut self,
139 rng: &mut (dyn RngCore + Send),
140 clock: &dyn Clock,
141 compat_session: &CompatSession,
142 token: String,
143 expires_after: Option<Duration>,
144 ) -> Result<CompatAccessToken, Self::Error> {
145 let created_at = clock.now();
146 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
147 tracing::Span::current().record("compat_access_token.id", tracing::field::display(id));
148
149 let expires_at = expires_after.map(|expires_after| created_at + expires_after);
150
151 sqlx::query!(
152 r#"
153 INSERT INTO compat_access_tokens
154 (compat_access_token_id, compat_session_id, access_token, created_at, expires_at)
155 VALUES ($1, $2, $3, $4, $5)
156 "#,
157 Uuid::from(id),
158 Uuid::from(compat_session.id),
159 token,
160 created_at,
161 expires_at,
162 )
163 .traced()
164 .execute(&mut *self.conn)
165 .await?;
166
167 Ok(CompatAccessToken {
168 id,
169 session_id: compat_session.id,
170 token,
171 created_at,
172 expires_at,
173 })
174 }
175
176 #[tracing::instrument(
177 name = "db.compat_access_token.expire",
178 skip_all,
179 fields(
180 db.query.text,
181 %compat_access_token.id,
182 compat_session.id = %compat_access_token.session_id,
183 ),
184 err,
185 )]
186 async fn expire(
187 &mut self,
188 clock: &dyn Clock,
189 mut compat_access_token: CompatAccessToken,
190 ) -> Result<CompatAccessToken, Self::Error> {
191 let expires_at = clock.now();
192 let res = sqlx::query!(
193 r#"
194 UPDATE compat_access_tokens
195 SET expires_at = $2
196 WHERE compat_access_token_id = $1
197 "#,
198 Uuid::from(compat_access_token.id),
199 expires_at,
200 )
201 .traced()
202 .execute(&mut *self.conn)
203 .await?;
204
205 DatabaseError::ensure_affected_rows(&res, 1)?;
206
207 compat_access_token.expires_at = Some(expires_at);
208 Ok(compat_access_token)
209 }
210}