1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{AccessToken, RefreshToken, RefreshTokenState, Session};
10use mas_storage::{Clock, oauth2::OAuth2RefreshTokenRepository};
11use rand::RngCore;
12use sqlx::PgConnection;
13use ulid::Ulid;
14use uuid::Uuid;
15
16use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
17
18pub struct PgOAuth2RefreshTokenRepository<'c> {
21 conn: &'c mut PgConnection,
22}
23
24impl<'c> PgOAuth2RefreshTokenRepository<'c> {
25 pub fn new(conn: &'c mut PgConnection) -> Self {
28 Self { conn }
29 }
30}
31
32struct OAuth2RefreshTokenLookup {
33 oauth2_refresh_token_id: Uuid,
34 refresh_token: String,
35 created_at: DateTime<Utc>,
36 consumed_at: Option<DateTime<Utc>>,
37 revoked_at: Option<DateTime<Utc>>,
38 oauth2_access_token_id: Option<Uuid>,
39 oauth2_session_id: Uuid,
40 next_oauth2_refresh_token_id: Option<Uuid>,
41}
42
43impl TryFrom<OAuth2RefreshTokenLookup> for RefreshToken {
44 type Error = DatabaseInconsistencyError;
45
46 fn try_from(value: OAuth2RefreshTokenLookup) -> Result<Self, Self::Error> {
47 let id = value.oauth2_refresh_token_id.into();
48 let state = match (
49 value.revoked_at,
50 value.consumed_at,
51 value.next_oauth2_refresh_token_id,
52 ) {
53 (None, None, None) => RefreshTokenState::Valid,
54 (Some(revoked_at), None, None) => RefreshTokenState::Revoked { revoked_at },
55 (None, Some(consumed_at), None) => RefreshTokenState::Consumed {
56 consumed_at,
57 next_refresh_token_id: None,
58 },
59 (None, Some(consumed_at), Some(id)) => RefreshTokenState::Consumed {
60 consumed_at,
61 next_refresh_token_id: Some(Ulid::from(id)),
62 },
63 _ => {
64 return Err(DatabaseInconsistencyError::on("oauth2_refresh_tokens")
65 .column("next_oauth2_refresh_token_id")
66 .row(id));
67 }
68 };
69
70 Ok(RefreshToken {
71 id,
72 state,
73 session_id: value.oauth2_session_id.into(),
74 refresh_token: value.refresh_token,
75 created_at: value.created_at,
76 access_token_id: value.oauth2_access_token_id.map(Ulid::from),
77 })
78 }
79}
80
81#[async_trait]
82impl OAuth2RefreshTokenRepository for PgOAuth2RefreshTokenRepository<'_> {
83 type Error = DatabaseError;
84
85 #[tracing::instrument(
86 name = "db.oauth2_refresh_token.lookup",
87 skip_all,
88 fields(
89 db.query.text,
90 refresh_token.id = %id,
91 ),
92 err,
93 )]
94 async fn lookup(&mut self, id: Ulid) -> Result<Option<RefreshToken>, Self::Error> {
95 let res = sqlx::query_as!(
96 OAuth2RefreshTokenLookup,
97 r#"
98 SELECT oauth2_refresh_token_id
99 , refresh_token
100 , created_at
101 , consumed_at
102 , revoked_at
103 , oauth2_access_token_id
104 , oauth2_session_id
105 , next_oauth2_refresh_token_id
106 FROM oauth2_refresh_tokens
107
108 WHERE oauth2_refresh_token_id = $1
109 "#,
110 Uuid::from(id),
111 )
112 .traced()
113 .fetch_optional(&mut *self.conn)
114 .await?;
115
116 let Some(res) = res else { return Ok(None) };
117
118 Ok(Some(res.try_into()?))
119 }
120
121 #[tracing::instrument(
122 name = "db.oauth2_refresh_token.find_by_token",
123 skip_all,
124 fields(
125 db.query.text,
126 ),
127 err,
128 )]
129 async fn find_by_token(
130 &mut self,
131 refresh_token: &str,
132 ) -> Result<Option<RefreshToken>, Self::Error> {
133 let res = sqlx::query_as!(
134 OAuth2RefreshTokenLookup,
135 r#"
136 SELECT oauth2_refresh_token_id
137 , refresh_token
138 , created_at
139 , consumed_at
140 , revoked_at
141 , oauth2_access_token_id
142 , oauth2_session_id
143 , next_oauth2_refresh_token_id
144 FROM oauth2_refresh_tokens
145
146 WHERE refresh_token = $1
147 "#,
148 refresh_token,
149 )
150 .traced()
151 .fetch_optional(&mut *self.conn)
152 .await?;
153
154 let Some(res) = res else { return Ok(None) };
155
156 Ok(Some(res.try_into()?))
157 }
158
159 #[tracing::instrument(
160 name = "db.oauth2_refresh_token.add",
161 skip_all,
162 fields(
163 db.query.text,
164 %session.id,
165 client.id = %session.client_id,
166 refresh_token.id,
167 ),
168 err,
169 )]
170 async fn add(
171 &mut self,
172 rng: &mut (dyn RngCore + Send),
173 clock: &dyn Clock,
174 session: &Session,
175 access_token: &AccessToken,
176 refresh_token: String,
177 ) -> Result<RefreshToken, Self::Error> {
178 let created_at = clock.now();
179 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
180 tracing::Span::current().record("refresh_token.id", tracing::field::display(id));
181
182 sqlx::query!(
183 r#"
184 INSERT INTO oauth2_refresh_tokens
185 (oauth2_refresh_token_id, oauth2_session_id, oauth2_access_token_id,
186 refresh_token, created_at)
187 VALUES
188 ($1, $2, $3, $4, $5)
189 "#,
190 Uuid::from(id),
191 Uuid::from(session.id),
192 Uuid::from(access_token.id),
193 refresh_token,
194 created_at,
195 )
196 .traced()
197 .execute(&mut *self.conn)
198 .await?;
199
200 Ok(RefreshToken {
201 id,
202 state: RefreshTokenState::default(),
203 session_id: session.id,
204 refresh_token,
205 access_token_id: Some(access_token.id),
206 created_at,
207 })
208 }
209
210 #[tracing::instrument(
211 name = "db.oauth2_refresh_token.consume",
212 skip_all,
213 fields(
214 db.query.text,
215 %refresh_token.id,
216 session.id = %refresh_token.session_id,
217 ),
218 err,
219 )]
220 async fn consume(
221 &mut self,
222 clock: &dyn Clock,
223 refresh_token: RefreshToken,
224 replaced_by: &RefreshToken,
225 ) -> Result<RefreshToken, Self::Error> {
226 let consumed_at = clock.now();
227 let res = sqlx::query!(
228 r#"
229 UPDATE oauth2_refresh_tokens
230 SET consumed_at = $2,
231 next_oauth2_refresh_token_id = $3
232 WHERE oauth2_refresh_token_id = $1
233 "#,
234 Uuid::from(refresh_token.id),
235 consumed_at,
236 Uuid::from(replaced_by.id),
237 )
238 .traced()
239 .execute(&mut *self.conn)
240 .await?;
241
242 DatabaseError::ensure_affected_rows(&res, 1)?;
243
244 refresh_token
245 .consume(consumed_at, replaced_by)
246 .map_err(DatabaseError::to_invalid_operation)
247 }
248
249 #[tracing::instrument(
250 name = "db.oauth2_refresh_token.revoke",
251 skip_all,
252 fields(
253 db.query.text,
254 %refresh_token.id,
255 session.id = %refresh_token.session_id,
256 ),
257 err,
258 )]
259 async fn revoke(
260 &mut self,
261 clock: &dyn Clock,
262 refresh_token: RefreshToken,
263 ) -> Result<RefreshToken, Self::Error> {
264 let revoked_at = clock.now();
265 let res = sqlx::query!(
266 r#"
267 UPDATE oauth2_refresh_tokens
268 SET revoked_at = $2
269 WHERE oauth2_refresh_token_id = $1
270 "#,
271 Uuid::from(refresh_token.id),
272 revoked_at,
273 )
274 .traced()
275 .execute(&mut *self.conn)
276 .await?;
277
278 DatabaseError::ensure_affected_rows(&res, 1)?;
279
280 refresh_token
281 .revoke(revoked_at)
282 .map_err(DatabaseError::to_invalid_operation)
283 }
284}