1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 Authentication, AuthenticationMethod, BrowserSession, Password,
13 UpstreamOAuthAuthorizationSession, User, UserAgent,
14};
15use mas_storage::{
16 Clock, Page, Pagination,
17 user::{BrowserSessionFilter, BrowserSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27 DatabaseError, DatabaseInconsistencyError,
28 filter::StatementExt,
29 iden::{UserSessions, Users},
30 pagination::QueryBuilderExt,
31 tracing::ExecuteExt,
32};
33
34pub struct PgBrowserSessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgBrowserSessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48#[allow(clippy::struct_field_names)]
49#[derive(sqlx::FromRow)]
50#[sea_query::enum_def]
51struct SessionLookup {
52 user_session_id: Uuid,
53 user_session_created_at: DateTime<Utc>,
54 user_session_finished_at: Option<DateTime<Utc>>,
55 user_session_user_agent: Option<String>,
56 user_session_last_active_at: Option<DateTime<Utc>>,
57 user_session_last_active_ip: Option<IpAddr>,
58 user_id: Uuid,
59 user_username: String,
60 user_created_at: DateTime<Utc>,
61 user_locked_at: Option<DateTime<Utc>>,
62 user_deactivated_at: Option<DateTime<Utc>>,
63 user_can_request_admin: bool,
64}
65
66impl TryFrom<SessionLookup> for BrowserSession {
67 type Error = DatabaseInconsistencyError;
68
69 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
70 let id = Ulid::from(value.user_id);
71 let user = User {
72 id,
73 username: value.user_username,
74 sub: id.to_string(),
75 created_at: value.user_created_at,
76 locked_at: value.user_locked_at,
77 deactivated_at: value.user_deactivated_at,
78 can_request_admin: value.user_can_request_admin,
79 };
80
81 Ok(BrowserSession {
82 id: value.user_session_id.into(),
83 user,
84 created_at: value.user_session_created_at,
85 finished_at: value.user_session_finished_at,
86 user_agent: value.user_session_user_agent.map(UserAgent::parse),
87 last_active_at: value.user_session_last_active_at,
88 last_active_ip: value.user_session_last_active_ip,
89 })
90 }
91}
92
93struct AuthenticationLookup {
94 user_session_authentication_id: Uuid,
95 created_at: DateTime<Utc>,
96 user_password_id: Option<Uuid>,
97 upstream_oauth_authorization_session_id: Option<Uuid>,
98}
99
100impl TryFrom<AuthenticationLookup> for Authentication {
101 type Error = DatabaseInconsistencyError;
102
103 fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
104 let id = Ulid::from(value.user_session_authentication_id);
105 let authentication_method = match (
106 value.user_password_id.map(Into::into),
107 value
108 .upstream_oauth_authorization_session_id
109 .map(Into::into),
110 ) {
111 (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
112 (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
113 upstream_oauth2_session_id,
114 },
115 (None, None) => AuthenticationMethod::Unknown,
116 _ => {
117 return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
118 }
119 };
120
121 Ok(Authentication {
122 id,
123 created_at: value.created_at,
124 authentication_method,
125 })
126 }
127}
128
129impl crate::filter::Filter for BrowserSessionFilter<'_> {
130 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
131 sea_query::Condition::all()
132 .add_option(self.user().map(|user| {
133 Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
134 }))
135 .add_option(self.state().map(|state| {
136 if state.is_active() {
137 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
138 } else {
139 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
140 }
141 }))
142 .add_option(self.last_active_after().map(|last_active_after| {
143 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
144 }))
145 .add_option(self.last_active_before().map(|last_active_before| {
146 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
147 }))
148 }
149}
150
151#[async_trait]
152impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
153 type Error = DatabaseError;
154
155 #[tracing::instrument(
156 name = "db.browser_session.lookup",
157 skip_all,
158 fields(
159 db.query.text,
160 user_session.id = %id,
161 ),
162 err,
163 )]
164 async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
165 let res = sqlx::query_as!(
166 SessionLookup,
167 r#"
168 SELECT s.user_session_id
169 , s.created_at AS "user_session_created_at"
170 , s.finished_at AS "user_session_finished_at"
171 , s.user_agent AS "user_session_user_agent"
172 , s.last_active_at AS "user_session_last_active_at"
173 , s.last_active_ip AS "user_session_last_active_ip: IpAddr"
174 , u.user_id
175 , u.username AS "user_username"
176 , u.created_at AS "user_created_at"
177 , u.locked_at AS "user_locked_at"
178 , u.deactivated_at AS "user_deactivated_at"
179 , u.can_request_admin AS "user_can_request_admin"
180 FROM user_sessions s
181 INNER JOIN users u
182 USING (user_id)
183 WHERE s.user_session_id = $1
184 "#,
185 Uuid::from(id),
186 )
187 .traced()
188 .fetch_optional(&mut *self.conn)
189 .await?;
190
191 let Some(res) = res else { return Ok(None) };
192
193 Ok(Some(res.try_into()?))
194 }
195
196 #[tracing::instrument(
197 name = "db.browser_session.add",
198 skip_all,
199 fields(
200 db.query.text,
201 %user.id,
202 user_session.id,
203 ),
204 err,
205 )]
206 async fn add(
207 &mut self,
208 rng: &mut (dyn RngCore + Send),
209 clock: &dyn Clock,
210 user: &User,
211 user_agent: Option<UserAgent>,
212 ) -> Result<BrowserSession, Self::Error> {
213 let created_at = clock.now();
214 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
215 tracing::Span::current().record("user_session.id", tracing::field::display(id));
216
217 sqlx::query!(
218 r#"
219 INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
220 VALUES ($1, $2, $3, $4)
221 "#,
222 Uuid::from(id),
223 Uuid::from(user.id),
224 created_at,
225 user_agent.as_deref(),
226 )
227 .traced()
228 .execute(&mut *self.conn)
229 .await?;
230
231 let session = BrowserSession {
232 id,
233 user: user.clone(),
235 created_at,
236 finished_at: None,
237 user_agent,
238 last_active_at: None,
239 last_active_ip: None,
240 };
241
242 Ok(session)
243 }
244
245 #[tracing::instrument(
246 name = "db.browser_session.finish",
247 skip_all,
248 fields(
249 db.query.text,
250 %user_session.id,
251 ),
252 err,
253 )]
254 async fn finish(
255 &mut self,
256 clock: &dyn Clock,
257 mut user_session: BrowserSession,
258 ) -> Result<BrowserSession, Self::Error> {
259 let finished_at = clock.now();
260 let res = sqlx::query!(
261 r#"
262 UPDATE user_sessions
263 SET finished_at = $1
264 WHERE user_session_id = $2
265 "#,
266 finished_at,
267 Uuid::from(user_session.id),
268 )
269 .traced()
270 .execute(&mut *self.conn)
271 .await?;
272
273 user_session.finished_at = Some(finished_at);
274
275 DatabaseError::ensure_affected_rows(&res, 1)?;
276
277 Ok(user_session)
278 }
279
280 #[tracing::instrument(
281 name = "db.browser_session.finish_bulk",
282 skip_all,
283 fields(
284 db.query.text,
285 ),
286 err,
287 )]
288 async fn finish_bulk(
289 &mut self,
290 clock: &dyn Clock,
291 filter: BrowserSessionFilter<'_>,
292 ) -> Result<usize, Self::Error> {
293 let finished_at = clock.now();
294 let (sql, arguments) = sea_query::Query::update()
295 .table(UserSessions::Table)
296 .value(UserSessions::FinishedAt, finished_at)
297 .apply_filter(filter)
298 .build_sqlx(PostgresQueryBuilder);
299
300 let res = sqlx::query_with(&sql, arguments)
301 .traced()
302 .execute(&mut *self.conn)
303 .await?;
304
305 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
306 }
307
308 #[tracing::instrument(
309 name = "db.browser_session.list",
310 skip_all,
311 fields(
312 db.query.text,
313 ),
314 err,
315 )]
316 async fn list(
317 &mut self,
318 filter: BrowserSessionFilter<'_>,
319 pagination: Pagination,
320 ) -> Result<Page<BrowserSession>, Self::Error> {
321 let (sql, arguments) = sea_query::Query::select()
322 .expr_as(
323 Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
324 SessionLookupIden::UserSessionId,
325 )
326 .expr_as(
327 Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
328 SessionLookupIden::UserSessionCreatedAt,
329 )
330 .expr_as(
331 Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
332 SessionLookupIden::UserSessionFinishedAt,
333 )
334 .expr_as(
335 Expr::col((UserSessions::Table, UserSessions::UserAgent)),
336 SessionLookupIden::UserSessionUserAgent,
337 )
338 .expr_as(
339 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
340 SessionLookupIden::UserSessionLastActiveAt,
341 )
342 .expr_as(
343 Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
344 SessionLookupIden::UserSessionLastActiveIp,
345 )
346 .expr_as(
347 Expr::col((Users::Table, Users::UserId)),
348 SessionLookupIden::UserId,
349 )
350 .expr_as(
351 Expr::col((Users::Table, Users::Username)),
352 SessionLookupIden::UserUsername,
353 )
354 .expr_as(
355 Expr::col((Users::Table, Users::CreatedAt)),
356 SessionLookupIden::UserCreatedAt,
357 )
358 .expr_as(
359 Expr::col((Users::Table, Users::LockedAt)),
360 SessionLookupIden::UserLockedAt,
361 )
362 .expr_as(
363 Expr::col((Users::Table, Users::DeactivatedAt)),
364 SessionLookupIden::UserDeactivatedAt,
365 )
366 .expr_as(
367 Expr::col((Users::Table, Users::CanRequestAdmin)),
368 SessionLookupIden::UserCanRequestAdmin,
369 )
370 .from(UserSessions::Table)
371 .inner_join(
372 Users::Table,
373 Expr::col((UserSessions::Table, UserSessions::UserId))
374 .equals((Users::Table, Users::UserId)),
375 )
376 .apply_filter(filter)
377 .generate_pagination(
378 (UserSessions::Table, UserSessions::UserSessionId),
379 pagination,
380 )
381 .build_sqlx(PostgresQueryBuilder);
382
383 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
384 .traced()
385 .fetch_all(&mut *self.conn)
386 .await?;
387
388 let page = pagination
389 .process(edges)
390 .try_map(BrowserSession::try_from)?;
391
392 Ok(page)
393 }
394
395 #[tracing::instrument(
396 name = "db.browser_session.count",
397 skip_all,
398 fields(
399 db.query.text,
400 ),
401 err,
402 )]
403 async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
404 let (sql, arguments) = sea_query::Query::select()
405 .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
406 .from(UserSessions::Table)
407 .apply_filter(filter)
408 .build_sqlx(PostgresQueryBuilder);
409
410 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
411 .traced()
412 .fetch_one(&mut *self.conn)
413 .await?;
414
415 count
416 .try_into()
417 .map_err(DatabaseError::to_invalid_operation)
418 }
419
420 #[tracing::instrument(
421 name = "db.browser_session.authenticate_with_password",
422 skip_all,
423 fields(
424 db.query.text,
425 %user_session.id,
426 %user_password.id,
427 user_session_authentication.id,
428 ),
429 err,
430 )]
431 async fn authenticate_with_password(
432 &mut self,
433 rng: &mut (dyn RngCore + Send),
434 clock: &dyn Clock,
435 user_session: &BrowserSession,
436 user_password: &Password,
437 ) -> Result<Authentication, Self::Error> {
438 let created_at = clock.now();
439 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
440 tracing::Span::current().record(
441 "user_session_authentication.id",
442 tracing::field::display(id),
443 );
444
445 sqlx::query!(
446 r#"
447 INSERT INTO user_session_authentications
448 (user_session_authentication_id, user_session_id, created_at, user_password_id)
449 VALUES ($1, $2, $3, $4)
450 "#,
451 Uuid::from(id),
452 Uuid::from(user_session.id),
453 created_at,
454 Uuid::from(user_password.id),
455 )
456 .traced()
457 .execute(&mut *self.conn)
458 .await?;
459
460 Ok(Authentication {
461 id,
462 created_at,
463 authentication_method: AuthenticationMethod::Password {
464 user_password_id: user_password.id,
465 },
466 })
467 }
468
469 #[tracing::instrument(
470 name = "db.browser_session.authenticate_with_upstream",
471 skip_all,
472 fields(
473 db.query.text,
474 %user_session.id,
475 %upstream_oauth_session.id,
476 user_session_authentication.id,
477 ),
478 err,
479 )]
480 async fn authenticate_with_upstream(
481 &mut self,
482 rng: &mut (dyn RngCore + Send),
483 clock: &dyn Clock,
484 user_session: &BrowserSession,
485 upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
486 ) -> Result<Authentication, Self::Error> {
487 let created_at = clock.now();
488 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
489 tracing::Span::current().record(
490 "user_session_authentication.id",
491 tracing::field::display(id),
492 );
493
494 sqlx::query!(
495 r#"
496 INSERT INTO user_session_authentications
497 (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
498 VALUES ($1, $2, $3, $4)
499 "#,
500 Uuid::from(id),
501 Uuid::from(user_session.id),
502 created_at,
503 Uuid::from(upstream_oauth_session.id),
504 )
505 .traced()
506 .execute(&mut *self.conn)
507 .await?;
508
509 Ok(Authentication {
510 id,
511 created_at,
512 authentication_method: AuthenticationMethod::UpstreamOAuth2 {
513 upstream_oauth2_session_id: upstream_oauth_session.id,
514 },
515 })
516 }
517
518 #[tracing::instrument(
519 name = "db.browser_session.get_last_authentication",
520 skip_all,
521 fields(
522 db.query.text,
523 %user_session.id,
524 ),
525 err,
526 )]
527 async fn get_last_authentication(
528 &mut self,
529 user_session: &BrowserSession,
530 ) -> Result<Option<Authentication>, Self::Error> {
531 let authentication = sqlx::query_as!(
532 AuthenticationLookup,
533 r#"
534 SELECT user_session_authentication_id
535 , created_at
536 , user_password_id
537 , upstream_oauth_authorization_session_id
538 FROM user_session_authentications
539 WHERE user_session_id = $1
540 ORDER BY created_at DESC
541 LIMIT 1
542 "#,
543 Uuid::from(user_session.id),
544 )
545 .traced()
546 .fetch_optional(&mut *self.conn)
547 .await?;
548
549 let Some(authentication) = authentication else {
550 return Ok(None);
551 };
552
553 let authentication = Authentication::try_from(authentication)?;
554 Ok(Some(authentication))
555 }
556
557 #[tracing::instrument(
558 name = "db.browser_session.record_batch_activity",
559 skip_all,
560 fields(
561 db.query.text,
562 ),
563 err,
564 )]
565 async fn record_batch_activity(
566 &mut self,
567 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
568 ) -> Result<(), Self::Error> {
569 let mut ids = Vec::with_capacity(activity.len());
570 let mut last_activities = Vec::with_capacity(activity.len());
571 let mut ips = Vec::with_capacity(activity.len());
572
573 for (id, last_activity, ip) in activity {
574 ids.push(Uuid::from(id));
575 last_activities.push(last_activity);
576 ips.push(ip);
577 }
578
579 let res = sqlx::query!(
580 r#"
581 UPDATE user_sessions
582 SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
583 , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
584 FROM (
585 SELECT *
586 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
587 AS t(user_session_id, last_active_at, last_active_ip)
588 ) AS t
589 WHERE user_sessions.user_session_id = t.user_session_id
590 "#,
591 &ids,
592 &last_activities,
593 &ips as &[Option<IpAddr>],
594 )
595 .traced()
596 .execute(&mut *self.conn)
597 .await?;
598
599 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
600
601 Ok(())
602 }
603}