1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 BrowserSession, Clock, UpstreamOAuthAuthorizationSession,
12 UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider,
13};
14use mas_storage::{
15 Page, Pagination,
16 pagination::Node,
17 upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27 DatabaseError, DatabaseInconsistencyError,
28 filter::{Filter, StatementExt},
29 iden::UpstreamOAuthAuthorizationSessions,
30 pagination::QueryBuilderExt,
31 tracing::ExecuteExt,
32};
33
34impl Filter for UpstreamOAuthSessionFilter<'_> {
35 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
36 sea_query::Condition::all()
37 .add_option(self.provider().map(|provider| {
38 Expr::col((
39 UpstreamOAuthAuthorizationSessions::Table,
40 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
41 ))
42 .eq(Uuid::from(provider.id))
43 }))
44 .add_option(self.sub_claim().map(|sub| {
45 Expr::col((
46 UpstreamOAuthAuthorizationSessions::Table,
47 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
48 ))
49 .cast_json_field("sub")
50 .eq(sub)
51 }))
52 .add_option(self.sid_claim().map(|sid| {
53 Expr::col((
54 UpstreamOAuthAuthorizationSessions::Table,
55 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
56 ))
57 .cast_json_field("sid")
58 .eq(sid)
59 }))
60 }
61}
62
63pub struct PgUpstreamOAuthSessionRepository<'c> {
66 conn: &'c mut PgConnection,
67}
68
69impl<'c> PgUpstreamOAuthSessionRepository<'c> {
70 pub fn new(conn: &'c mut PgConnection) -> Self {
73 Self { conn }
74 }
75}
76
77#[derive(sqlx::FromRow)]
78#[enum_def]
79struct SessionLookup {
80 upstream_oauth_authorization_session_id: Uuid,
81 upstream_oauth_provider_id: Uuid,
82 upstream_oauth_link_id: Option<Uuid>,
83 state: String,
84 code_challenge_verifier: Option<String>,
85 nonce: Option<String>,
86 id_token: Option<String>,
87 id_token_claims: Option<serde_json::Value>,
88 userinfo: Option<serde_json::Value>,
89 created_at: DateTime<Utc>,
90 completed_at: Option<DateTime<Utc>>,
91 consumed_at: Option<DateTime<Utc>>,
92 extra_callback_parameters: Option<serde_json::Value>,
93 unlinked_at: Option<DateTime<Utc>>,
94}
95
96impl Node<Ulid> for SessionLookup {
97 fn cursor(&self) -> Ulid {
98 self.upstream_oauth_authorization_session_id.into()
99 }
100}
101
102impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
103 type Error = DatabaseInconsistencyError;
104
105 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
106 let id = value.upstream_oauth_authorization_session_id.into();
107 let state = match (
108 value.upstream_oauth_link_id,
109 value.id_token,
110 value.id_token_claims,
111 value.extra_callback_parameters,
112 value.userinfo,
113 value.completed_at,
114 value.consumed_at,
115 value.unlinked_at,
116 ) {
117 (None, None, None, None, None, None, None, None) => {
118 UpstreamOAuthAuthorizationSessionState::Pending
119 }
120 (
121 Some(link_id),
122 id_token,
123 id_token_claims,
124 extra_callback_parameters,
125 userinfo,
126 Some(completed_at),
127 None,
128 None,
129 ) => UpstreamOAuthAuthorizationSessionState::Completed {
130 completed_at,
131 link_id: link_id.into(),
132 id_token,
133 id_token_claims,
134 extra_callback_parameters,
135 userinfo,
136 },
137 (
138 Some(link_id),
139 id_token,
140 id_token_claims,
141 extra_callback_parameters,
142 userinfo,
143 Some(completed_at),
144 Some(consumed_at),
145 None,
146 ) => UpstreamOAuthAuthorizationSessionState::Consumed {
147 completed_at,
148 link_id: link_id.into(),
149 id_token,
150 id_token_claims,
151 extra_callback_parameters,
152 userinfo,
153 consumed_at,
154 },
155 (
156 _,
157 id_token,
158 id_token_claims,
159 _,
160 _,
161 Some(completed_at),
162 consumed_at,
163 Some(unlinked_at),
164 ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
165 completed_at,
166 id_token,
167 id_token_claims,
168 consumed_at,
169 unlinked_at,
170 },
171 _ => {
172 return Err(DatabaseInconsistencyError::on(
173 "upstream_oauth_authorization_sessions",
174 )
175 .row(id));
176 }
177 };
178
179 Ok(Self {
180 id,
181 provider_id: value.upstream_oauth_provider_id.into(),
182 state_str: value.state,
183 nonce: value.nonce,
184 code_challenge_verifier: value.code_challenge_verifier,
185 created_at: value.created_at,
186 state,
187 })
188 }
189}
190
191#[async_trait]
192impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
193 type Error = DatabaseError;
194
195 #[tracing::instrument(
196 name = "db.upstream_oauth_authorization_session.lookup",
197 skip_all,
198 fields(
199 db.query.text,
200 upstream_oauth_provider.id = %id,
201 ),
202 err,
203 )]
204 async fn lookup(
205 &mut self,
206 id: Ulid,
207 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
208 let res = sqlx::query_as!(
209 SessionLookup,
210 r#"
211 SELECT
212 upstream_oauth_authorization_session_id,
213 upstream_oauth_provider_id,
214 upstream_oauth_link_id,
215 state,
216 code_challenge_verifier,
217 nonce,
218 id_token,
219 id_token_claims,
220 extra_callback_parameters,
221 userinfo,
222 created_at,
223 completed_at,
224 consumed_at,
225 unlinked_at
226 FROM upstream_oauth_authorization_sessions
227 WHERE upstream_oauth_authorization_session_id = $1
228 "#,
229 Uuid::from(id),
230 )
231 .traced()
232 .fetch_optional(&mut *self.conn)
233 .await?;
234
235 let Some(res) = res else { return Ok(None) };
236
237 Ok(Some(res.try_into()?))
238 }
239
240 #[tracing::instrument(
241 name = "db.upstream_oauth_authorization_session.add",
242 skip_all,
243 fields(
244 db.query.text,
245 %upstream_oauth_provider.id,
246 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
247 %upstream_oauth_provider.client_id,
248 upstream_oauth_authorization_session.id,
249 ),
250 err,
251 )]
252 async fn add(
253 &mut self,
254 rng: &mut (dyn RngCore + Send),
255 clock: &dyn Clock,
256 upstream_oauth_provider: &UpstreamOAuthProvider,
257 state_str: String,
258 code_challenge_verifier: Option<String>,
259 nonce: Option<String>,
260 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
261 let created_at = clock.now();
262 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
263 tracing::Span::current().record(
264 "upstream_oauth_authorization_session.id",
265 tracing::field::display(id),
266 );
267
268 sqlx::query!(
269 r#"
270 INSERT INTO upstream_oauth_authorization_sessions (
271 upstream_oauth_authorization_session_id,
272 upstream_oauth_provider_id,
273 state,
274 code_challenge_verifier,
275 nonce,
276 created_at,
277 completed_at,
278 consumed_at,
279 id_token,
280 userinfo
281 ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
282 "#,
283 Uuid::from(id),
284 Uuid::from(upstream_oauth_provider.id),
285 &state_str,
286 code_challenge_verifier.as_deref(),
287 nonce,
288 created_at,
289 )
290 .traced()
291 .execute(&mut *self.conn)
292 .await?;
293
294 Ok(UpstreamOAuthAuthorizationSession {
295 id,
296 state: UpstreamOAuthAuthorizationSessionState::default(),
297 provider_id: upstream_oauth_provider.id,
298 state_str,
299 code_challenge_verifier,
300 nonce,
301 created_at,
302 })
303 }
304
305 #[tracing::instrument(
306 name = "db.upstream_oauth_authorization_session.complete_with_link",
307 skip_all,
308 fields(
309 db.query.text,
310 %upstream_oauth_authorization_session.id,
311 %upstream_oauth_link.id,
312 ),
313 err,
314 )]
315 async fn complete_with_link(
316 &mut self,
317 clock: &dyn Clock,
318 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
319 upstream_oauth_link: &UpstreamOAuthLink,
320 id_token: Option<String>,
321 id_token_claims: Option<serde_json::Value>,
322 extra_callback_parameters: Option<serde_json::Value>,
323 userinfo: Option<serde_json::Value>,
324 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
325 let completed_at = clock.now();
326
327 sqlx::query!(
328 r#"
329 UPDATE upstream_oauth_authorization_sessions
330 SET upstream_oauth_link_id = $1
331 , completed_at = $2
332 , id_token = $3
333 , id_token_claims = $4
334 , extra_callback_parameters = $5
335 , userinfo = $6
336 WHERE upstream_oauth_authorization_session_id = $7
337 "#,
338 Uuid::from(upstream_oauth_link.id),
339 completed_at,
340 id_token,
341 id_token_claims,
342 extra_callback_parameters,
343 userinfo,
344 Uuid::from(upstream_oauth_authorization_session.id),
345 )
346 .traced()
347 .execute(&mut *self.conn)
348 .await?;
349
350 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
351 .complete(
352 completed_at,
353 upstream_oauth_link,
354 id_token,
355 id_token_claims,
356 extra_callback_parameters,
357 userinfo,
358 )
359 .map_err(DatabaseError::to_invalid_operation)?;
360
361 Ok(upstream_oauth_authorization_session)
362 }
363
364 #[tracing::instrument(
366 name = "db.upstream_oauth_authorization_session.consume",
367 skip_all,
368 fields(
369 db.query.text,
370 %upstream_oauth_authorization_session.id,
371 ),
372 err,
373 )]
374 async fn consume(
375 &mut self,
376 clock: &dyn Clock,
377 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
378 browser_session: &BrowserSession,
379 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
380 let consumed_at = clock.now();
381 sqlx::query!(
382 r#"
383 UPDATE upstream_oauth_authorization_sessions
384 SET consumed_at = $1,
385 user_session_id = $2
386 WHERE upstream_oauth_authorization_session_id = $3
387 "#,
388 consumed_at,
389 Uuid::from(browser_session.id),
390 Uuid::from(upstream_oauth_authorization_session.id),
391 )
392 .traced()
393 .execute(&mut *self.conn)
394 .await?;
395
396 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
397 .consume(consumed_at)
398 .map_err(DatabaseError::to_invalid_operation)?;
399
400 Ok(upstream_oauth_authorization_session)
401 }
402
403 #[tracing::instrument(
404 name = "db.upstream_oauth_authorization_session.list",
405 skip_all,
406 fields(
407 db.query.text,
408 ),
409 err,
410 )]
411 async fn list(
412 &mut self,
413 filter: UpstreamOAuthSessionFilter<'_>,
414 pagination: Pagination,
415 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
416 let (sql, arguments) = Query::select()
417 .expr_as(
418 Expr::col((
419 UpstreamOAuthAuthorizationSessions::Table,
420 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
421 )),
422 SessionLookupIden::UpstreamOauthAuthorizationSessionId,
423 )
424 .expr_as(
425 Expr::col((
426 UpstreamOAuthAuthorizationSessions::Table,
427 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
428 )),
429 SessionLookupIden::UpstreamOauthProviderId,
430 )
431 .expr_as(
432 Expr::col((
433 UpstreamOAuthAuthorizationSessions::Table,
434 UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
435 )),
436 SessionLookupIden::UpstreamOauthLinkId,
437 )
438 .expr_as(
439 Expr::col((
440 UpstreamOAuthAuthorizationSessions::Table,
441 UpstreamOAuthAuthorizationSessions::State,
442 )),
443 SessionLookupIden::State,
444 )
445 .expr_as(
446 Expr::col((
447 UpstreamOAuthAuthorizationSessions::Table,
448 UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
449 )),
450 SessionLookupIden::CodeChallengeVerifier,
451 )
452 .expr_as(
453 Expr::col((
454 UpstreamOAuthAuthorizationSessions::Table,
455 UpstreamOAuthAuthorizationSessions::Nonce,
456 )),
457 SessionLookupIden::Nonce,
458 )
459 .expr_as(
460 Expr::col((
461 UpstreamOAuthAuthorizationSessions::Table,
462 UpstreamOAuthAuthorizationSessions::IdToken,
463 )),
464 SessionLookupIden::IdToken,
465 )
466 .expr_as(
467 Expr::col((
468 UpstreamOAuthAuthorizationSessions::Table,
469 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
470 )),
471 SessionLookupIden::IdTokenClaims,
472 )
473 .expr_as(
474 Expr::col((
475 UpstreamOAuthAuthorizationSessions::Table,
476 UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
477 )),
478 SessionLookupIden::ExtraCallbackParameters,
479 )
480 .expr_as(
481 Expr::col((
482 UpstreamOAuthAuthorizationSessions::Table,
483 UpstreamOAuthAuthorizationSessions::Userinfo,
484 )),
485 SessionLookupIden::Userinfo,
486 )
487 .expr_as(
488 Expr::col((
489 UpstreamOAuthAuthorizationSessions::Table,
490 UpstreamOAuthAuthorizationSessions::CreatedAt,
491 )),
492 SessionLookupIden::CreatedAt,
493 )
494 .expr_as(
495 Expr::col((
496 UpstreamOAuthAuthorizationSessions::Table,
497 UpstreamOAuthAuthorizationSessions::CompletedAt,
498 )),
499 SessionLookupIden::CompletedAt,
500 )
501 .expr_as(
502 Expr::col((
503 UpstreamOAuthAuthorizationSessions::Table,
504 UpstreamOAuthAuthorizationSessions::ConsumedAt,
505 )),
506 SessionLookupIden::ConsumedAt,
507 )
508 .expr_as(
509 Expr::col((
510 UpstreamOAuthAuthorizationSessions::Table,
511 UpstreamOAuthAuthorizationSessions::UnlinkedAt,
512 )),
513 SessionLookupIden::UnlinkedAt,
514 )
515 .from(UpstreamOAuthAuthorizationSessions::Table)
516 .apply_filter(filter)
517 .generate_pagination(
518 (
519 UpstreamOAuthAuthorizationSessions::Table,
520 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
521 ),
522 pagination,
523 )
524 .build_sqlx(PostgresQueryBuilder);
525
526 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
527 .traced()
528 .fetch_all(&mut *self.conn)
529 .await?;
530
531 let page = pagination
532 .process(edges)
533 .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
534
535 Ok(page)
536 }
537
538 #[tracing::instrument(
539 name = "db.upstream_oauth_authorization_session.count",
540 skip_all,
541 fields(
542 db.query.text,
543 ),
544 err,
545 )]
546 async fn count(
547 &mut self,
548 filter: UpstreamOAuthSessionFilter<'_>,
549 ) -> Result<usize, Self::Error> {
550 let (sql, arguments) = Query::select()
551 .expr(
552 Expr::col((
553 UpstreamOAuthAuthorizationSessions::Table,
554 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
555 ))
556 .count(),
557 )
558 .from(UpstreamOAuthAuthorizationSessions::Table)
559 .apply_filter(filter)
560 .build_sqlx(PostgresQueryBuilder);
561
562 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
563 .traced()
564 .fetch_one(&mut *self.conn)
565 .await?;
566
567 count
568 .try_into()
569 .map_err(DatabaseError::to_invalid_operation)
570 }
571
572 #[tracing::instrument(
573 name = "db.upstream_oauth_authorization_session.cleanup",
574 skip_all,
575 fields(
576 db.query.text,
577 since = since.map(tracing::field::display),
578 until = %until,
579 limit = limit,
580 ),
581 err,
582 )]
583 async fn cleanup_orphaned(
584 &mut self,
585 since: Option<Ulid>,
586 until: Ulid,
587 limit: usize,
588 ) -> Result<(usize, Option<Ulid>), Self::Error> {
589 let res = sqlx::query_scalar!(
593 r#"
594 WITH to_delete AS (
595 SELECT upstream_oauth_authorization_session_id
596 FROM upstream_oauth_authorization_sessions
597 WHERE ($1::uuid IS NULL OR upstream_oauth_authorization_session_id > $1)
598 AND upstream_oauth_authorization_session_id <= $2
599 AND user_session_id IS NULL
600 ORDER BY upstream_oauth_authorization_session_id
601 LIMIT $3
602 )
603 DELETE FROM upstream_oauth_authorization_sessions
604 USING to_delete
605 WHERE upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id = to_delete.upstream_oauth_authorization_session_id
606 RETURNING upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id
607 "#,
608 since.map(Uuid::from),
609 Uuid::from(until),
610 i64::try_from(limit).unwrap_or(i64::MAX)
611 )
612 .traced()
613 .fetch_all(&mut *self.conn)
614 .await?;
615
616 let count = res.len();
617 let max_id = res.into_iter().max();
618
619 Ok((count, max_id.map(Ulid::from)))
620 }
621}