mas_storage_pg/user/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    Authentication, AuthenticationMethod, BrowserSession, Password,
13    UpstreamOAuthAuthorizationSession, User, UserAgent,
14};
15use mas_storage::{
16    Clock, Page, Pagination,
17    user::{BrowserSessionFilter, BrowserSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27    DatabaseError, DatabaseInconsistencyError,
28    filter::StatementExt,
29    iden::{UserSessions, Users},
30    pagination::QueryBuilderExt,
31    tracing::ExecuteExt,
32};
33
34/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
35/// connection
36pub struct PgBrowserSessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgBrowserSessionRepository<'c> {
41    /// Create a new [`PgBrowserSessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48#[allow(clippy::struct_field_names)]
49#[derive(sqlx::FromRow)]
50#[sea_query::enum_def]
51struct SessionLookup {
52    user_session_id: Uuid,
53    user_session_created_at: DateTime<Utc>,
54    user_session_finished_at: Option<DateTime<Utc>>,
55    user_session_user_agent: Option<String>,
56    user_session_last_active_at: Option<DateTime<Utc>>,
57    user_session_last_active_ip: Option<IpAddr>,
58    user_id: Uuid,
59    user_username: String,
60    user_created_at: DateTime<Utc>,
61    user_locked_at: Option<DateTime<Utc>>,
62    user_deactivated_at: Option<DateTime<Utc>>,
63    user_can_request_admin: bool,
64}
65
66impl TryFrom<SessionLookup> for BrowserSession {
67    type Error = DatabaseInconsistencyError;
68
69    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
70        let id = Ulid::from(value.user_id);
71        let user = User {
72            id,
73            username: value.user_username,
74            sub: id.to_string(),
75            created_at: value.user_created_at,
76            locked_at: value.user_locked_at,
77            deactivated_at: value.user_deactivated_at,
78            can_request_admin: value.user_can_request_admin,
79        };
80
81        Ok(BrowserSession {
82            id: value.user_session_id.into(),
83            user,
84            created_at: value.user_session_created_at,
85            finished_at: value.user_session_finished_at,
86            user_agent: value.user_session_user_agent.map(UserAgent::parse),
87            last_active_at: value.user_session_last_active_at,
88            last_active_ip: value.user_session_last_active_ip,
89        })
90    }
91}
92
93struct AuthenticationLookup {
94    user_session_authentication_id: Uuid,
95    created_at: DateTime<Utc>,
96    user_password_id: Option<Uuid>,
97    upstream_oauth_authorization_session_id: Option<Uuid>,
98}
99
100impl TryFrom<AuthenticationLookup> for Authentication {
101    type Error = DatabaseInconsistencyError;
102
103    fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
104        let id = Ulid::from(value.user_session_authentication_id);
105        let authentication_method = match (
106            value.user_password_id.map(Into::into),
107            value
108                .upstream_oauth_authorization_session_id
109                .map(Into::into),
110        ) {
111            (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
112            (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
113                upstream_oauth2_session_id,
114            },
115            (None, None) => AuthenticationMethod::Unknown,
116            _ => {
117                return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
118            }
119        };
120
121        Ok(Authentication {
122            id,
123            created_at: value.created_at,
124            authentication_method,
125        })
126    }
127}
128
129impl crate::filter::Filter for BrowserSessionFilter<'_> {
130    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
131        sea_query::Condition::all()
132            .add_option(self.user().map(|user| {
133                Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
134            }))
135            .add_option(self.state().map(|state| {
136                if state.is_active() {
137                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
138                } else {
139                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
140                }
141            }))
142            .add_option(self.last_active_after().map(|last_active_after| {
143                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
144            }))
145            .add_option(self.last_active_before().map(|last_active_before| {
146                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
147            }))
148    }
149}
150
151#[async_trait]
152impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
153    type Error = DatabaseError;
154
155    #[tracing::instrument(
156        name = "db.browser_session.lookup",
157        skip_all,
158        fields(
159            db.query.text,
160            user_session.id = %id,
161        ),
162        err,
163    )]
164    async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
165        let res = sqlx::query_as!(
166            SessionLookup,
167            r#"
168                SELECT s.user_session_id
169                     , s.created_at            AS "user_session_created_at"
170                     , s.finished_at           AS "user_session_finished_at"
171                     , s.user_agent            AS "user_session_user_agent"
172                     , s.last_active_at        AS "user_session_last_active_at"
173                     , s.last_active_ip        AS "user_session_last_active_ip: IpAddr"
174                     , u.user_id
175                     , u.username              AS "user_username"
176                     , u.created_at            AS "user_created_at"
177                     , u.locked_at             AS "user_locked_at"
178                     , u.deactivated_at        AS "user_deactivated_at"
179                     , u.can_request_admin     AS "user_can_request_admin"
180                FROM user_sessions s
181                INNER JOIN users u
182                    USING (user_id)
183                WHERE s.user_session_id = $1
184            "#,
185            Uuid::from(id),
186        )
187        .traced()
188        .fetch_optional(&mut *self.conn)
189        .await?;
190
191        let Some(res) = res else { return Ok(None) };
192
193        Ok(Some(res.try_into()?))
194    }
195
196    #[tracing::instrument(
197        name = "db.browser_session.add",
198        skip_all,
199        fields(
200            db.query.text,
201            %user.id,
202            user_session.id,
203        ),
204        err,
205    )]
206    async fn add(
207        &mut self,
208        rng: &mut (dyn RngCore + Send),
209        clock: &dyn Clock,
210        user: &User,
211        user_agent: Option<UserAgent>,
212    ) -> Result<BrowserSession, Self::Error> {
213        let created_at = clock.now();
214        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
215        tracing::Span::current().record("user_session.id", tracing::field::display(id));
216
217        sqlx::query!(
218            r#"
219                INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
220                VALUES ($1, $2, $3, $4)
221            "#,
222            Uuid::from(id),
223            Uuid::from(user.id),
224            created_at,
225            user_agent.as_deref(),
226        )
227        .traced()
228        .execute(&mut *self.conn)
229        .await?;
230
231        let session = BrowserSession {
232            id,
233            // XXX
234            user: user.clone(),
235            created_at,
236            finished_at: None,
237            user_agent,
238            last_active_at: None,
239            last_active_ip: None,
240        };
241
242        Ok(session)
243    }
244
245    #[tracing::instrument(
246        name = "db.browser_session.finish",
247        skip_all,
248        fields(
249            db.query.text,
250            %user_session.id,
251        ),
252        err,
253    )]
254    async fn finish(
255        &mut self,
256        clock: &dyn Clock,
257        mut user_session: BrowserSession,
258    ) -> Result<BrowserSession, Self::Error> {
259        let finished_at = clock.now();
260        let res = sqlx::query!(
261            r#"
262                UPDATE user_sessions
263                SET finished_at = $1
264                WHERE user_session_id = $2
265            "#,
266            finished_at,
267            Uuid::from(user_session.id),
268        )
269        .traced()
270        .execute(&mut *self.conn)
271        .await?;
272
273        user_session.finished_at = Some(finished_at);
274
275        DatabaseError::ensure_affected_rows(&res, 1)?;
276
277        Ok(user_session)
278    }
279
280    #[tracing::instrument(
281        name = "db.browser_session.finish_bulk",
282        skip_all,
283        fields(
284            db.query.text,
285        ),
286        err,
287    )]
288    async fn finish_bulk(
289        &mut self,
290        clock: &dyn Clock,
291        filter: BrowserSessionFilter<'_>,
292    ) -> Result<usize, Self::Error> {
293        let finished_at = clock.now();
294        let (sql, arguments) = sea_query::Query::update()
295            .table(UserSessions::Table)
296            .value(UserSessions::FinishedAt, finished_at)
297            .apply_filter(filter)
298            .build_sqlx(PostgresQueryBuilder);
299
300        let res = sqlx::query_with(&sql, arguments)
301            .traced()
302            .execute(&mut *self.conn)
303            .await?;
304
305        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
306    }
307
308    #[tracing::instrument(
309        name = "db.browser_session.list",
310        skip_all,
311        fields(
312            db.query.text,
313        ),
314        err,
315    )]
316    async fn list(
317        &mut self,
318        filter: BrowserSessionFilter<'_>,
319        pagination: Pagination,
320    ) -> Result<Page<BrowserSession>, Self::Error> {
321        let (sql, arguments) = sea_query::Query::select()
322            .expr_as(
323                Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
324                SessionLookupIden::UserSessionId,
325            )
326            .expr_as(
327                Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
328                SessionLookupIden::UserSessionCreatedAt,
329            )
330            .expr_as(
331                Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
332                SessionLookupIden::UserSessionFinishedAt,
333            )
334            .expr_as(
335                Expr::col((UserSessions::Table, UserSessions::UserAgent)),
336                SessionLookupIden::UserSessionUserAgent,
337            )
338            .expr_as(
339                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
340                SessionLookupIden::UserSessionLastActiveAt,
341            )
342            .expr_as(
343                Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
344                SessionLookupIden::UserSessionLastActiveIp,
345            )
346            .expr_as(
347                Expr::col((Users::Table, Users::UserId)),
348                SessionLookupIden::UserId,
349            )
350            .expr_as(
351                Expr::col((Users::Table, Users::Username)),
352                SessionLookupIden::UserUsername,
353            )
354            .expr_as(
355                Expr::col((Users::Table, Users::CreatedAt)),
356                SessionLookupIden::UserCreatedAt,
357            )
358            .expr_as(
359                Expr::col((Users::Table, Users::LockedAt)),
360                SessionLookupIden::UserLockedAt,
361            )
362            .expr_as(
363                Expr::col((Users::Table, Users::DeactivatedAt)),
364                SessionLookupIden::UserDeactivatedAt,
365            )
366            .expr_as(
367                Expr::col((Users::Table, Users::CanRequestAdmin)),
368                SessionLookupIden::UserCanRequestAdmin,
369            )
370            .from(UserSessions::Table)
371            .inner_join(
372                Users::Table,
373                Expr::col((UserSessions::Table, UserSessions::UserId))
374                    .equals((Users::Table, Users::UserId)),
375            )
376            .apply_filter(filter)
377            .generate_pagination(
378                (UserSessions::Table, UserSessions::UserSessionId),
379                pagination,
380            )
381            .build_sqlx(PostgresQueryBuilder);
382
383        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
384            .traced()
385            .fetch_all(&mut *self.conn)
386            .await?;
387
388        let page = pagination
389            .process(edges)
390            .try_map(BrowserSession::try_from)?;
391
392        Ok(page)
393    }
394
395    #[tracing::instrument(
396        name = "db.browser_session.count",
397        skip_all,
398        fields(
399            db.query.text,
400        ),
401        err,
402    )]
403    async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
404        let (sql, arguments) = sea_query::Query::select()
405            .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
406            .from(UserSessions::Table)
407            .apply_filter(filter)
408            .build_sqlx(PostgresQueryBuilder);
409
410        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
411            .traced()
412            .fetch_one(&mut *self.conn)
413            .await?;
414
415        count
416            .try_into()
417            .map_err(DatabaseError::to_invalid_operation)
418    }
419
420    #[tracing::instrument(
421        name = "db.browser_session.authenticate_with_password",
422        skip_all,
423        fields(
424            db.query.text,
425            %user_session.id,
426            %user_password.id,
427            user_session_authentication.id,
428        ),
429        err,
430    )]
431    async fn authenticate_with_password(
432        &mut self,
433        rng: &mut (dyn RngCore + Send),
434        clock: &dyn Clock,
435        user_session: &BrowserSession,
436        user_password: &Password,
437    ) -> Result<Authentication, Self::Error> {
438        let created_at = clock.now();
439        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
440        tracing::Span::current().record(
441            "user_session_authentication.id",
442            tracing::field::display(id),
443        );
444
445        sqlx::query!(
446            r#"
447                INSERT INTO user_session_authentications
448                    (user_session_authentication_id, user_session_id, created_at, user_password_id)
449                VALUES ($1, $2, $3, $4)
450            "#,
451            Uuid::from(id),
452            Uuid::from(user_session.id),
453            created_at,
454            Uuid::from(user_password.id),
455        )
456        .traced()
457        .execute(&mut *self.conn)
458        .await?;
459
460        Ok(Authentication {
461            id,
462            created_at,
463            authentication_method: AuthenticationMethod::Password {
464                user_password_id: user_password.id,
465            },
466        })
467    }
468
469    #[tracing::instrument(
470        name = "db.browser_session.authenticate_with_upstream",
471        skip_all,
472        fields(
473            db.query.text,
474            %user_session.id,
475            %upstream_oauth_session.id,
476            user_session_authentication.id,
477        ),
478        err,
479    )]
480    async fn authenticate_with_upstream(
481        &mut self,
482        rng: &mut (dyn RngCore + Send),
483        clock: &dyn Clock,
484        user_session: &BrowserSession,
485        upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
486    ) -> Result<Authentication, Self::Error> {
487        let created_at = clock.now();
488        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
489        tracing::Span::current().record(
490            "user_session_authentication.id",
491            tracing::field::display(id),
492        );
493
494        sqlx::query!(
495            r#"
496                INSERT INTO user_session_authentications
497                    (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
498                VALUES ($1, $2, $3, $4)
499            "#,
500            Uuid::from(id),
501            Uuid::from(user_session.id),
502            created_at,
503            Uuid::from(upstream_oauth_session.id),
504        )
505        .traced()
506        .execute(&mut *self.conn)
507        .await?;
508
509        Ok(Authentication {
510            id,
511            created_at,
512            authentication_method: AuthenticationMethod::UpstreamOAuth2 {
513                upstream_oauth2_session_id: upstream_oauth_session.id,
514            },
515        })
516    }
517
518    #[tracing::instrument(
519        name = "db.browser_session.get_last_authentication",
520        skip_all,
521        fields(
522            db.query.text,
523            %user_session.id,
524        ),
525        err,
526    )]
527    async fn get_last_authentication(
528        &mut self,
529        user_session: &BrowserSession,
530    ) -> Result<Option<Authentication>, Self::Error> {
531        let authentication = sqlx::query_as!(
532            AuthenticationLookup,
533            r#"
534                SELECT user_session_authentication_id
535                     , created_at
536                     , user_password_id
537                     , upstream_oauth_authorization_session_id
538                FROM user_session_authentications
539                WHERE user_session_id = $1
540                ORDER BY created_at DESC
541                LIMIT 1
542            "#,
543            Uuid::from(user_session.id),
544        )
545        .traced()
546        .fetch_optional(&mut *self.conn)
547        .await?;
548
549        let Some(authentication) = authentication else {
550            return Ok(None);
551        };
552
553        let authentication = Authentication::try_from(authentication)?;
554        Ok(Some(authentication))
555    }
556
557    #[tracing::instrument(
558        name = "db.browser_session.record_batch_activity",
559        skip_all,
560        fields(
561            db.query.text,
562        ),
563        err,
564    )]
565    async fn record_batch_activity(
566        &mut self,
567        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
568    ) -> Result<(), Self::Error> {
569        let mut ids = Vec::with_capacity(activity.len());
570        let mut last_activities = Vec::with_capacity(activity.len());
571        let mut ips = Vec::with_capacity(activity.len());
572
573        for (id, last_activity, ip) in activity {
574            ids.push(Uuid::from(id));
575            last_activities.push(last_activity);
576            ips.push(ip);
577        }
578
579        let res = sqlx::query!(
580            r#"
581                UPDATE user_sessions
582                SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
583                  , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
584                FROM (
585                    SELECT *
586                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
587                        AS t(user_session_id, last_active_at, last_active_ip)
588                ) AS t
589                WHERE user_sessions.user_session_id = t.user_session_id
590            "#,
591            &ids,
592            &last_activities,
593            &ips as &[Option<IpAddr>],
594        )
595        .traced()
596        .execute(&mut *self.conn)
597        .await?;
598
599        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
600
601        Ok(())
602    }
603}