mas_storage_pg/compat/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13    User, UserAgent,
14};
15use mas_storage::{
16    Clock, Page, Pagination,
17    compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28    DatabaseError, DatabaseInconsistencyError,
29    filter::{Filter, StatementExt, StatementWithJoinsExt},
30    iden::{CompatSessions, CompatSsoLogins},
31    pagination::QueryBuilderExt,
32    tracing::ExecuteExt,
33};
34
35/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
36pub struct PgCompatSessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41    /// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48struct CompatSessionLookup {
49    compat_session_id: Uuid,
50    device_id: Option<String>,
51    human_name: Option<String>,
52    user_id: Uuid,
53    user_session_id: Option<Uuid>,
54    created_at: DateTime<Utc>,
55    finished_at: Option<DateTime<Utc>>,
56    is_synapse_admin: bool,
57    user_agent: Option<String>,
58    last_active_at: Option<DateTime<Utc>>,
59    last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63    fn from(value: CompatSessionLookup) -> Self {
64        let id = value.compat_session_id.into();
65
66        let state = match value.finished_at {
67            None => CompatSessionState::Valid,
68            Some(finished_at) => CompatSessionState::Finished { finished_at },
69        };
70
71        CompatSession {
72            id,
73            state,
74            user_id: value.user_id.into(),
75            user_session_id: value.user_session_id.map(Ulid::from),
76            device: value.device_id.map(Device::from),
77            human_name: value.human_name,
78            created_at: value.created_at,
79            is_synapse_admin: value.is_synapse_admin,
80            user_agent: value.user_agent.map(UserAgent::parse),
81            last_active_at: value.last_active_at,
82            last_active_ip: value.last_active_ip,
83        }
84    }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90    compat_session_id: Uuid,
91    device_id: Option<String>,
92    human_name: Option<String>,
93    user_id: Uuid,
94    user_session_id: Option<Uuid>,
95    created_at: DateTime<Utc>,
96    finished_at: Option<DateTime<Utc>>,
97    is_synapse_admin: bool,
98    user_agent: Option<String>,
99    last_active_at: Option<DateTime<Utc>>,
100    last_active_ip: Option<IpAddr>,
101    compat_sso_login_id: Option<Uuid>,
102    compat_sso_login_token: Option<String>,
103    compat_sso_login_redirect_uri: Option<String>,
104    compat_sso_login_created_at: Option<DateTime<Utc>>,
105    compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106    compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110    type Error = DatabaseInconsistencyError;
111
112    fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113        let id = value.compat_session_id.into();
114
115        let state = match value.finished_at {
116            None => CompatSessionState::Valid,
117            Some(finished_at) => CompatSessionState::Finished { finished_at },
118        };
119
120        let session = CompatSession {
121            id,
122            state,
123            user_id: value.user_id.into(),
124            device: value.device_id.map(Device::from),
125            human_name: value.human_name,
126            user_session_id: value.user_session_id.map(Ulid::from),
127            created_at: value.created_at,
128            is_synapse_admin: value.is_synapse_admin,
129            user_agent: value.user_agent.map(UserAgent::parse),
130            last_active_at: value.last_active_at,
131            last_active_ip: value.last_active_ip,
132        };
133
134        match (
135            value.compat_sso_login_id,
136            value.compat_sso_login_token,
137            value.compat_sso_login_redirect_uri,
138            value.compat_sso_login_created_at,
139            value.compat_sso_login_fulfilled_at,
140            value.compat_sso_login_exchanged_at,
141        ) {
142            (None, None, None, None, None, None) => Ok((session, None)),
143            (
144                Some(id),
145                Some(login_token),
146                Some(redirect_uri),
147                Some(created_at),
148                fulfilled_at,
149                exchanged_at,
150            ) => {
151                let id = id.into();
152                let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153                    DatabaseInconsistencyError::on("compat_sso_logins")
154                        .column("redirect_uri")
155                        .row(id)
156                        .source(e)
157                })?;
158
159                let state = match (fulfilled_at, exchanged_at) {
160                    (Some(fulfilled_at), None) => CompatSsoLoginState::Fulfilled {
161                        fulfilled_at,
162                        session_id: session.id,
163                    },
164                    (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
165                        fulfilled_at,
166                        exchanged_at,
167                        session_id: session.id,
168                    },
169                    _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
170                };
171
172                let login = CompatSsoLogin {
173                    id,
174                    redirect_uri,
175                    login_token,
176                    created_at,
177                    state,
178                };
179
180                Ok((session, Some(login)))
181            }
182            _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
183        }
184    }
185}
186
187impl Filter for CompatSessionFilter<'_> {
188    fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
189        sea_query::Condition::all()
190            .add_option(self.user().map(|user| {
191                Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
192            }))
193            .add_option(self.browser_session().map(|browser_session| {
194                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
195                    .eq(Uuid::from(browser_session.id))
196            }))
197            .add_option(self.state().map(|state| {
198                if state.is_active() {
199                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
200                } else {
201                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
202                }
203            }))
204            .add_option(self.auth_type().map(|auth_type| {
205                // In in the SELECT to list sessions, we can rely on the JOINed table, whereas
206                // in other queries we need to do a subquery
207                if has_joins {
208                    if auth_type.is_sso_login() {
209                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
210                            .is_not_null()
211                    } else {
212                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
213                            .is_null()
214                    }
215                } else {
216                    // This builds either a:
217                    // `WHERE compat_session_id = ANY(...)`
218                    // or a `WHERE compat_session_id <> ALL(...)`
219                    let compat_sso_logins = Query::select()
220                        .expr(Expr::col((
221                            CompatSsoLogins::Table,
222                            CompatSsoLogins::CompatSessionId,
223                        )))
224                        .from(CompatSsoLogins::Table)
225                        .take();
226
227                    if auth_type.is_sso_login() {
228                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
229                            .eq(Expr::any(compat_sso_logins))
230                    } else {
231                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
232                            .ne(Expr::all(compat_sso_logins))
233                    }
234                }
235            }))
236            .add_option(self.last_active_after().map(|last_active_after| {
237                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
238                    .gt(last_active_after)
239            }))
240            .add_option(self.last_active_before().map(|last_active_before| {
241                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
242                    .lt(last_active_before)
243            }))
244            .add_option(self.device().map(|device| {
245                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
246            }))
247    }
248}
249
250#[async_trait]
251impl CompatSessionRepository for PgCompatSessionRepository<'_> {
252    type Error = DatabaseError;
253
254    #[tracing::instrument(
255        name = "db.compat_session.lookup",
256        skip_all,
257        fields(
258            db.query.text,
259            compat_session.id = %id,
260        ),
261        err,
262    )]
263    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
264        let res = sqlx::query_as!(
265            CompatSessionLookup,
266            r#"
267                SELECT compat_session_id
268                     , device_id
269                     , human_name
270                     , user_id
271                     , user_session_id
272                     , created_at
273                     , finished_at
274                     , is_synapse_admin
275                     , user_agent
276                     , last_active_at
277                     , last_active_ip as "last_active_ip: IpAddr"
278                FROM compat_sessions
279                WHERE compat_session_id = $1
280            "#,
281            Uuid::from(id),
282        )
283        .traced()
284        .fetch_optional(&mut *self.conn)
285        .await?;
286
287        let Some(res) = res else { return Ok(None) };
288
289        Ok(Some(res.into()))
290    }
291
292    #[tracing::instrument(
293        name = "db.compat_session.add",
294        skip_all,
295        fields(
296            db.query.text,
297            compat_session.id,
298            %user.id,
299            %user.username,
300            compat_session.device.id = device.as_str(),
301        ),
302        err,
303    )]
304    async fn add(
305        &mut self,
306        rng: &mut (dyn RngCore + Send),
307        clock: &dyn Clock,
308        user: &User,
309        device: Device,
310        browser_session: Option<&BrowserSession>,
311        is_synapse_admin: bool,
312    ) -> Result<CompatSession, Self::Error> {
313        let created_at = clock.now();
314        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
315        tracing::Span::current().record("compat_session.id", tracing::field::display(id));
316
317        sqlx::query!(
318            r#"
319                INSERT INTO compat_sessions
320                    (compat_session_id, user_id, device_id,
321                     user_session_id, created_at, is_synapse_admin)
322                VALUES ($1, $2, $3, $4, $5, $6)
323            "#,
324            Uuid::from(id),
325            Uuid::from(user.id),
326            device.as_str(),
327            browser_session.map(|s| Uuid::from(s.id)),
328            created_at,
329            is_synapse_admin,
330        )
331        .traced()
332        .execute(&mut *self.conn)
333        .await?;
334
335        Ok(CompatSession {
336            id,
337            state: CompatSessionState::default(),
338            user_id: user.id,
339            device: Some(device),
340            human_name: None,
341            user_session_id: browser_session.map(|s| s.id),
342            created_at,
343            is_synapse_admin,
344            user_agent: None,
345            last_active_at: None,
346            last_active_ip: None,
347        })
348    }
349
350    #[tracing::instrument(
351        name = "db.compat_session.finish",
352        skip_all,
353        fields(
354            db.query.text,
355            %compat_session.id,
356            user.id = %compat_session.user_id,
357            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
358        ),
359        err,
360    )]
361    async fn finish(
362        &mut self,
363        clock: &dyn Clock,
364        compat_session: CompatSession,
365    ) -> Result<CompatSession, Self::Error> {
366        let finished_at = clock.now();
367
368        let res = sqlx::query!(
369            r#"
370                UPDATE compat_sessions cs
371                SET finished_at = $2
372                WHERE compat_session_id = $1
373            "#,
374            Uuid::from(compat_session.id),
375            finished_at,
376        )
377        .traced()
378        .execute(&mut *self.conn)
379        .await?;
380
381        DatabaseError::ensure_affected_rows(&res, 1)?;
382
383        let compat_session = compat_session
384            .finish(finished_at)
385            .map_err(DatabaseError::to_invalid_operation)?;
386
387        Ok(compat_session)
388    }
389
390    #[tracing::instrument(
391        name = "db.compat_session.finish_bulk",
392        skip_all,
393        fields(db.query.text),
394        err,
395    )]
396    async fn finish_bulk(
397        &mut self,
398        clock: &dyn Clock,
399        filter: CompatSessionFilter<'_>,
400    ) -> Result<usize, Self::Error> {
401        let finished_at = clock.now();
402        let (sql, arguments) = Query::update()
403            .table(CompatSessions::Table)
404            .value(CompatSessions::FinishedAt, finished_at)
405            .apply_filter(filter)
406            .build_sqlx(PostgresQueryBuilder);
407
408        let res = sqlx::query_with(&sql, arguments)
409            .traced()
410            .execute(&mut *self.conn)
411            .await?;
412
413        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
414    }
415
416    #[tracing::instrument(
417        name = "db.compat_session.list",
418        skip_all,
419        fields(
420            db.query.text,
421        ),
422        err,
423    )]
424    async fn list(
425        &mut self,
426        filter: CompatSessionFilter<'_>,
427        pagination: Pagination,
428    ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
429        let (sql, arguments) = Query::select()
430            .expr_as(
431                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
432                CompatSessionAndSsoLoginLookupIden::CompatSessionId,
433            )
434            .expr_as(
435                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
436                CompatSessionAndSsoLoginLookupIden::DeviceId,
437            )
438            .expr_as(
439                Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
440                CompatSessionAndSsoLoginLookupIden::HumanName,
441            )
442            .expr_as(
443                Expr::col((CompatSessions::Table, CompatSessions::UserId)),
444                CompatSessionAndSsoLoginLookupIden::UserId,
445            )
446            .expr_as(
447                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
448                CompatSessionAndSsoLoginLookupIden::UserSessionId,
449            )
450            .expr_as(
451                Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
452                CompatSessionAndSsoLoginLookupIden::CreatedAt,
453            )
454            .expr_as(
455                Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
456                CompatSessionAndSsoLoginLookupIden::FinishedAt,
457            )
458            .expr_as(
459                Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
460                CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
461            )
462            .expr_as(
463                Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
464                CompatSessionAndSsoLoginLookupIden::UserAgent,
465            )
466            .expr_as(
467                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
468                CompatSessionAndSsoLoginLookupIden::LastActiveAt,
469            )
470            .expr_as(
471                Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
472                CompatSessionAndSsoLoginLookupIden::LastActiveIp,
473            )
474            .expr_as(
475                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
476                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
477            )
478            .expr_as(
479                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
480                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
481            )
482            .expr_as(
483                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
484                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
485            )
486            .expr_as(
487                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
488                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
489            )
490            .expr_as(
491                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
492                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
493            )
494            .expr_as(
495                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
496                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
497            )
498            .from(CompatSessions::Table)
499            .left_join(
500                CompatSsoLogins::Table,
501                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
502                    .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
503            )
504            .apply_filter_with_joins(filter)
505            .generate_pagination(
506                (CompatSessions::Table, CompatSessions::CompatSessionId),
507                pagination,
508            )
509            .build_sqlx(PostgresQueryBuilder);
510
511        let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
512            .traced()
513            .fetch_all(&mut *self.conn)
514            .await?;
515
516        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
517
518        Ok(page)
519    }
520
521    #[tracing::instrument(
522        name = "db.compat_session.count",
523        skip_all,
524        fields(
525            db.query.text,
526        ),
527        err,
528    )]
529    async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
530        let (sql, arguments) = sea_query::Query::select()
531            .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
532            .from(CompatSessions::Table)
533            .apply_filter(filter)
534            .build_sqlx(PostgresQueryBuilder);
535
536        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
537            .traced()
538            .fetch_one(&mut *self.conn)
539            .await?;
540
541        count
542            .try_into()
543            .map_err(DatabaseError::to_invalid_operation)
544    }
545
546    #[tracing::instrument(
547        name = "db.compat_session.record_batch_activity",
548        skip_all,
549        fields(
550            db.query.text,
551        ),
552        err,
553    )]
554    async fn record_batch_activity(
555        &mut self,
556        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
557    ) -> Result<(), Self::Error> {
558        let mut ids = Vec::with_capacity(activity.len());
559        let mut last_activities = Vec::with_capacity(activity.len());
560        let mut ips = Vec::with_capacity(activity.len());
561
562        for (id, last_activity, ip) in activity {
563            ids.push(Uuid::from(id));
564            last_activities.push(last_activity);
565            ips.push(ip);
566        }
567
568        let res = sqlx::query!(
569            r#"
570                UPDATE compat_sessions
571                SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
572                  , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
573                FROM (
574                    SELECT *
575                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
576                        AS t(compat_session_id, last_active_at, last_active_ip)
577                ) AS t
578                WHERE compat_sessions.compat_session_id = t.compat_session_id
579            "#,
580            &ids,
581            &last_activities,
582            &ips as &[Option<IpAddr>],
583        )
584        .traced()
585        .execute(&mut *self.conn)
586        .await?;
587
588        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
589
590        Ok(())
591    }
592
593    #[tracing::instrument(
594        name = "db.compat_session.record_user_agent",
595        skip_all,
596        fields(
597            db.query.text,
598            %compat_session.id,
599        ),
600        err,
601    )]
602    async fn record_user_agent(
603        &mut self,
604        mut compat_session: CompatSession,
605        user_agent: UserAgent,
606    ) -> Result<CompatSession, Self::Error> {
607        let res = sqlx::query!(
608            r#"
609            UPDATE compat_sessions
610            SET user_agent = $2
611            WHERE compat_session_id = $1
612        "#,
613            Uuid::from(compat_session.id),
614            &*user_agent,
615        )
616        .traced()
617        .execute(&mut *self.conn)
618        .await?;
619
620        compat_session.user_agent = Some(user_agent);
621
622        DatabaseError::ensure_affected_rows(&res, 1)?;
623
624        Ok(compat_session)
625    }
626}