1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13 User, UserAgent,
14};
15use mas_storage::{
16 Clock, Page, Pagination,
17 compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28 DatabaseError, DatabaseInconsistencyError,
29 filter::{Filter, StatementExt, StatementWithJoinsExt},
30 iden::{CompatSessions, CompatSsoLogins},
31 pagination::QueryBuilderExt,
32 tracing::ExecuteExt,
33};
34
35pub struct PgCompatSessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48struct CompatSessionLookup {
49 compat_session_id: Uuid,
50 device_id: Option<String>,
51 human_name: Option<String>,
52 user_id: Uuid,
53 user_session_id: Option<Uuid>,
54 created_at: DateTime<Utc>,
55 finished_at: Option<DateTime<Utc>>,
56 is_synapse_admin: bool,
57 user_agent: Option<String>,
58 last_active_at: Option<DateTime<Utc>>,
59 last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63 fn from(value: CompatSessionLookup) -> Self {
64 let id = value.compat_session_id.into();
65
66 let state = match value.finished_at {
67 None => CompatSessionState::Valid,
68 Some(finished_at) => CompatSessionState::Finished { finished_at },
69 };
70
71 CompatSession {
72 id,
73 state,
74 user_id: value.user_id.into(),
75 user_session_id: value.user_session_id.map(Ulid::from),
76 device: value.device_id.map(Device::from),
77 human_name: value.human_name,
78 created_at: value.created_at,
79 is_synapse_admin: value.is_synapse_admin,
80 user_agent: value.user_agent.map(UserAgent::parse),
81 last_active_at: value.last_active_at,
82 last_active_ip: value.last_active_ip,
83 }
84 }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90 compat_session_id: Uuid,
91 device_id: Option<String>,
92 human_name: Option<String>,
93 user_id: Uuid,
94 user_session_id: Option<Uuid>,
95 created_at: DateTime<Utc>,
96 finished_at: Option<DateTime<Utc>>,
97 is_synapse_admin: bool,
98 user_agent: Option<String>,
99 last_active_at: Option<DateTime<Utc>>,
100 last_active_ip: Option<IpAddr>,
101 compat_sso_login_id: Option<Uuid>,
102 compat_sso_login_token: Option<String>,
103 compat_sso_login_redirect_uri: Option<String>,
104 compat_sso_login_created_at: Option<DateTime<Utc>>,
105 compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106 compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110 type Error = DatabaseInconsistencyError;
111
112 fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113 let id = value.compat_session_id.into();
114
115 let state = match value.finished_at {
116 None => CompatSessionState::Valid,
117 Some(finished_at) => CompatSessionState::Finished { finished_at },
118 };
119
120 let session = CompatSession {
121 id,
122 state,
123 user_id: value.user_id.into(),
124 device: value.device_id.map(Device::from),
125 human_name: value.human_name,
126 user_session_id: value.user_session_id.map(Ulid::from),
127 created_at: value.created_at,
128 is_synapse_admin: value.is_synapse_admin,
129 user_agent: value.user_agent.map(UserAgent::parse),
130 last_active_at: value.last_active_at,
131 last_active_ip: value.last_active_ip,
132 };
133
134 match (
135 value.compat_sso_login_id,
136 value.compat_sso_login_token,
137 value.compat_sso_login_redirect_uri,
138 value.compat_sso_login_created_at,
139 value.compat_sso_login_fulfilled_at,
140 value.compat_sso_login_exchanged_at,
141 ) {
142 (None, None, None, None, None, None) => Ok((session, None)),
143 (
144 Some(id),
145 Some(login_token),
146 Some(redirect_uri),
147 Some(created_at),
148 fulfilled_at,
149 exchanged_at,
150 ) => {
151 let id = id.into();
152 let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153 DatabaseInconsistencyError::on("compat_sso_logins")
154 .column("redirect_uri")
155 .row(id)
156 .source(e)
157 })?;
158
159 let state = match (fulfilled_at, exchanged_at) {
160 (Some(fulfilled_at), None) => CompatSsoLoginState::Fulfilled {
161 fulfilled_at,
162 session_id: session.id,
163 },
164 (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
165 fulfilled_at,
166 exchanged_at,
167 session_id: session.id,
168 },
169 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
170 };
171
172 let login = CompatSsoLogin {
173 id,
174 redirect_uri,
175 login_token,
176 created_at,
177 state,
178 };
179
180 Ok((session, Some(login)))
181 }
182 _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
183 }
184 }
185}
186
187impl Filter for CompatSessionFilter<'_> {
188 fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
189 sea_query::Condition::all()
190 .add_option(self.user().map(|user| {
191 Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
192 }))
193 .add_option(self.browser_session().map(|browser_session| {
194 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
195 .eq(Uuid::from(browser_session.id))
196 }))
197 .add_option(self.state().map(|state| {
198 if state.is_active() {
199 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
200 } else {
201 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
202 }
203 }))
204 .add_option(self.auth_type().map(|auth_type| {
205 if has_joins {
208 if auth_type.is_sso_login() {
209 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
210 .is_not_null()
211 } else {
212 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
213 .is_null()
214 }
215 } else {
216 let compat_sso_logins = Query::select()
220 .expr(Expr::col((
221 CompatSsoLogins::Table,
222 CompatSsoLogins::CompatSessionId,
223 )))
224 .from(CompatSsoLogins::Table)
225 .take();
226
227 if auth_type.is_sso_login() {
228 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
229 .eq(Expr::any(compat_sso_logins))
230 } else {
231 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
232 .ne(Expr::all(compat_sso_logins))
233 }
234 }
235 }))
236 .add_option(self.last_active_after().map(|last_active_after| {
237 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
238 .gt(last_active_after)
239 }))
240 .add_option(self.last_active_before().map(|last_active_before| {
241 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
242 .lt(last_active_before)
243 }))
244 .add_option(self.device().map(|device| {
245 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
246 }))
247 }
248}
249
250#[async_trait]
251impl CompatSessionRepository for PgCompatSessionRepository<'_> {
252 type Error = DatabaseError;
253
254 #[tracing::instrument(
255 name = "db.compat_session.lookup",
256 skip_all,
257 fields(
258 db.query.text,
259 compat_session.id = %id,
260 ),
261 err,
262 )]
263 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
264 let res = sqlx::query_as!(
265 CompatSessionLookup,
266 r#"
267 SELECT compat_session_id
268 , device_id
269 , human_name
270 , user_id
271 , user_session_id
272 , created_at
273 , finished_at
274 , is_synapse_admin
275 , user_agent
276 , last_active_at
277 , last_active_ip as "last_active_ip: IpAddr"
278 FROM compat_sessions
279 WHERE compat_session_id = $1
280 "#,
281 Uuid::from(id),
282 )
283 .traced()
284 .fetch_optional(&mut *self.conn)
285 .await?;
286
287 let Some(res) = res else { return Ok(None) };
288
289 Ok(Some(res.into()))
290 }
291
292 #[tracing::instrument(
293 name = "db.compat_session.add",
294 skip_all,
295 fields(
296 db.query.text,
297 compat_session.id,
298 %user.id,
299 %user.username,
300 compat_session.device.id = device.as_str(),
301 ),
302 err,
303 )]
304 async fn add(
305 &mut self,
306 rng: &mut (dyn RngCore + Send),
307 clock: &dyn Clock,
308 user: &User,
309 device: Device,
310 browser_session: Option<&BrowserSession>,
311 is_synapse_admin: bool,
312 ) -> Result<CompatSession, Self::Error> {
313 let created_at = clock.now();
314 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
315 tracing::Span::current().record("compat_session.id", tracing::field::display(id));
316
317 sqlx::query!(
318 r#"
319 INSERT INTO compat_sessions
320 (compat_session_id, user_id, device_id,
321 user_session_id, created_at, is_synapse_admin)
322 VALUES ($1, $2, $3, $4, $5, $6)
323 "#,
324 Uuid::from(id),
325 Uuid::from(user.id),
326 device.as_str(),
327 browser_session.map(|s| Uuid::from(s.id)),
328 created_at,
329 is_synapse_admin,
330 )
331 .traced()
332 .execute(&mut *self.conn)
333 .await?;
334
335 Ok(CompatSession {
336 id,
337 state: CompatSessionState::default(),
338 user_id: user.id,
339 device: Some(device),
340 human_name: None,
341 user_session_id: browser_session.map(|s| s.id),
342 created_at,
343 is_synapse_admin,
344 user_agent: None,
345 last_active_at: None,
346 last_active_ip: None,
347 })
348 }
349
350 #[tracing::instrument(
351 name = "db.compat_session.finish",
352 skip_all,
353 fields(
354 db.query.text,
355 %compat_session.id,
356 user.id = %compat_session.user_id,
357 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
358 ),
359 err,
360 )]
361 async fn finish(
362 &mut self,
363 clock: &dyn Clock,
364 compat_session: CompatSession,
365 ) -> Result<CompatSession, Self::Error> {
366 let finished_at = clock.now();
367
368 let res = sqlx::query!(
369 r#"
370 UPDATE compat_sessions cs
371 SET finished_at = $2
372 WHERE compat_session_id = $1
373 "#,
374 Uuid::from(compat_session.id),
375 finished_at,
376 )
377 .traced()
378 .execute(&mut *self.conn)
379 .await?;
380
381 DatabaseError::ensure_affected_rows(&res, 1)?;
382
383 let compat_session = compat_session
384 .finish(finished_at)
385 .map_err(DatabaseError::to_invalid_operation)?;
386
387 Ok(compat_session)
388 }
389
390 #[tracing::instrument(
391 name = "db.compat_session.finish_bulk",
392 skip_all,
393 fields(db.query.text),
394 err,
395 )]
396 async fn finish_bulk(
397 &mut self,
398 clock: &dyn Clock,
399 filter: CompatSessionFilter<'_>,
400 ) -> Result<usize, Self::Error> {
401 let finished_at = clock.now();
402 let (sql, arguments) = Query::update()
403 .table(CompatSessions::Table)
404 .value(CompatSessions::FinishedAt, finished_at)
405 .apply_filter(filter)
406 .build_sqlx(PostgresQueryBuilder);
407
408 let res = sqlx::query_with(&sql, arguments)
409 .traced()
410 .execute(&mut *self.conn)
411 .await?;
412
413 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
414 }
415
416 #[tracing::instrument(
417 name = "db.compat_session.list",
418 skip_all,
419 fields(
420 db.query.text,
421 ),
422 err,
423 )]
424 async fn list(
425 &mut self,
426 filter: CompatSessionFilter<'_>,
427 pagination: Pagination,
428 ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
429 let (sql, arguments) = Query::select()
430 .expr_as(
431 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
432 CompatSessionAndSsoLoginLookupIden::CompatSessionId,
433 )
434 .expr_as(
435 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
436 CompatSessionAndSsoLoginLookupIden::DeviceId,
437 )
438 .expr_as(
439 Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
440 CompatSessionAndSsoLoginLookupIden::HumanName,
441 )
442 .expr_as(
443 Expr::col((CompatSessions::Table, CompatSessions::UserId)),
444 CompatSessionAndSsoLoginLookupIden::UserId,
445 )
446 .expr_as(
447 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
448 CompatSessionAndSsoLoginLookupIden::UserSessionId,
449 )
450 .expr_as(
451 Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
452 CompatSessionAndSsoLoginLookupIden::CreatedAt,
453 )
454 .expr_as(
455 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
456 CompatSessionAndSsoLoginLookupIden::FinishedAt,
457 )
458 .expr_as(
459 Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
460 CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
461 )
462 .expr_as(
463 Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
464 CompatSessionAndSsoLoginLookupIden::UserAgent,
465 )
466 .expr_as(
467 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
468 CompatSessionAndSsoLoginLookupIden::LastActiveAt,
469 )
470 .expr_as(
471 Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
472 CompatSessionAndSsoLoginLookupIden::LastActiveIp,
473 )
474 .expr_as(
475 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
476 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
477 )
478 .expr_as(
479 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
480 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
481 )
482 .expr_as(
483 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
484 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
485 )
486 .expr_as(
487 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
488 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
489 )
490 .expr_as(
491 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
492 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
493 )
494 .expr_as(
495 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
496 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
497 )
498 .from(CompatSessions::Table)
499 .left_join(
500 CompatSsoLogins::Table,
501 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
502 .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
503 )
504 .apply_filter_with_joins(filter)
505 .generate_pagination(
506 (CompatSessions::Table, CompatSessions::CompatSessionId),
507 pagination,
508 )
509 .build_sqlx(PostgresQueryBuilder);
510
511 let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
512 .traced()
513 .fetch_all(&mut *self.conn)
514 .await?;
515
516 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
517
518 Ok(page)
519 }
520
521 #[tracing::instrument(
522 name = "db.compat_session.count",
523 skip_all,
524 fields(
525 db.query.text,
526 ),
527 err,
528 )]
529 async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
530 let (sql, arguments) = sea_query::Query::select()
531 .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
532 .from(CompatSessions::Table)
533 .apply_filter(filter)
534 .build_sqlx(PostgresQueryBuilder);
535
536 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
537 .traced()
538 .fetch_one(&mut *self.conn)
539 .await?;
540
541 count
542 .try_into()
543 .map_err(DatabaseError::to_invalid_operation)
544 }
545
546 #[tracing::instrument(
547 name = "db.compat_session.record_batch_activity",
548 skip_all,
549 fields(
550 db.query.text,
551 ),
552 err,
553 )]
554 async fn record_batch_activity(
555 &mut self,
556 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
557 ) -> Result<(), Self::Error> {
558 let mut ids = Vec::with_capacity(activity.len());
559 let mut last_activities = Vec::with_capacity(activity.len());
560 let mut ips = Vec::with_capacity(activity.len());
561
562 for (id, last_activity, ip) in activity {
563 ids.push(Uuid::from(id));
564 last_activities.push(last_activity);
565 ips.push(ip);
566 }
567
568 let res = sqlx::query!(
569 r#"
570 UPDATE compat_sessions
571 SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
572 , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
573 FROM (
574 SELECT *
575 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
576 AS t(compat_session_id, last_active_at, last_active_ip)
577 ) AS t
578 WHERE compat_sessions.compat_session_id = t.compat_session_id
579 "#,
580 &ids,
581 &last_activities,
582 &ips as &[Option<IpAddr>],
583 )
584 .traced()
585 .execute(&mut *self.conn)
586 .await?;
587
588 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
589
590 Ok(())
591 }
592
593 #[tracing::instrument(
594 name = "db.compat_session.record_user_agent",
595 skip_all,
596 fields(
597 db.query.text,
598 %compat_session.id,
599 ),
600 err,
601 )]
602 async fn record_user_agent(
603 &mut self,
604 mut compat_session: CompatSession,
605 user_agent: UserAgent,
606 ) -> Result<CompatSession, Self::Error> {
607 let res = sqlx::query!(
608 r#"
609 UPDATE compat_sessions
610 SET user_agent = $2
611 WHERE compat_session_id = $1
612 "#,
613 Uuid::from(compat_session.id),
614 &*user_agent,
615 )
616 .traced()
617 .execute(&mut *self.conn)
618 .await?;
619
620 compat_session.user_agent = Some(user_agent);
621
622 DatabaseError::ensure_affected_rows(&res, 1)?;
623
624 Ok(compat_session)
625 }
626}