mas_storage_pg/oauth2/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User, UserAgent};
12use mas_storage::{
13    Clock, Page, Pagination,
14    oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError, DatabaseInconsistencyError,
26    filter::{Filter, StatementExt},
27    iden::{OAuth2Clients, OAuth2Sessions},
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
33pub struct PgOAuth2SessionRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgOAuth2SessionRepository<'c> {
38    /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct OAuthSessionLookup {
48    oauth2_session_id: Uuid,
49    user_id: Option<Uuid>,
50    user_session_id: Option<Uuid>,
51    oauth2_client_id: Uuid,
52    scope_list: Vec<String>,
53    created_at: DateTime<Utc>,
54    finished_at: Option<DateTime<Utc>>,
55    user_agent: Option<String>,
56    last_active_at: Option<DateTime<Utc>>,
57    last_active_ip: Option<IpAddr>,
58}
59
60impl TryFrom<OAuthSessionLookup> for Session {
61    type Error = DatabaseInconsistencyError;
62
63    fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
64        let id = Ulid::from(value.oauth2_session_id);
65        let scope: Result<Scope, _> = value
66            .scope_list
67            .iter()
68            .map(|s| s.parse::<ScopeToken>())
69            .collect();
70        let scope = scope.map_err(|e| {
71            DatabaseInconsistencyError::on("oauth2_sessions")
72                .column("scope")
73                .row(id)
74                .source(e)
75        })?;
76
77        let state = match value.finished_at {
78            None => SessionState::Valid,
79            Some(finished_at) => SessionState::Finished { finished_at },
80        };
81
82        Ok(Session {
83            id,
84            state,
85            created_at: value.created_at,
86            client_id: value.oauth2_client_id.into(),
87            user_id: value.user_id.map(Ulid::from),
88            user_session_id: value.user_session_id.map(Ulid::from),
89            scope,
90            user_agent: value.user_agent.map(UserAgent::parse),
91            last_active_at: value.last_active_at,
92            last_active_ip: value.last_active_ip,
93        })
94    }
95}
96
97impl Filter for OAuth2SessionFilter<'_> {
98    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
99        sea_query::Condition::all()
100            .add_option(self.user().map(|user| {
101                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
102            }))
103            .add_option(self.client().map(|client| {
104                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
105                    .eq(Uuid::from(client.id))
106            }))
107            .add_option(self.client_kind().map(|client_kind| {
108                // This builds either a:
109                // `WHERE oauth2_client_id = ANY(...)`
110                // or a `WHERE oauth2_client_id <> ALL(...)`
111                let static_clients = Query::select()
112                    .expr(Expr::col((
113                        OAuth2Clients::Table,
114                        OAuth2Clients::OAuth2ClientId,
115                    )))
116                    .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
117                    .from(OAuth2Clients::Table)
118                    .take();
119                if client_kind.is_static() {
120                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
121                        .eq(Expr::any(static_clients))
122                } else {
123                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
124                        .ne(Expr::all(static_clients))
125                }
126            }))
127            .add_option(self.device().map(|device| {
128                if let Ok(scope_token) = device.to_scope_token() {
129                    Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
130                        OAuth2Sessions::Table,
131                        OAuth2Sessions::ScopeList,
132                    ))))
133                } else {
134                    // If the device ID can't be encoded as a scope token, match no rows
135                    Expr::val(false).into()
136                }
137            }))
138            .add_option(self.browser_session().map(|browser_session| {
139                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
140                    .eq(Uuid::from(browser_session.id))
141            }))
142            .add_option(self.state().map(|state| {
143                if state.is_active() {
144                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
145                } else {
146                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
147                }
148            }))
149            .add_option(self.scope().map(|scope| {
150                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
151                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
152            }))
153            .add_option(self.any_user().map(|any_user| {
154                if any_user {
155                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
156                } else {
157                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
158                }
159            }))
160            .add_option(self.last_active_after().map(|last_active_after| {
161                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
162                    .gt(last_active_after)
163            }))
164            .add_option(self.last_active_before().map(|last_active_before| {
165                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
166                    .lt(last_active_before)
167            }))
168    }
169}
170
171#[async_trait]
172impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
173    type Error = DatabaseError;
174
175    #[tracing::instrument(
176        name = "db.oauth2_session.lookup",
177        skip_all,
178        fields(
179            db.query.text,
180            session.id = %id,
181        ),
182        err,
183    )]
184    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
185        let res = sqlx::query_as!(
186            OAuthSessionLookup,
187            r#"
188                SELECT oauth2_session_id
189                     , user_id
190                     , user_session_id
191                     , oauth2_client_id
192                     , scope_list
193                     , created_at
194                     , finished_at
195                     , user_agent
196                     , last_active_at
197                     , last_active_ip as "last_active_ip: IpAddr"
198                FROM oauth2_sessions
199
200                WHERE oauth2_session_id = $1
201            "#,
202            Uuid::from(id),
203        )
204        .traced()
205        .fetch_optional(&mut *self.conn)
206        .await?;
207
208        let Some(session) = res else { return Ok(None) };
209
210        Ok(Some(session.try_into()?))
211    }
212
213    #[tracing::instrument(
214        name = "db.oauth2_session.add",
215        skip_all,
216        fields(
217            db.query.text,
218            %client.id,
219            session.id,
220            session.scope = %scope,
221        ),
222        err,
223    )]
224    async fn add(
225        &mut self,
226        rng: &mut (dyn RngCore + Send),
227        clock: &dyn Clock,
228        client: &Client,
229        user: Option<&User>,
230        user_session: Option<&BrowserSession>,
231        scope: Scope,
232    ) -> Result<Session, Self::Error> {
233        let created_at = clock.now();
234        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
235        tracing::Span::current().record("session.id", tracing::field::display(id));
236
237        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
238
239        sqlx::query!(
240            r#"
241                INSERT INTO oauth2_sessions
242                    ( oauth2_session_id
243                    , user_id
244                    , user_session_id
245                    , oauth2_client_id
246                    , scope_list
247                    , created_at
248                    )
249                VALUES ($1, $2, $3, $4, $5, $6)
250            "#,
251            Uuid::from(id),
252            user.map(|u| Uuid::from(u.id)),
253            user_session.map(|s| Uuid::from(s.id)),
254            Uuid::from(client.id),
255            &scope_list,
256            created_at,
257        )
258        .traced()
259        .execute(&mut *self.conn)
260        .await?;
261
262        Ok(Session {
263            id,
264            state: SessionState::Valid,
265            created_at,
266            user_id: user.map(|u| u.id),
267            user_session_id: user_session.map(|s| s.id),
268            client_id: client.id,
269            scope,
270            user_agent: None,
271            last_active_at: None,
272            last_active_ip: None,
273        })
274    }
275
276    #[tracing::instrument(
277        name = "db.oauth2_session.finish_bulk",
278        skip_all,
279        fields(
280            db.query.text,
281        ),
282        err,
283    )]
284    async fn finish_bulk(
285        &mut self,
286        clock: &dyn Clock,
287        filter: OAuth2SessionFilter<'_>,
288    ) -> Result<usize, Self::Error> {
289        let finished_at = clock.now();
290        let (sql, arguments) = Query::update()
291            .table(OAuth2Sessions::Table)
292            .value(OAuth2Sessions::FinishedAt, finished_at)
293            .apply_filter(filter)
294            .build_sqlx(PostgresQueryBuilder);
295
296        let res = sqlx::query_with(&sql, arguments)
297            .traced()
298            .execute(&mut *self.conn)
299            .await?;
300
301        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
302    }
303
304    #[tracing::instrument(
305        name = "db.oauth2_session.finish",
306        skip_all,
307        fields(
308            db.query.text,
309            %session.id,
310            %session.scope,
311            client.id = %session.client_id,
312        ),
313        err,
314    )]
315    async fn finish(
316        &mut self,
317        clock: &dyn Clock,
318        session: Session,
319    ) -> Result<Session, Self::Error> {
320        let finished_at = clock.now();
321        let res = sqlx::query!(
322            r#"
323                UPDATE oauth2_sessions
324                SET finished_at = $2
325                WHERE oauth2_session_id = $1
326            "#,
327            Uuid::from(session.id),
328            finished_at,
329        )
330        .traced()
331        .execute(&mut *self.conn)
332        .await?;
333
334        DatabaseError::ensure_affected_rows(&res, 1)?;
335
336        session
337            .finish(finished_at)
338            .map_err(DatabaseError::to_invalid_operation)
339    }
340
341    #[tracing::instrument(
342        name = "db.oauth2_session.list",
343        skip_all,
344        fields(
345            db.query.text,
346        ),
347        err,
348    )]
349    async fn list(
350        &mut self,
351        filter: OAuth2SessionFilter<'_>,
352        pagination: Pagination,
353    ) -> Result<Page<Session>, Self::Error> {
354        let (sql, arguments) = Query::select()
355            .expr_as(
356                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
357                OAuthSessionLookupIden::Oauth2SessionId,
358            )
359            .expr_as(
360                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
361                OAuthSessionLookupIden::UserId,
362            )
363            .expr_as(
364                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
365                OAuthSessionLookupIden::UserSessionId,
366            )
367            .expr_as(
368                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
369                OAuthSessionLookupIden::Oauth2ClientId,
370            )
371            .expr_as(
372                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
373                OAuthSessionLookupIden::ScopeList,
374            )
375            .expr_as(
376                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
377                OAuthSessionLookupIden::CreatedAt,
378            )
379            .expr_as(
380                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
381                OAuthSessionLookupIden::FinishedAt,
382            )
383            .expr_as(
384                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
385                OAuthSessionLookupIden::UserAgent,
386            )
387            .expr_as(
388                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
389                OAuthSessionLookupIden::LastActiveAt,
390            )
391            .expr_as(
392                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
393                OAuthSessionLookupIden::LastActiveIp,
394            )
395            .from(OAuth2Sessions::Table)
396            .apply_filter(filter)
397            .generate_pagination(
398                (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
399                pagination,
400            )
401            .build_sqlx(PostgresQueryBuilder);
402
403        let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
404            .traced()
405            .fetch_all(&mut *self.conn)
406            .await?;
407
408        let page = pagination.process(edges).try_map(Session::try_from)?;
409
410        Ok(page)
411    }
412
413    #[tracing::instrument(
414        name = "db.oauth2_session.count",
415        skip_all,
416        fields(
417            db.query.text,
418        ),
419        err,
420    )]
421    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
422        let (sql, arguments) = Query::select()
423            .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
424            .from(OAuth2Sessions::Table)
425            .apply_filter(filter)
426            .build_sqlx(PostgresQueryBuilder);
427
428        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
429            .traced()
430            .fetch_one(&mut *self.conn)
431            .await?;
432
433        count
434            .try_into()
435            .map_err(DatabaseError::to_invalid_operation)
436    }
437
438    #[tracing::instrument(
439        name = "db.oauth2_session.record_batch_activity",
440        skip_all,
441        fields(
442            db.query.text,
443        ),
444        err,
445    )]
446    async fn record_batch_activity(
447        &mut self,
448        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
449    ) -> Result<(), Self::Error> {
450        let mut ids = Vec::with_capacity(activity.len());
451        let mut last_activities = Vec::with_capacity(activity.len());
452        let mut ips = Vec::with_capacity(activity.len());
453
454        for (id, last_activity, ip) in activity {
455            ids.push(Uuid::from(id));
456            last_activities.push(last_activity);
457            ips.push(ip);
458        }
459
460        let res = sqlx::query!(
461            r#"
462                UPDATE oauth2_sessions
463                SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
464                  , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
465                FROM (
466                    SELECT *
467                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
468                        AS t(oauth2_session_id, last_active_at, last_active_ip)
469                ) AS t
470                WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
471            "#,
472            &ids,
473            &last_activities,
474            &ips as &[Option<IpAddr>],
475        )
476        .traced()
477        .execute(&mut *self.conn)
478        .await?;
479
480        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
481
482        Ok(())
483    }
484
485    #[tracing::instrument(
486        name = "db.oauth2_session.record_user_agent",
487        skip_all,
488        fields(
489            db.query.text,
490            %session.id,
491            %session.scope,
492            client.id = %session.client_id,
493            session.user_agent = %user_agent.raw,
494        ),
495        err,
496    )]
497    async fn record_user_agent(
498        &mut self,
499        mut session: Session,
500        user_agent: UserAgent,
501    ) -> Result<Session, Self::Error> {
502        let res = sqlx::query!(
503            r#"
504                UPDATE oauth2_sessions
505                SET user_agent = $2
506                WHERE oauth2_session_id = $1
507            "#,
508            Uuid::from(session.id),
509            &*user_agent,
510        )
511        .traced()
512        .execute(&mut *self.conn)
513        .await?;
514
515        session.user_agent = Some(user_agent);
516
517        DatabaseError::ensure_affected_rows(&res, 1)?;
518
519        Ok(session)
520    }
521}