1use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Client, Clock, Session, SessionState, User};
13use mas_storage::{
14 Page, Pagination,
15 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
16 pagination::Node,
17};
18use oauth2_types::scope::{Scope, ScopeToken};
19use rand::RngCore;
20use sea_query::{
21 Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
22 extension::postgres::PgExpr,
23};
24use sea_query_binder::SqlxBinder;
25use sqlx::PgConnection;
26use ulid::Ulid;
27use uuid::Uuid;
28
29use crate::{
30 DatabaseError, DatabaseInconsistencyError,
31 filter::{Filter, StatementExt},
32 iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
33 pagination::QueryBuilderExt,
34 tracing::ExecuteExt,
35};
36
37pub struct PgOAuth2SessionRepository<'c> {
39 conn: &'c mut PgConnection,
40}
41
42impl<'c> PgOAuth2SessionRepository<'c> {
43 pub fn new(conn: &'c mut PgConnection) -> Self {
46 Self { conn }
47 }
48}
49
50#[derive(sqlx::FromRow)]
51#[enum_def]
52struct OAuthSessionLookup {
53 oauth2_session_id: Uuid,
54 user_id: Option<Uuid>,
55 user_session_id: Option<Uuid>,
56 oauth2_client_id: Uuid,
57 scope_list: Vec<String>,
58 created_at: DateTime<Utc>,
59 finished_at: Option<DateTime<Utc>>,
60 user_agent: Option<String>,
61 last_active_at: Option<DateTime<Utc>>,
62 last_active_ip: Option<IpAddr>,
63 human_name: Option<String>,
64}
65
66impl Node<Ulid> for OAuthSessionLookup {
67 fn cursor(&self) -> Ulid {
68 self.oauth2_session_id.into()
69 }
70}
71
72impl TryFrom<OAuthSessionLookup> for Session {
73 type Error = DatabaseInconsistencyError;
74
75 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
76 let id = Ulid::from(value.oauth2_session_id);
77 let scope: Result<Scope, _> = value
78 .scope_list
79 .iter()
80 .map(|s| s.parse::<ScopeToken>())
81 .collect();
82 let scope = scope.map_err(|e| {
83 DatabaseInconsistencyError::on("oauth2_sessions")
84 .column("scope")
85 .row(id)
86 .source(e)
87 })?;
88
89 let state = match value.finished_at {
90 None => SessionState::Valid,
91 Some(finished_at) => SessionState::Finished { finished_at },
92 };
93
94 Ok(Session {
95 id,
96 state,
97 created_at: value.created_at,
98 client_id: value.oauth2_client_id.into(),
99 user_id: value.user_id.map(Ulid::from),
100 user_session_id: value.user_session_id.map(Ulid::from),
101 scope,
102 user_agent: value.user_agent,
103 last_active_at: value.last_active_at,
104 last_active_ip: value.last_active_ip,
105 human_name: value.human_name,
106 })
107 }
108}
109
110impl Filter for OAuth2SessionFilter<'_> {
111 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
112 sea_query::Condition::all()
113 .add_option(self.user().map(|user| {
114 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
115 }))
116 .add_option(self.client().map(|client| {
117 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
118 .eq(Uuid::from(client.id))
119 }))
120 .add_option(self.client_kind().map(|client_kind| {
121 let static_clients = Query::select()
125 .expr(Expr::col((
126 OAuth2Clients::Table,
127 OAuth2Clients::OAuth2ClientId,
128 )))
129 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
130 .from(OAuth2Clients::Table)
131 .take();
132 if client_kind.is_static() {
133 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
134 .eq(Expr::any(static_clients))
135 } else {
136 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
137 .ne(Expr::all(static_clients))
138 }
139 }))
140 .add_option(self.device().map(|device| -> SimpleExpr {
141 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
142 Condition::any()
143 .add(
144 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
145 OAuth2Sessions::Table,
146 OAuth2Sessions::ScopeList,
147 )))),
148 )
149 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
150 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
151 )))
152 .into()
153 } else {
154 Expr::val(false).into()
156 }
157 }))
158 .add_option(self.browser_session().map(|browser_session| {
159 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
160 .eq(Uuid::from(browser_session.id))
161 }))
162 .add_option(self.browser_session_filter().map(|browser_session_filter| {
163 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
164 Query::select()
165 .expr(Expr::col((
166 UserSessions::Table,
167 UserSessions::UserSessionId,
168 )))
169 .apply_filter(browser_session_filter)
170 .from(UserSessions::Table)
171 .take(),
172 )
173 }))
174 .add_option(self.state().map(|state| {
175 if state.is_active() {
176 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
177 } else {
178 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
179 }
180 }))
181 .add_option(self.scope().map(|scope| {
182 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
183 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
184 }))
185 .add_option(self.any_user().map(|any_user| {
186 if any_user {
187 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
188 } else {
189 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
190 }
191 }))
192 .add_option(self.last_active_after().map(|last_active_after| {
193 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
194 .gt(last_active_after)
195 }))
196 .add_option(self.last_active_before().map(|last_active_before| {
197 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
198 .lt(last_active_before)
199 }))
200 }
201}
202
203#[async_trait]
204impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
205 type Error = DatabaseError;
206
207 #[tracing::instrument(
208 name = "db.oauth2_session.lookup",
209 skip_all,
210 fields(
211 db.query.text,
212 session.id = %id,
213 ),
214 err,
215 )]
216 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
217 let res = sqlx::query_as!(
218 OAuthSessionLookup,
219 r#"
220 SELECT oauth2_session_id
221 , user_id
222 , user_session_id
223 , oauth2_client_id
224 , scope_list
225 , created_at
226 , finished_at
227 , user_agent
228 , last_active_at
229 , last_active_ip as "last_active_ip: IpAddr"
230 , human_name
231 FROM oauth2_sessions
232
233 WHERE oauth2_session_id = $1
234 "#,
235 Uuid::from(id),
236 )
237 .traced()
238 .fetch_optional(&mut *self.conn)
239 .await?;
240
241 let Some(session) = res else { return Ok(None) };
242
243 Ok(Some(session.try_into()?))
244 }
245
246 #[tracing::instrument(
247 name = "db.oauth2_session.add",
248 skip_all,
249 fields(
250 db.query.text,
251 %client.id,
252 session.id,
253 session.scope = %scope,
254 ),
255 err,
256 )]
257 async fn add(
258 &mut self,
259 rng: &mut (dyn RngCore + Send),
260 clock: &dyn Clock,
261 client: &Client,
262 user: Option<&User>,
263 user_session: Option<&BrowserSession>,
264 scope: Scope,
265 ) -> Result<Session, Self::Error> {
266 let created_at = clock.now();
267 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
268 tracing::Span::current().record("session.id", tracing::field::display(id));
269
270 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
271
272 sqlx::query!(
273 r#"
274 INSERT INTO oauth2_sessions
275 ( oauth2_session_id
276 , user_id
277 , user_session_id
278 , oauth2_client_id
279 , scope_list
280 , created_at
281 )
282 VALUES ($1, $2, $3, $4, $5, $6)
283 "#,
284 Uuid::from(id),
285 user.map(|u| Uuid::from(u.id)),
286 user_session.map(|s| Uuid::from(s.id)),
287 Uuid::from(client.id),
288 &scope_list,
289 created_at,
290 )
291 .traced()
292 .execute(&mut *self.conn)
293 .await?;
294
295 Ok(Session {
296 id,
297 state: SessionState::Valid,
298 created_at,
299 user_id: user.map(|u| u.id),
300 user_session_id: user_session.map(|s| s.id),
301 client_id: client.id,
302 scope,
303 user_agent: None,
304 last_active_at: None,
305 last_active_ip: None,
306 human_name: None,
307 })
308 }
309
310 #[tracing::instrument(
311 name = "db.oauth2_session.finish_bulk",
312 skip_all,
313 fields(
314 db.query.text,
315 ),
316 err,
317 )]
318 async fn finish_bulk(
319 &mut self,
320 clock: &dyn Clock,
321 filter: OAuth2SessionFilter<'_>,
322 ) -> Result<usize, Self::Error> {
323 let finished_at = clock.now();
324 let (sql, arguments) = Query::update()
325 .table(OAuth2Sessions::Table)
326 .value(OAuth2Sessions::FinishedAt, finished_at)
327 .apply_filter(filter)
328 .build_sqlx(PostgresQueryBuilder);
329
330 let res = sqlx::query_with(&sql, arguments)
331 .traced()
332 .execute(&mut *self.conn)
333 .await?;
334
335 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
336 }
337
338 #[tracing::instrument(
339 name = "db.oauth2_session.finish",
340 skip_all,
341 fields(
342 db.query.text,
343 %session.id,
344 %session.scope,
345 client.id = %session.client_id,
346 ),
347 err,
348 )]
349 async fn finish(
350 &mut self,
351 clock: &dyn Clock,
352 session: Session,
353 ) -> Result<Session, Self::Error> {
354 let finished_at = clock.now();
355 let res = sqlx::query!(
356 r#"
357 UPDATE oauth2_sessions
358 SET finished_at = $2
359 WHERE oauth2_session_id = $1
360 "#,
361 Uuid::from(session.id),
362 finished_at,
363 )
364 .traced()
365 .execute(&mut *self.conn)
366 .await?;
367
368 DatabaseError::ensure_affected_rows(&res, 1)?;
369
370 session
371 .finish(finished_at)
372 .map_err(DatabaseError::to_invalid_operation)
373 }
374
375 #[tracing::instrument(
376 name = "db.oauth2_session.list",
377 skip_all,
378 fields(
379 db.query.text,
380 ),
381 err,
382 )]
383 async fn list(
384 &mut self,
385 filter: OAuth2SessionFilter<'_>,
386 pagination: Pagination,
387 ) -> Result<Page<Session>, Self::Error> {
388 let (sql, arguments) = Query::select()
389 .expr_as(
390 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
391 OAuthSessionLookupIden::Oauth2SessionId,
392 )
393 .expr_as(
394 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
395 OAuthSessionLookupIden::UserId,
396 )
397 .expr_as(
398 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
399 OAuthSessionLookupIden::UserSessionId,
400 )
401 .expr_as(
402 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
403 OAuthSessionLookupIden::Oauth2ClientId,
404 )
405 .expr_as(
406 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
407 OAuthSessionLookupIden::ScopeList,
408 )
409 .expr_as(
410 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
411 OAuthSessionLookupIden::CreatedAt,
412 )
413 .expr_as(
414 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
415 OAuthSessionLookupIden::FinishedAt,
416 )
417 .expr_as(
418 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
419 OAuthSessionLookupIden::UserAgent,
420 )
421 .expr_as(
422 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
423 OAuthSessionLookupIden::LastActiveAt,
424 )
425 .expr_as(
426 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
427 OAuthSessionLookupIden::LastActiveIp,
428 )
429 .expr_as(
430 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
431 OAuthSessionLookupIden::HumanName,
432 )
433 .from(OAuth2Sessions::Table)
434 .apply_filter(filter)
435 .generate_pagination(
436 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
437 pagination,
438 )
439 .build_sqlx(PostgresQueryBuilder);
440
441 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
442 .traced()
443 .fetch_all(&mut *self.conn)
444 .await?;
445
446 let page = pagination.process(edges).try_map(Session::try_from)?;
447
448 Ok(page)
449 }
450
451 #[tracing::instrument(
452 name = "db.oauth2_session.count",
453 skip_all,
454 fields(
455 db.query.text,
456 ),
457 err,
458 )]
459 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
460 let (sql, arguments) = Query::select()
461 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
462 .from(OAuth2Sessions::Table)
463 .apply_filter(filter)
464 .build_sqlx(PostgresQueryBuilder);
465
466 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
467 .traced()
468 .fetch_one(&mut *self.conn)
469 .await?;
470
471 count
472 .try_into()
473 .map_err(DatabaseError::to_invalid_operation)
474 }
475
476 #[tracing::instrument(
477 name = "db.oauth2_session.record_batch_activity",
478 skip_all,
479 fields(
480 db.query.text,
481 ),
482 err,
483 )]
484 async fn record_batch_activity(
485 &mut self,
486 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
487 ) -> Result<(), Self::Error> {
488 activities.sort_unstable();
491 let mut ids = Vec::with_capacity(activities.len());
492 let mut last_activities = Vec::with_capacity(activities.len());
493 let mut ips = Vec::with_capacity(activities.len());
494
495 for (id, last_activity, ip) in activities {
496 ids.push(Uuid::from(id));
497 last_activities.push(last_activity);
498 ips.push(ip);
499 }
500
501 let res = sqlx::query!(
502 r#"
503 UPDATE oauth2_sessions
504 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
505 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
506 FROM (
507 SELECT *
508 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
509 AS t(oauth2_session_id, last_active_at, last_active_ip)
510 ) AS t
511 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
512 "#,
513 &ids,
514 &last_activities,
515 &ips as &[Option<IpAddr>],
516 )
517 .traced()
518 .execute(&mut *self.conn)
519 .await?;
520
521 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
522
523 Ok(())
524 }
525
526 #[tracing::instrument(
527 name = "db.oauth2_session.record_user_agent",
528 skip_all,
529 fields(
530 db.query.text,
531 %session.id,
532 %session.scope,
533 client.id = %session.client_id,
534 session.user_agent = user_agent,
535 ),
536 err,
537 )]
538 async fn record_user_agent(
539 &mut self,
540 mut session: Session,
541 user_agent: String,
542 ) -> Result<Session, Self::Error> {
543 let res = sqlx::query!(
544 r#"
545 UPDATE oauth2_sessions
546 SET user_agent = $2
547 WHERE oauth2_session_id = $1
548 "#,
549 Uuid::from(session.id),
550 &*user_agent,
551 )
552 .traced()
553 .execute(&mut *self.conn)
554 .await?;
555
556 session.user_agent = Some(user_agent);
557
558 DatabaseError::ensure_affected_rows(&res, 1)?;
559
560 Ok(session)
561 }
562
563 #[tracing::instrument(
564 name = "repository.oauth2_session.set_human_name",
565 skip(self),
566 fields(
567 client.id = %session.client_id,
568 session.human_name = ?human_name,
569 ),
570 err,
571 )]
572 async fn set_human_name(
573 &mut self,
574 mut session: Session,
575 human_name: Option<String>,
576 ) -> Result<Session, Self::Error> {
577 let res = sqlx::query!(
578 r#"
579 UPDATE oauth2_sessions
580 SET human_name = $2
581 WHERE oauth2_session_id = $1
582 "#,
583 Uuid::from(session.id),
584 human_name.as_deref(),
585 )
586 .traced()
587 .execute(&mut *self.conn)
588 .await?;
589
590 session.human_name = human_name;
591
592 DatabaseError::ensure_affected_rows(&res, 1)?;
593
594 Ok(session)
595 }
596
597 #[tracing::instrument(
598 name = "db.oauth2_session.cleanup_finished",
599 skip_all,
600 fields(
601 db.query.text,
602 since = since.map(tracing::field::display),
603 until = %until,
604 limit = limit,
605 ),
606 err,
607 )]
608 async fn cleanup_finished(
609 &mut self,
610 since: Option<DateTime<Utc>>,
611 until: DateTime<Utc>,
612 limit: usize,
613 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
614 let res = sqlx::query!(
615 r#"
616 WITH
617 to_delete AS (
618 SELECT oauth2_session_id, finished_at
619 FROM oauth2_sessions
620 WHERE finished_at IS NOT NULL
621 AND ($1::timestamptz IS NULL OR finished_at >= $1)
622 AND finished_at < $2
623 ORDER BY finished_at ASC
624 LIMIT $3
625 FOR UPDATE
626 ),
627 deleted_refresh_tokens AS (
628 DELETE FROM oauth2_refresh_tokens USING to_delete
629 WHERE oauth2_refresh_tokens.oauth2_session_id = to_delete.oauth2_session_id
630 ),
631 deleted_access_tokens AS (
632 DELETE FROM oauth2_access_tokens USING to_delete
633 WHERE oauth2_access_tokens.oauth2_session_id = to_delete.oauth2_session_id
634 ),
635 deleted_sessions AS (
636 DELETE FROM oauth2_sessions USING to_delete
637 WHERE oauth2_sessions.oauth2_session_id = to_delete.oauth2_session_id
638 RETURNING oauth2_sessions.finished_at
639 )
640 SELECT COUNT(*) as "count!", MAX(finished_at) as last_finished_at FROM deleted_sessions
641 "#,
642 since,
643 until,
644 i64::try_from(limit).unwrap_or(i64::MAX),
645 )
646 .traced()
647 .fetch_one(&mut *self.conn)
648 .await?;
649
650 Ok((
651 res.count.try_into().unwrap_or(usize::MAX),
652 res.last_finished_at,
653 ))
654 }
655
656 #[tracing::instrument(
657 name = "db.oauth2_session.cleanup_inactive_ips",
658 skip_all,
659 fields(
660 db.query.text,
661 since = since.map(tracing::field::display),
662 threshold = %threshold,
663 limit = limit,
664 ),
665 err,
666 )]
667 async fn cleanup_inactive_ips(
668 &mut self,
669 since: Option<DateTime<Utc>>,
670 threshold: DateTime<Utc>,
671 limit: usize,
672 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
673 let res = sqlx::query!(
674 r#"
675 WITH to_update AS (
676 SELECT oauth2_session_id, last_active_at
677 FROM oauth2_sessions
678 WHERE last_active_ip IS NOT NULL
679 AND last_active_at IS NOT NULL
680 AND ($1::timestamptz IS NULL OR last_active_at >= $1)
681 AND last_active_at < $2
682 ORDER BY last_active_at ASC
683 LIMIT $3
684 FOR UPDATE
685 ),
686 updated AS (
687 UPDATE oauth2_sessions
688 SET last_active_ip = NULL
689 FROM to_update
690 WHERE oauth2_sessions.oauth2_session_id = to_update.oauth2_session_id
691 RETURNING oauth2_sessions.last_active_at
692 )
693 SELECT COUNT(*) AS "count!", MAX(last_active_at) AS last_active_at FROM updated
694 "#,
695 since,
696 threshold,
697 i64::try_from(limit).unwrap_or(i64::MAX),
698 )
699 .traced()
700 .fetch_one(&mut *self.conn)
701 .await?;
702
703 Ok((
704 res.count.try_into().unwrap_or(usize::MAX),
705 res.last_active_at,
706 ))
707 }
708}