1use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13 Authentication, AuthenticationMethod, BrowserSession, Clock, Password,
14 UpstreamOAuthAuthorizationSession, User,
15};
16use mas_storage::{
17 Page, Pagination,
18 pagination::Node,
19 user::{BrowserSessionFilter, BrowserSessionRepository},
20};
21use rand::RngCore;
22use sea_query::{Expr, PostgresQueryBuilder, Query};
23use sea_query_binder::SqlxBinder;
24use sqlx::PgConnection;
25use ulid::Ulid;
26use uuid::Uuid;
27
28use crate::{
29 DatabaseError, DatabaseInconsistencyError,
30 filter::StatementExt,
31 iden::{UpstreamOAuthAuthorizationSessions, UserSessions, Users},
32 pagination::QueryBuilderExt,
33 tracing::ExecuteExt,
34};
35
36pub struct PgBrowserSessionRepository<'c> {
39 conn: &'c mut PgConnection,
40}
41
42impl<'c> PgBrowserSessionRepository<'c> {
43 pub fn new(conn: &'c mut PgConnection) -> Self {
46 Self { conn }
47 }
48}
49
50#[allow(clippy::struct_field_names)]
51#[derive(sqlx::FromRow)]
52#[sea_query::enum_def]
53struct SessionLookup {
54 user_session_id: Uuid,
55 user_session_created_at: DateTime<Utc>,
56 user_session_finished_at: Option<DateTime<Utc>>,
57 user_session_user_agent: Option<String>,
58 user_session_last_active_at: Option<DateTime<Utc>>,
59 user_session_last_active_ip: Option<IpAddr>,
60 user_id: Uuid,
61 user_username: String,
62 user_created_at: DateTime<Utc>,
63 user_locked_at: Option<DateTime<Utc>>,
64 user_deactivated_at: Option<DateTime<Utc>>,
65 user_can_request_admin: bool,
66 user_is_guest: bool,
67}
68
69impl Node<Ulid> for SessionLookup {
70 fn cursor(&self) -> Ulid {
71 self.user_id.into()
72 }
73}
74
75impl TryFrom<SessionLookup> for BrowserSession {
76 type Error = DatabaseInconsistencyError;
77
78 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
79 let id = Ulid::from(value.user_id);
80 let user = User {
81 id,
82 username: value.user_username,
83 sub: id.to_string(),
84 created_at: value.user_created_at,
85 locked_at: value.user_locked_at,
86 deactivated_at: value.user_deactivated_at,
87 can_request_admin: value.user_can_request_admin,
88 is_guest: value.user_is_guest,
89 };
90
91 Ok(BrowserSession {
92 id: value.user_session_id.into(),
93 user,
94 created_at: value.user_session_created_at,
95 finished_at: value.user_session_finished_at,
96 user_agent: value.user_session_user_agent,
97 last_active_at: value.user_session_last_active_at,
98 last_active_ip: value.user_session_last_active_ip,
99 })
100 }
101}
102
103struct AuthenticationLookup {
104 user_session_authentication_id: Uuid,
105 created_at: DateTime<Utc>,
106 user_password_id: Option<Uuid>,
107 upstream_oauth_authorization_session_id: Option<Uuid>,
108}
109
110impl TryFrom<AuthenticationLookup> for Authentication {
111 type Error = DatabaseInconsistencyError;
112
113 fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
114 let id = Ulid::from(value.user_session_authentication_id);
115 let authentication_method = match (
116 value.user_password_id.map(Into::into),
117 value
118 .upstream_oauth_authorization_session_id
119 .map(Into::into),
120 ) {
121 (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
122 (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
123 upstream_oauth2_session_id,
124 },
125 (None, None) => AuthenticationMethod::Unknown,
126 _ => {
127 return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
128 }
129 };
130
131 Ok(Authentication {
132 id,
133 created_at: value.created_at,
134 authentication_method,
135 })
136 }
137}
138
139impl crate::filter::Filter for BrowserSessionFilter<'_> {
140 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
141 sea_query::Condition::all()
142 .add_option(self.user().map(|user| {
143 Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
144 }))
145 .add_option(self.state().map(|state| {
146 if state.is_active() {
147 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
148 } else {
149 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
150 }
151 }))
152 .add_option(self.last_active_after().map(|last_active_after| {
153 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
154 }))
155 .add_option(self.last_active_before().map(|last_active_before| {
156 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
157 }))
158 .add_option(self.linked_to_upstream_sessions().map(|filter| {
159 Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
160 Query::select()
161 .expr(Expr::col((
162 UpstreamOAuthAuthorizationSessions::Table,
163 UpstreamOAuthAuthorizationSessions::UserSessionId,
164 )))
165 .from(UpstreamOAuthAuthorizationSessions::Table)
166 .apply_filter(filter)
167 .take(),
168 )
169 }))
170 }
171}
172
173#[async_trait]
174impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
175 type Error = DatabaseError;
176
177 #[tracing::instrument(
178 name = "db.browser_session.lookup",
179 skip_all,
180 fields(
181 db.query.text,
182 user_session.id = %id,
183 ),
184 err,
185 )]
186 async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
187 let res = sqlx::query_as!(
188 SessionLookup,
189 r#"
190 SELECT s.user_session_id
191 , s.created_at AS "user_session_created_at"
192 , s.finished_at AS "user_session_finished_at"
193 , s.user_agent AS "user_session_user_agent"
194 , s.last_active_at AS "user_session_last_active_at"
195 , s.last_active_ip AS "user_session_last_active_ip: IpAddr"
196 , u.user_id
197 , u.username AS "user_username"
198 , u.created_at AS "user_created_at"
199 , u.locked_at AS "user_locked_at"
200 , u.deactivated_at AS "user_deactivated_at"
201 , u.can_request_admin AS "user_can_request_admin"
202 , u.is_guest AS "user_is_guest"
203 FROM user_sessions s
204 INNER JOIN users u
205 USING (user_id)
206 WHERE s.user_session_id = $1
207 "#,
208 Uuid::from(id),
209 )
210 .traced()
211 .fetch_optional(&mut *self.conn)
212 .await?;
213
214 let Some(res) = res else { return Ok(None) };
215
216 Ok(Some(res.try_into()?))
217 }
218
219 #[tracing::instrument(
220 name = "db.browser_session.add",
221 skip_all,
222 fields(
223 db.query.text,
224 %user.id,
225 user_session.id,
226 ),
227 err,
228 )]
229 async fn add(
230 &mut self,
231 rng: &mut (dyn RngCore + Send),
232 clock: &dyn Clock,
233 user: &User,
234 user_agent: Option<String>,
235 ) -> Result<BrowserSession, Self::Error> {
236 let created_at = clock.now();
237 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
238 tracing::Span::current().record("user_session.id", tracing::field::display(id));
239
240 sqlx::query!(
241 r#"
242 INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
243 VALUES ($1, $2, $3, $4)
244 "#,
245 Uuid::from(id),
246 Uuid::from(user.id),
247 created_at,
248 user_agent.as_deref(),
249 )
250 .traced()
251 .execute(&mut *self.conn)
252 .await?;
253
254 let session = BrowserSession {
255 id,
256 user: user.clone(),
258 created_at,
259 finished_at: None,
260 user_agent,
261 last_active_at: None,
262 last_active_ip: None,
263 };
264
265 Ok(session)
266 }
267
268 #[tracing::instrument(
269 name = "db.browser_session.finish",
270 skip_all,
271 fields(
272 db.query.text,
273 %user_session.id,
274 ),
275 err,
276 )]
277 async fn finish(
278 &mut self,
279 clock: &dyn Clock,
280 mut user_session: BrowserSession,
281 ) -> Result<BrowserSession, Self::Error> {
282 let finished_at = clock.now();
283 let res = sqlx::query!(
284 r#"
285 UPDATE user_sessions
286 SET finished_at = $1
287 WHERE user_session_id = $2
288 "#,
289 finished_at,
290 Uuid::from(user_session.id),
291 )
292 .traced()
293 .execute(&mut *self.conn)
294 .await?;
295
296 user_session.finished_at = Some(finished_at);
297
298 DatabaseError::ensure_affected_rows(&res, 1)?;
299
300 Ok(user_session)
301 }
302
303 #[tracing::instrument(
304 name = "db.browser_session.finish_bulk",
305 skip_all,
306 fields(
307 db.query.text,
308 ),
309 err,
310 )]
311 async fn finish_bulk(
312 &mut self,
313 clock: &dyn Clock,
314 filter: BrowserSessionFilter<'_>,
315 ) -> Result<usize, Self::Error> {
316 let finished_at = clock.now();
317 let (sql, arguments) = sea_query::Query::update()
318 .table(UserSessions::Table)
319 .value(UserSessions::FinishedAt, finished_at)
320 .apply_filter(filter)
321 .build_sqlx(PostgresQueryBuilder);
322
323 let res = sqlx::query_with(&sql, arguments)
324 .traced()
325 .execute(&mut *self.conn)
326 .await?;
327
328 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
329 }
330
331 #[tracing::instrument(
332 name = "db.browser_session.list",
333 skip_all,
334 fields(
335 db.query.text,
336 ),
337 err,
338 )]
339 async fn list(
340 &mut self,
341 filter: BrowserSessionFilter<'_>,
342 pagination: Pagination,
343 ) -> Result<Page<BrowserSession>, Self::Error> {
344 let (sql, arguments) = sea_query::Query::select()
345 .expr_as(
346 Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
347 SessionLookupIden::UserSessionId,
348 )
349 .expr_as(
350 Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
351 SessionLookupIden::UserSessionCreatedAt,
352 )
353 .expr_as(
354 Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
355 SessionLookupIden::UserSessionFinishedAt,
356 )
357 .expr_as(
358 Expr::col((UserSessions::Table, UserSessions::UserAgent)),
359 SessionLookupIden::UserSessionUserAgent,
360 )
361 .expr_as(
362 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
363 SessionLookupIden::UserSessionLastActiveAt,
364 )
365 .expr_as(
366 Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
367 SessionLookupIden::UserSessionLastActiveIp,
368 )
369 .expr_as(
370 Expr::col((Users::Table, Users::UserId)),
371 SessionLookupIden::UserId,
372 )
373 .expr_as(
374 Expr::col((Users::Table, Users::Username)),
375 SessionLookupIden::UserUsername,
376 )
377 .expr_as(
378 Expr::col((Users::Table, Users::CreatedAt)),
379 SessionLookupIden::UserCreatedAt,
380 )
381 .expr_as(
382 Expr::col((Users::Table, Users::LockedAt)),
383 SessionLookupIden::UserLockedAt,
384 )
385 .expr_as(
386 Expr::col((Users::Table, Users::DeactivatedAt)),
387 SessionLookupIden::UserDeactivatedAt,
388 )
389 .expr_as(
390 Expr::col((Users::Table, Users::CanRequestAdmin)),
391 SessionLookupIden::UserCanRequestAdmin,
392 )
393 .expr_as(
394 Expr::col((Users::Table, Users::IsGuest)),
395 SessionLookupIden::UserIsGuest,
396 )
397 .from(UserSessions::Table)
398 .inner_join(
399 Users::Table,
400 Expr::col((UserSessions::Table, UserSessions::UserId))
401 .equals((Users::Table, Users::UserId)),
402 )
403 .apply_filter(filter)
404 .generate_pagination(
405 (UserSessions::Table, UserSessions::UserSessionId),
406 pagination,
407 )
408 .build_sqlx(PostgresQueryBuilder);
409
410 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
411 .traced()
412 .fetch_all(&mut *self.conn)
413 .await?;
414
415 let page = pagination
416 .process(edges)
417 .try_map(BrowserSession::try_from)?;
418
419 Ok(page)
420 }
421
422 #[tracing::instrument(
423 name = "db.browser_session.count",
424 skip_all,
425 fields(
426 db.query.text,
427 ),
428 err,
429 )]
430 async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
431 let (sql, arguments) = sea_query::Query::select()
432 .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
433 .from(UserSessions::Table)
434 .apply_filter(filter)
435 .build_sqlx(PostgresQueryBuilder);
436
437 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
438 .traced()
439 .fetch_one(&mut *self.conn)
440 .await?;
441
442 count
443 .try_into()
444 .map_err(DatabaseError::to_invalid_operation)
445 }
446
447 #[tracing::instrument(
448 name = "db.browser_session.authenticate_with_password",
449 skip_all,
450 fields(
451 db.query.text,
452 %user_session.id,
453 %user_password.id,
454 user_session_authentication.id,
455 ),
456 err,
457 )]
458 async fn authenticate_with_password(
459 &mut self,
460 rng: &mut (dyn RngCore + Send),
461 clock: &dyn Clock,
462 user_session: &BrowserSession,
463 user_password: &Password,
464 ) -> Result<Authentication, Self::Error> {
465 let created_at = clock.now();
466 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
467 tracing::Span::current().record(
468 "user_session_authentication.id",
469 tracing::field::display(id),
470 );
471
472 sqlx::query!(
473 r#"
474 INSERT INTO user_session_authentications
475 (user_session_authentication_id, user_session_id, created_at, user_password_id)
476 VALUES ($1, $2, $3, $4)
477 "#,
478 Uuid::from(id),
479 Uuid::from(user_session.id),
480 created_at,
481 Uuid::from(user_password.id),
482 )
483 .traced()
484 .execute(&mut *self.conn)
485 .await?;
486
487 Ok(Authentication {
488 id,
489 created_at,
490 authentication_method: AuthenticationMethod::Password {
491 user_password_id: user_password.id,
492 },
493 })
494 }
495
496 #[tracing::instrument(
497 name = "db.browser_session.authenticate_with_upstream",
498 skip_all,
499 fields(
500 db.query.text,
501 %user_session.id,
502 %upstream_oauth_session.id,
503 user_session_authentication.id,
504 ),
505 err,
506 )]
507 async fn authenticate_with_upstream(
508 &mut self,
509 rng: &mut (dyn RngCore + Send),
510 clock: &dyn Clock,
511 user_session: &BrowserSession,
512 upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
513 ) -> Result<Authentication, Self::Error> {
514 let created_at = clock.now();
515 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
516 tracing::Span::current().record(
517 "user_session_authentication.id",
518 tracing::field::display(id),
519 );
520
521 sqlx::query!(
522 r#"
523 INSERT INTO user_session_authentications
524 (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
525 VALUES ($1, $2, $3, $4)
526 "#,
527 Uuid::from(id),
528 Uuid::from(user_session.id),
529 created_at,
530 Uuid::from(upstream_oauth_session.id),
531 )
532 .traced()
533 .execute(&mut *self.conn)
534 .await?;
535
536 Ok(Authentication {
537 id,
538 created_at,
539 authentication_method: AuthenticationMethod::UpstreamOAuth2 {
540 upstream_oauth2_session_id: upstream_oauth_session.id,
541 },
542 })
543 }
544
545 #[tracing::instrument(
546 name = "db.browser_session.get_last_authentication",
547 skip_all,
548 fields(
549 db.query.text,
550 %user_session.id,
551 ),
552 err,
553 )]
554 async fn get_last_authentication(
555 &mut self,
556 user_session: &BrowserSession,
557 ) -> Result<Option<Authentication>, Self::Error> {
558 let authentication = sqlx::query_as!(
559 AuthenticationLookup,
560 r#"
561 SELECT user_session_authentication_id
562 , created_at
563 , user_password_id
564 , upstream_oauth_authorization_session_id
565 FROM user_session_authentications
566 WHERE user_session_id = $1
567 ORDER BY created_at DESC
568 LIMIT 1
569 "#,
570 Uuid::from(user_session.id),
571 )
572 .traced()
573 .fetch_optional(&mut *self.conn)
574 .await?;
575
576 let Some(authentication) = authentication else {
577 return Ok(None);
578 };
579
580 let authentication = Authentication::try_from(authentication)?;
581 Ok(Some(authentication))
582 }
583
584 #[tracing::instrument(
585 name = "db.browser_session.record_batch_activity",
586 skip_all,
587 fields(
588 db.query.text,
589 ),
590 err,
591 )]
592 async fn record_batch_activity(
593 &mut self,
594 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
595 ) -> Result<(), Self::Error> {
596 activities.sort_unstable();
599 let mut ids = Vec::with_capacity(activities.len());
600 let mut last_activities = Vec::with_capacity(activities.len());
601 let mut ips = Vec::with_capacity(activities.len());
602
603 for (id, last_activity, ip) in activities {
604 ids.push(Uuid::from(id));
605 last_activities.push(last_activity);
606 ips.push(ip);
607 }
608
609 let res = sqlx::query!(
610 r#"
611 UPDATE user_sessions
612 SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
613 , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
614 FROM (
615 SELECT *
616 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
617 AS t(user_session_id, last_active_at, last_active_ip)
618 ) AS t
619 WHERE user_sessions.user_session_id = t.user_session_id
620 "#,
621 &ids,
622 &last_activities,
623 &ips as &[Option<IpAddr>],
624 )
625 .traced()
626 .execute(&mut *self.conn)
627 .await?;
628
629 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
630
631 Ok(())
632 }
633
634 #[tracing::instrument(
635 name = "db.browser_session.cleanup_finished",
636 skip_all,
637 fields(
638 db.query.text,
639 since = since.map(tracing::field::display),
640 until = %until,
641 limit = limit,
642 ),
643 err,
644 )]
645 async fn cleanup_finished(
646 &mut self,
647 since: Option<DateTime<Utc>>,
648 until: DateTime<Utc>,
649 limit: usize,
650 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
651 let res = sqlx::query!(
652 r#"
653 WITH
654 to_delete AS (
655 SELECT user_session_id, finished_at
656 FROM user_sessions us
657 WHERE us.finished_at IS NOT NULL
658 AND ($1::timestamptz IS NULL OR us.finished_at >= $1)
659 AND us.finished_at < $2
660 -- Only delete if no oauth2_sessions reference this user_session
661 AND NOT EXISTS (
662 SELECT 1 FROM oauth2_sessions os
663 WHERE os.user_session_id = us.user_session_id
664 )
665 -- Only delete if no compat_sessions reference this user_session
666 AND NOT EXISTS (
667 SELECT 1 FROM compat_sessions cs
668 WHERE cs.user_session_id = us.user_session_id
669 )
670 ORDER BY us.finished_at ASC
671 LIMIT $3
672 FOR UPDATE OF us
673 ),
674 deleted_authentications AS (
675 DELETE FROM user_session_authentications USING to_delete
676 WHERE user_session_authentications.user_session_id = to_delete.user_session_id
677 ),
678 deleted_sessions AS (
679 DELETE FROM user_sessions USING to_delete
680 WHERE user_sessions.user_session_id = to_delete.user_session_id
681 RETURNING user_sessions.finished_at
682 )
683 SELECT COUNT(*) as "count!", MAX(finished_at) as last_finished_at FROM deleted_sessions
684 "#,
685 since,
686 until,
687 i64::try_from(limit).unwrap_or(i64::MAX),
688 )
689 .traced()
690 .fetch_one(&mut *self.conn)
691 .await?;
692
693 Ok((
694 res.count.try_into().unwrap_or(usize::MAX),
695 res.last_finished_at,
696 ))
697 }
698
699 #[tracing::instrument(
700 name = "db.browser_session.cleanup_inactive_ips",
701 skip_all,
702 fields(
703 db.query.text,
704 since = since.map(tracing::field::display),
705 threshold = %threshold,
706 limit = limit,
707 ),
708 err,
709 )]
710 async fn cleanup_inactive_ips(
711 &mut self,
712 since: Option<DateTime<Utc>>,
713 threshold: DateTime<Utc>,
714 limit: usize,
715 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
716 let res = sqlx::query!(
717 r#"
718 WITH to_update AS (
719 SELECT user_session_id, last_active_at
720 FROM user_sessions
721 WHERE last_active_ip IS NOT NULL
722 AND last_active_at IS NOT NULL
723 AND ($1::timestamptz IS NULL OR last_active_at >= $1)
724 AND last_active_at < $2
725 ORDER BY last_active_at ASC
726 LIMIT $3
727 FOR UPDATE
728 ),
729 updated AS (
730 UPDATE user_sessions
731 SET last_active_ip = NULL
732 FROM to_update
733 WHERE user_sessions.user_session_id = to_update.user_session_id
734 RETURNING user_sessions.last_active_at
735 )
736 SELECT COUNT(*) AS "count!", MAX(last_active_at) AS last_active_at FROM updated
737 "#,
738 since,
739 threshold,
740 i64::try_from(limit).unwrap_or(i64::MAX),
741 )
742 .traced()
743 .fetch_one(&mut *self.conn)
744 .await?;
745
746 Ok((
747 res.count.try_into().unwrap_or(usize::MAX),
748 res.last_active_at,
749 ))
750 }
751}