1use async_trait::async_trait;
8use chrono::{DateTime, Duration, Utc};
9use mas_data_model::{AccessToken, AccessTokenState, Session};
10use mas_storage::{Clock, oauth2::OAuth2AccessTokenRepository};
11use rand::RngCore;
12use sqlx::PgConnection;
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{DatabaseError, tracing::ExecuteExt};
17
18pub struct PgOAuth2AccessTokenRepository<'c> {
21 conn: &'c mut PgConnection,
22}
23
24impl<'c> PgOAuth2AccessTokenRepository<'c> {
25 pub fn new(conn: &'c mut PgConnection) -> Self {
28 Self { conn }
29 }
30}
31
32struct OAuth2AccessTokenLookup {
33 oauth2_access_token_id: Uuid,
34 oauth2_session_id: Uuid,
35 access_token: String,
36 created_at: DateTime<Utc>,
37 expires_at: Option<DateTime<Utc>>,
38 revoked_at: Option<DateTime<Utc>>,
39 first_used_at: Option<DateTime<Utc>>,
40}
41
42impl From<OAuth2AccessTokenLookup> for AccessToken {
43 fn from(value: OAuth2AccessTokenLookup) -> Self {
44 let state = match value.revoked_at {
45 None => AccessTokenState::Valid,
46 Some(revoked_at) => AccessTokenState::Revoked { revoked_at },
47 };
48
49 Self {
50 id: value.oauth2_access_token_id.into(),
51 state,
52 session_id: value.oauth2_session_id.into(),
53 access_token: value.access_token,
54 created_at: value.created_at,
55 expires_at: value.expires_at,
56 first_used_at: value.first_used_at,
57 }
58 }
59}
60
61#[async_trait]
62impl OAuth2AccessTokenRepository for PgOAuth2AccessTokenRepository<'_> {
63 type Error = DatabaseError;
64
65 async fn lookup(&mut self, id: Ulid) -> Result<Option<AccessToken>, Self::Error> {
66 let res = sqlx::query_as!(
67 OAuth2AccessTokenLookup,
68 r#"
69 SELECT oauth2_access_token_id
70 , access_token
71 , created_at
72 , expires_at
73 , revoked_at
74 , oauth2_session_id
75 , first_used_at
76
77 FROM oauth2_access_tokens
78
79 WHERE oauth2_access_token_id = $1
80 "#,
81 Uuid::from(id),
82 )
83 .fetch_optional(&mut *self.conn)
84 .await?;
85
86 let Some(res) = res else { return Ok(None) };
87
88 Ok(Some(res.into()))
89 }
90
91 #[tracing::instrument(
92 name = "db.oauth2_access_token.find_by_token",
93 skip_all,
94 fields(
95 db.query.text,
96 ),
97 err,
98 )]
99 async fn find_by_token(
100 &mut self,
101 access_token: &str,
102 ) -> Result<Option<AccessToken>, Self::Error> {
103 let res = sqlx::query_as!(
104 OAuth2AccessTokenLookup,
105 r#"
106 SELECT oauth2_access_token_id
107 , access_token
108 , created_at
109 , expires_at
110 , revoked_at
111 , oauth2_session_id
112 , first_used_at
113
114 FROM oauth2_access_tokens
115
116 WHERE access_token = $1
117 "#,
118 access_token,
119 )
120 .fetch_optional(&mut *self.conn)
121 .await?;
122
123 let Some(res) = res else { return Ok(None) };
124
125 Ok(Some(res.into()))
126 }
127
128 #[tracing::instrument(
129 name = "db.oauth2_access_token.add",
130 skip_all,
131 fields(
132 db.query.text,
133 %session.id,
134 client.id = %session.client_id,
135 access_token.id,
136 ),
137 err,
138 )]
139 async fn add(
140 &mut self,
141 rng: &mut (dyn RngCore + Send),
142 clock: &dyn Clock,
143 session: &Session,
144 access_token: String,
145 expires_after: Option<Duration>,
146 ) -> Result<AccessToken, Self::Error> {
147 let created_at = clock.now();
148 let expires_at = expires_after.map(|d| created_at + d);
149 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
150
151 tracing::Span::current().record("access_token.id", tracing::field::display(id));
152
153 sqlx::query!(
154 r#"
155 INSERT INTO oauth2_access_tokens
156 (oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
157 VALUES
158 ($1, $2, $3, $4, $5)
159 "#,
160 Uuid::from(id),
161 Uuid::from(session.id),
162 &access_token,
163 created_at,
164 expires_at,
165 )
166 .traced()
167 .execute(&mut *self.conn)
168 .await?;
169
170 Ok(AccessToken {
171 id,
172 state: AccessTokenState::default(),
173 access_token,
174 session_id: session.id,
175 created_at,
176 expires_at,
177 first_used_at: None,
178 })
179 }
180
181 #[tracing::instrument(
182 name = "db.oauth2_access_token.revoke",
183 skip_all,
184 fields(
185 db.query.text,
186 session.id = %access_token.session_id,
187 %access_token.id,
188 ),
189 err,
190 )]
191 async fn revoke(
192 &mut self,
193 clock: &dyn Clock,
194 access_token: AccessToken,
195 ) -> Result<AccessToken, Self::Error> {
196 let revoked_at = clock.now();
197 let res = sqlx::query!(
198 r#"
199 UPDATE oauth2_access_tokens
200 SET revoked_at = $2
201 WHERE oauth2_access_token_id = $1
202 "#,
203 Uuid::from(access_token.id),
204 revoked_at,
205 )
206 .traced()
207 .execute(&mut *self.conn)
208 .await?;
209
210 DatabaseError::ensure_affected_rows(&res, 1)?;
211
212 access_token
213 .revoke(revoked_at)
214 .map_err(DatabaseError::to_invalid_operation)
215 }
216
217 #[tracing::instrument(
218 name = "db.oauth2_access_token.mark_used",
219 skip_all,
220 fields(
221 db.query.text,
222 session.id = %access_token.session_id,
223 %access_token.id,
224 ),
225 err,
226 )]
227 async fn mark_used(
228 &mut self,
229 clock: &dyn Clock,
230 mut access_token: AccessToken,
231 ) -> Result<AccessToken, Self::Error> {
232 let now = clock.now();
233 let res = sqlx::query!(
234 r#"
235 UPDATE oauth2_access_tokens
236 SET first_used_at = $2
237 WHERE oauth2_access_token_id = $1
238 "#,
239 Uuid::from(access_token.id),
240 now,
241 )
242 .execute(&mut *self.conn)
243 .await?;
244
245 DatabaseError::ensure_affected_rows(&res, 1)?;
246
247 access_token.first_used_at = Some(now);
248
249 Ok(access_token)
250 }
251
252 #[tracing::instrument(
253 name = "db.oauth2_access_token.cleanup_revoked",
254 skip_all,
255 fields(
256 db.query.text,
257 ),
258 err,
259 )]
260 async fn cleanup_revoked(&mut self, clock: &dyn Clock) -> Result<usize, Self::Error> {
261 let threshold = clock.now() - Duration::microseconds(60 * 60 * 1000 * 1000);
263 let res = sqlx::query!(
264 r#"
265 DELETE FROM oauth2_access_tokens
266 WHERE revoked_at < $1
267 "#,
268 threshold,
269 )
270 .traced()
271 .execute(&mut *self.conn)
272 .await?;
273
274 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
275 }
276}