1use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Clock, DeviceCodeGrant, DeviceCodeGrantState, Session};
13use mas_storage::oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository};
14use oauth2_types::scope::Scope;
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
21
22pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
25 conn: &'c mut PgConnection,
26}
27
28impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
29 pub fn new(conn: &'c mut PgConnection) -> Self {
32 Self { conn }
33 }
34}
35
36struct OAuth2DeviceGrantLookup {
37 oauth2_device_code_grant_id: Uuid,
38 oauth2_client_id: Uuid,
39 scope: String,
40 device_code: String,
41 user_code: String,
42 created_at: DateTime<Utc>,
43 expires_at: DateTime<Utc>,
44 fulfilled_at: Option<DateTime<Utc>>,
45 rejected_at: Option<DateTime<Utc>>,
46 exchanged_at: Option<DateTime<Utc>>,
47 user_session_id: Option<Uuid>,
48 oauth2_session_id: Option<Uuid>,
49 ip_address: Option<IpAddr>,
50 user_agent: Option<String>,
51}
52
53impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
54 type Error = DatabaseInconsistencyError;
55
56 fn try_from(
57 OAuth2DeviceGrantLookup {
58 oauth2_device_code_grant_id,
59 oauth2_client_id,
60 scope,
61 device_code,
62 user_code,
63 created_at,
64 expires_at,
65 fulfilled_at,
66 rejected_at,
67 exchanged_at,
68 user_session_id,
69 oauth2_session_id,
70 ip_address,
71 user_agent,
72 }: OAuth2DeviceGrantLookup,
73 ) -> Result<Self, Self::Error> {
74 let id = Ulid::from(oauth2_device_code_grant_id);
75 let client_id = Ulid::from(oauth2_client_id);
76
77 let scope: Scope = scope.parse().map_err(|e| {
78 DatabaseInconsistencyError::on("oauth2_authorization_grants")
79 .column("scope")
80 .row(id)
81 .source(e)
82 })?;
83
84 let state = match (
85 fulfilled_at,
86 rejected_at,
87 exchanged_at,
88 user_session_id,
89 oauth2_session_id,
90 ) {
91 (None, None, None, None, None) => DeviceCodeGrantState::Pending,
92
93 (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
94 DeviceCodeGrantState::Fulfilled {
95 browser_session_id: Ulid::from(user_session_id),
96 fulfilled_at,
97 }
98 }
99
100 (None, Some(rejected_at), None, Some(user_session_id), None) => {
101 DeviceCodeGrantState::Rejected {
102 browser_session_id: Ulid::from(user_session_id),
103 rejected_at,
104 }
105 }
106
107 (
108 Some(fulfilled_at),
109 None,
110 Some(exchanged_at),
111 Some(user_session_id),
112 Some(oauth2_session_id),
113 ) => DeviceCodeGrantState::Exchanged {
114 browser_session_id: Ulid::from(user_session_id),
115 session_id: Ulid::from(oauth2_session_id),
116 fulfilled_at,
117 exchanged_at,
118 },
119
120 _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
121 };
122
123 Ok(DeviceCodeGrant {
124 id,
125 state,
126 client_id,
127 scope,
128 user_code,
129 device_code,
130 created_at,
131 expires_at,
132 ip_address,
133 user_agent,
134 })
135 }
136}
137
138#[async_trait]
139impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
140 type Error = DatabaseError;
141
142 #[tracing::instrument(
143 name = "db.oauth2_device_code_grant.add",
144 skip_all,
145 fields(
146 db.query.text,
147 oauth2_device_code.id,
148 oauth2_device_code.scope = %params.scope,
149 oauth2_client.id = %params.client.id,
150 ),
151 err,
152 )]
153 async fn add(
154 &mut self,
155 rng: &mut (dyn RngCore + Send),
156 clock: &dyn Clock,
157 params: OAuth2DeviceCodeGrantParams<'_>,
158 ) -> Result<DeviceCodeGrant, Self::Error> {
159 let now = clock.now();
160 let id = Ulid::from_datetime_with_source(now.into(), rng);
161 tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
162
163 let created_at = now;
164 let expires_at = now + params.expires_in;
165 let client_id = params.client.id;
166
167 sqlx::query!(
168 r#"
169 INSERT INTO "oauth2_device_code_grant"
170 ( oauth2_device_code_grant_id
171 , oauth2_client_id
172 , scope
173 , device_code
174 , user_code
175 , created_at
176 , expires_at
177 , ip_address
178 , user_agent
179 )
180 VALUES
181 ($1, $2, $3, $4, $5, $6, $7, $8, $9)
182 "#,
183 Uuid::from(id),
184 Uuid::from(client_id),
185 params.scope.to_string(),
186 ¶ms.device_code,
187 ¶ms.user_code,
188 created_at,
189 expires_at,
190 params.ip_address as Option<IpAddr>,
191 params.user_agent.as_deref(),
192 )
193 .traced()
194 .execute(&mut *self.conn)
195 .await?;
196
197 Ok(DeviceCodeGrant {
198 id,
199 state: DeviceCodeGrantState::Pending,
200 client_id,
201 scope: params.scope,
202 user_code: params.user_code,
203 device_code: params.device_code,
204 created_at,
205 expires_at,
206 ip_address: params.ip_address,
207 user_agent: params.user_agent,
208 })
209 }
210
211 #[tracing::instrument(
212 name = "db.oauth2_device_code_grant.lookup",
213 skip_all,
214 fields(
215 db.query.text,
216 oauth2_device_code.id = %id,
217 ),
218 err,
219 )]
220 async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
221 let res = sqlx::query_as!(
222 OAuth2DeviceGrantLookup,
223 r#"
224 SELECT oauth2_device_code_grant_id
225 , oauth2_client_id
226 , scope
227 , device_code
228 , user_code
229 , created_at
230 , expires_at
231 , fulfilled_at
232 , rejected_at
233 , exchanged_at
234 , user_session_id
235 , oauth2_session_id
236 , ip_address as "ip_address: IpAddr"
237 , user_agent
238 FROM
239 oauth2_device_code_grant
240
241 WHERE oauth2_device_code_grant_id = $1
242 "#,
243 Uuid::from(id),
244 )
245 .traced()
246 .fetch_optional(&mut *self.conn)
247 .await?;
248
249 let Some(res) = res else { return Ok(None) };
250
251 Ok(Some(res.try_into()?))
252 }
253
254 #[tracing::instrument(
255 name = "db.oauth2_device_code_grant.find_by_user_code",
256 skip_all,
257 fields(
258 db.query.text,
259 oauth2_device_code.user_code = %user_code,
260 ),
261 err,
262 )]
263 async fn find_by_user_code(
264 &mut self,
265 user_code: &str,
266 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
267 let res = sqlx::query_as!(
268 OAuth2DeviceGrantLookup,
269 r#"
270 SELECT oauth2_device_code_grant_id
271 , oauth2_client_id
272 , scope
273 , device_code
274 , user_code
275 , created_at
276 , expires_at
277 , fulfilled_at
278 , rejected_at
279 , exchanged_at
280 , user_session_id
281 , oauth2_session_id
282 , ip_address as "ip_address: IpAddr"
283 , user_agent
284 FROM
285 oauth2_device_code_grant
286
287 WHERE user_code = $1
288 "#,
289 user_code,
290 )
291 .traced()
292 .fetch_optional(&mut *self.conn)
293 .await?;
294
295 let Some(res) = res else { return Ok(None) };
296
297 Ok(Some(res.try_into()?))
298 }
299
300 #[tracing::instrument(
301 name = "db.oauth2_device_code_grant.find_by_device_code",
302 skip_all,
303 fields(
304 db.query.text,
305 oauth2_device_code.device_code = %device_code,
306 ),
307 err,
308 )]
309 async fn find_by_device_code(
310 &mut self,
311 device_code: &str,
312 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
313 let res = sqlx::query_as!(
314 OAuth2DeviceGrantLookup,
315 r#"
316 SELECT oauth2_device_code_grant_id
317 , oauth2_client_id
318 , scope
319 , device_code
320 , user_code
321 , created_at
322 , expires_at
323 , fulfilled_at
324 , rejected_at
325 , exchanged_at
326 , user_session_id
327 , oauth2_session_id
328 , ip_address as "ip_address: IpAddr"
329 , user_agent
330 FROM
331 oauth2_device_code_grant
332
333 WHERE device_code = $1
334 "#,
335 device_code,
336 )
337 .traced()
338 .fetch_optional(&mut *self.conn)
339 .await?;
340
341 let Some(res) = res else { return Ok(None) };
342
343 Ok(Some(res.try_into()?))
344 }
345
346 #[tracing::instrument(
347 name = "db.oauth2_device_code_grant.fulfill",
348 skip_all,
349 fields(
350 db.query.text,
351 oauth2_device_code.id = %device_code_grant.id,
352 oauth2_client.id = %device_code_grant.client_id,
353 browser_session.id = %browser_session.id,
354 user.id = %browser_session.user.id,
355 ),
356 err,
357 )]
358 async fn fulfill(
359 &mut self,
360 clock: &dyn Clock,
361 device_code_grant: DeviceCodeGrant,
362 browser_session: &BrowserSession,
363 ) -> Result<DeviceCodeGrant, Self::Error> {
364 let fulfilled_at = clock.now();
365 let device_code_grant = device_code_grant
366 .fulfill(browser_session, fulfilled_at)
367 .map_err(DatabaseError::to_invalid_operation)?;
368
369 let res = sqlx::query!(
370 r#"
371 UPDATE oauth2_device_code_grant
372 SET fulfilled_at = $1
373 , user_session_id = $2
374 WHERE oauth2_device_code_grant_id = $3
375 "#,
376 fulfilled_at,
377 Uuid::from(browser_session.id),
378 Uuid::from(device_code_grant.id),
379 )
380 .traced()
381 .execute(&mut *self.conn)
382 .await?;
383
384 DatabaseError::ensure_affected_rows(&res, 1)?;
385
386 Ok(device_code_grant)
387 }
388
389 #[tracing::instrument(
390 name = "db.oauth2_device_code_grant.reject",
391 skip_all,
392 fields(
393 db.query.text,
394 oauth2_device_code.id = %device_code_grant.id,
395 oauth2_client.id = %device_code_grant.client_id,
396 browser_session.id = %browser_session.id,
397 user.id = %browser_session.user.id,
398 ),
399 err,
400 )]
401 async fn reject(
402 &mut self,
403 clock: &dyn Clock,
404 device_code_grant: DeviceCodeGrant,
405 browser_session: &BrowserSession,
406 ) -> Result<DeviceCodeGrant, Self::Error> {
407 let fulfilled_at = clock.now();
408 let device_code_grant = device_code_grant
409 .reject(browser_session, fulfilled_at)
410 .map_err(DatabaseError::to_invalid_operation)?;
411
412 let res = sqlx::query!(
413 r#"
414 UPDATE oauth2_device_code_grant
415 SET rejected_at = $1
416 , user_session_id = $2
417 WHERE oauth2_device_code_grant_id = $3
418 "#,
419 fulfilled_at,
420 Uuid::from(browser_session.id),
421 Uuid::from(device_code_grant.id),
422 )
423 .traced()
424 .execute(&mut *self.conn)
425 .await?;
426
427 DatabaseError::ensure_affected_rows(&res, 1)?;
428
429 Ok(device_code_grant)
430 }
431
432 #[tracing::instrument(
433 name = "db.oauth2_device_code_grant.exchange",
434 skip_all,
435 fields(
436 db.query.text,
437 oauth2_device_code.id = %device_code_grant.id,
438 oauth2_client.id = %device_code_grant.client_id,
439 oauth2_session.id = %session.id,
440 ),
441 err,
442 )]
443 async fn exchange(
444 &mut self,
445 clock: &dyn Clock,
446 device_code_grant: DeviceCodeGrant,
447 session: &Session,
448 ) -> Result<DeviceCodeGrant, Self::Error> {
449 let exchanged_at = clock.now();
450 let device_code_grant = device_code_grant
451 .exchange(session, exchanged_at)
452 .map_err(DatabaseError::to_invalid_operation)?;
453
454 let res = sqlx::query!(
455 r#"
456 UPDATE oauth2_device_code_grant
457 SET exchanged_at = $1
458 , oauth2_session_id = $2
459 WHERE oauth2_device_code_grant_id = $3
460 "#,
461 exchanged_at,
462 Uuid::from(session.id),
463 Uuid::from(device_code_grant.id),
464 )
465 .traced()
466 .execute(&mut *self.conn)
467 .await?;
468
469 DatabaseError::ensure_affected_rows(&res, 1)?;
470
471 Ok(device_code_grant)
472 }
473
474 #[tracing::instrument(
475 name = "db.oauth2_device_code_grant.cleanup",
476 skip_all,
477 fields(
478 db.query.text,
479 since = since.map(tracing::field::display),
480 until = %until,
481 limit = limit,
482 ),
483 err,
484 )]
485 async fn cleanup(
486 &mut self,
487 since: Option<Ulid>,
488 until: Ulid,
489 limit: usize,
490 ) -> Result<(usize, Option<Ulid>), Self::Error> {
491 let res = sqlx::query_scalar!(
496 r#"
497 WITH to_delete AS (
498 SELECT oauth2_device_code_grant_id
499 FROM oauth2_device_code_grant
500 WHERE ($1::uuid IS NULL OR oauth2_device_code_grant_id > $1)
501 AND oauth2_device_code_grant_id <= $2
502 ORDER BY oauth2_device_code_grant_id
503 LIMIT $3
504 )
505 DELETE FROM oauth2_device_code_grant
506 USING to_delete
507 WHERE oauth2_device_code_grant.oauth2_device_code_grant_id = to_delete.oauth2_device_code_grant_id
508 RETURNING oauth2_device_code_grant.oauth2_device_code_grant_id
509 "#,
510 since.map(Uuid::from),
511 Uuid::from(until),
512 i64::try_from(limit).unwrap_or(i64::MAX)
513 )
514 .traced()
515 .fetch_all(&mut *self.conn)
516 .await?;
517
518 let count = res.len();
519 let max_id = res.into_iter().max();
520
521 Ok((count, max_id.map(Ulid::from)))
522 }
523}