mas_storage/oauth2/
session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Client, Clock, Device, Session, User};
13use oauth2_types::scope::Scope;
14use rand_core::RngCore;
15use ulid::Ulid;
16
17use crate::{Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter};
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum OAuth2SessionState {
21    Active,
22    Finished,
23}
24
25impl OAuth2SessionState {
26    pub fn is_active(self) -> bool {
27        matches!(self, Self::Active)
28    }
29
30    pub fn is_finished(self) -> bool {
31        matches!(self, Self::Finished)
32    }
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
36pub enum ClientKind {
37    Static,
38    Dynamic,
39}
40
41impl ClientKind {
42    pub fn is_static(self) -> bool {
43        matches!(self, Self::Static)
44    }
45}
46
47/// Filter parameters for listing OAuth 2.0 sessions
48#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
49pub struct OAuth2SessionFilter<'a> {
50    user: Option<&'a User>,
51    any_user: Option<bool>,
52    browser_session: Option<&'a BrowserSession>,
53    browser_session_filter: Option<BrowserSessionFilter<'a>>,
54    device: Option<&'a Device>,
55    client: Option<&'a Client>,
56    client_kind: Option<ClientKind>,
57    state: Option<OAuth2SessionState>,
58    scope: Option<&'a Scope>,
59    last_active_before: Option<DateTime<Utc>>,
60    last_active_after: Option<DateTime<Utc>>,
61}
62
63impl<'a> OAuth2SessionFilter<'a> {
64    /// Create a new [`OAuth2SessionFilter`] with default values
65    #[must_use]
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// List sessions for a specific user
71    #[must_use]
72    pub fn for_user(mut self, user: &'a User) -> Self {
73        self.user = Some(user);
74        self
75    }
76
77    /// Get the user filter
78    ///
79    /// Returns [`None`] if no user filter was set
80    #[must_use]
81    pub fn user(&self) -> Option<&'a User> {
82        self.user
83    }
84
85    /// List sessions which belong to any user
86    #[must_use]
87    pub fn for_any_user(mut self) -> Self {
88        self.any_user = Some(true);
89        self
90    }
91
92    /// List sessions which belong to no user
93    #[must_use]
94    pub fn for_no_user(mut self) -> Self {
95        self.any_user = Some(false);
96        self
97    }
98
99    /// Get the 'any user' filter
100    ///
101    /// Returns [`None`] if no 'any user' filter was set
102    #[must_use]
103    pub fn any_user(&self) -> Option<bool> {
104        self.any_user
105    }
106
107    /// List sessions started by a specific browser session
108    #[must_use]
109    pub fn for_browser_session(mut self, browser_session: &'a BrowserSession) -> Self {
110        self.browser_session = Some(browser_session);
111        self
112    }
113
114    /// List sessions started by a set of browser sessions
115    #[must_use]
116    pub fn for_browser_sessions(
117        mut self,
118        browser_session_filter: BrowserSessionFilter<'a>,
119    ) -> Self {
120        self.browser_session_filter = Some(browser_session_filter);
121        self
122    }
123
124    /// Get the browser session filter
125    ///
126    /// Returns [`None`] if no browser session filter was set
127    #[must_use]
128    pub fn browser_session(&self) -> Option<&'a BrowserSession> {
129        self.browser_session
130    }
131
132    /// Get the browser sessions filter
133    ///
134    /// Returns [`None`] if no browser session filter was set
135    #[must_use]
136    pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
137        self.browser_session_filter
138    }
139
140    /// List sessions for a specific client
141    #[must_use]
142    pub fn for_client(mut self, client: &'a Client) -> Self {
143        self.client = Some(client);
144        self
145    }
146
147    /// Get the client filter
148    ///
149    /// Returns [`None`] if no client filter was set
150    #[must_use]
151    pub fn client(&self) -> Option<&'a Client> {
152        self.client
153    }
154
155    /// List only static clients
156    #[must_use]
157    pub fn only_static_clients(mut self) -> Self {
158        self.client_kind = Some(ClientKind::Static);
159        self
160    }
161
162    /// List only dynamic clients
163    #[must_use]
164    pub fn only_dynamic_clients(mut self) -> Self {
165        self.client_kind = Some(ClientKind::Dynamic);
166        self
167    }
168
169    /// Get the client kind filter
170    ///
171    /// Returns [`None`] if no client kind filter was set
172    #[must_use]
173    pub fn client_kind(&self) -> Option<ClientKind> {
174        self.client_kind
175    }
176
177    /// Only return sessions with a last active time before the given time
178    #[must_use]
179    pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
180        self.last_active_before = Some(last_active_before);
181        self
182    }
183
184    /// Only return sessions with a last active time after the given time
185    #[must_use]
186    pub fn with_last_active_after(mut self, last_active_after: DateTime<Utc>) -> Self {
187        self.last_active_after = Some(last_active_after);
188        self
189    }
190
191    /// Get the last active before filter
192    ///
193    /// Returns [`None`] if no client filter was set
194    #[must_use]
195    pub fn last_active_before(&self) -> Option<DateTime<Utc>> {
196        self.last_active_before
197    }
198
199    /// Get the last active after filter
200    ///
201    /// Returns [`None`] if no client filter was set
202    #[must_use]
203    pub fn last_active_after(&self) -> Option<DateTime<Utc>> {
204        self.last_active_after
205    }
206
207    /// Only return active sessions
208    #[must_use]
209    pub fn active_only(mut self) -> Self {
210        self.state = Some(OAuth2SessionState::Active);
211        self
212    }
213
214    /// Only return finished sessions
215    #[must_use]
216    pub fn finished_only(mut self) -> Self {
217        self.state = Some(OAuth2SessionState::Finished);
218        self
219    }
220
221    /// Get the state filter
222    ///
223    /// Returns [`None`] if no state filter was set
224    #[must_use]
225    pub fn state(&self) -> Option<OAuth2SessionState> {
226        self.state
227    }
228
229    /// Only return sessions with the given scope
230    #[must_use]
231    pub fn with_scope(mut self, scope: &'a Scope) -> Self {
232        self.scope = Some(scope);
233        self
234    }
235
236    /// Get the scope filter
237    ///
238    /// Returns [`None`] if no scope filter was set
239    #[must_use]
240    pub fn scope(&self) -> Option<&'a Scope> {
241        self.scope
242    }
243
244    /// Only return sessions that have the given device in their scope
245    #[must_use]
246    pub fn for_device(mut self, device: &'a Device) -> Self {
247        self.device = Some(device);
248        self
249    }
250
251    /// Get the device filter
252    ///
253    /// Returns [`None`] if no device filter was set
254    #[must_use]
255    pub fn device(&self) -> Option<&'a Device> {
256        self.device
257    }
258}
259
260/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
261/// saved in the storage backend
262#[async_trait]
263pub trait OAuth2SessionRepository: Send + Sync {
264    /// The error type returned by the repository
265    type Error;
266
267    /// Lookup an [`Session`] by its ID
268    ///
269    /// Returns `None` if no [`Session`] was found
270    ///
271    /// # Parameters
272    ///
273    /// * `id`: The ID of the [`Session`] to lookup
274    ///
275    /// # Errors
276    ///
277    /// Returns [`Self::Error`] if the underlying repository fails
278    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
279
280    /// Create a new [`Session`] with the given parameters
281    ///
282    /// Returns the newly created [`Session`]
283    ///
284    /// # Parameters
285    ///
286    /// * `rng`: The random number generator to use
287    /// * `clock`: The clock used to generate timestamps
288    /// * `client`: The [`Client`] which created the [`Session`]
289    /// * `user`: The [`User`] for which the session should be created, if any
290    /// * `user_session`: The [`BrowserSession`] of the user which completed the
291    ///   authorization, if any
292    /// * `scope`: The [`Scope`] of the [`Session`]
293    ///
294    /// # Errors
295    ///
296    /// Returns [`Self::Error`] if the underlying repository fails
297    async fn add(
298        &mut self,
299        rng: &mut (dyn RngCore + Send),
300        clock: &dyn Clock,
301        client: &Client,
302        user: Option<&User>,
303        user_session: Option<&BrowserSession>,
304        scope: Scope,
305    ) -> Result<Session, Self::Error>;
306
307    /// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
308    ///
309    /// Returns the newly created [`Session`]
310    ///
311    /// # Parameters
312    ///
313    /// * `rng`: The random number generator to use
314    /// * `clock`: The clock used to generate timestamps
315    /// * `client`: The [`Client`] which created the [`Session`]
316    /// * `user_session`: The [`BrowserSession`] of the user which completed the
317    ///   authorization
318    /// * `scope`: The [`Scope`] of the [`Session`]
319    ///
320    /// # Errors
321    ///
322    /// Returns [`Self::Error`] if the underlying repository fails
323    async fn add_from_browser_session(
324        &mut self,
325        rng: &mut (dyn RngCore + Send),
326        clock: &dyn Clock,
327        client: &Client,
328        user_session: &BrowserSession,
329        scope: Scope,
330    ) -> Result<Session, Self::Error> {
331        self.add(
332            rng,
333            clock,
334            client,
335            Some(&user_session.user),
336            Some(user_session),
337            scope,
338        )
339        .await
340    }
341
342    /// Create a new [`Session`] for a [`Client`] using the client credentials
343    /// flow
344    ///
345    /// Returns the newly created [`Session`]
346    ///
347    /// # Parameters
348    ///
349    /// * `rng`: The random number generator to use
350    /// * `clock`: The clock used to generate timestamps
351    /// * `client`: The [`Client`] which created the [`Session`]
352    /// * `scope`: The [`Scope`] of the [`Session`]
353    ///
354    /// # Errors
355    ///
356    /// Returns [`Self::Error`] if the underlying repository fails
357    async fn add_from_client_credentials(
358        &mut self,
359        rng: &mut (dyn RngCore + Send),
360        clock: &dyn Clock,
361        client: &Client,
362        scope: Scope,
363    ) -> Result<Session, Self::Error> {
364        self.add(rng, clock, client, None, None, scope).await
365    }
366
367    /// Mark a [`Session`] as finished
368    ///
369    /// Returns the updated [`Session`]
370    ///
371    /// # Parameters
372    ///
373    /// * `clock`: The clock used to generate timestamps
374    /// * `session`: The [`Session`] to mark as finished
375    ///
376    /// # Errors
377    ///
378    /// Returns [`Self::Error`] if the underlying repository fails
379    async fn finish(&mut self, clock: &dyn Clock, session: Session)
380    -> Result<Session, Self::Error>;
381
382    /// Mark all the [`Session`] matching the given filter as finished
383    ///
384    /// Returns the number of sessions affected
385    ///
386    /// # Parameters
387    ///
388    /// * `clock`: The clock used to generate timestamps
389    /// * `filter`: The filter parameters
390    ///
391    /// # Errors
392    ///
393    /// Returns [`Self::Error`] if the underlying repository fails
394    async fn finish_bulk(
395        &mut self,
396        clock: &dyn Clock,
397        filter: OAuth2SessionFilter<'_>,
398    ) -> Result<usize, Self::Error>;
399
400    /// List [`Session`]s matching the given filter and pagination parameters
401    ///
402    /// # Parameters
403    ///
404    /// * `filter`: The filter parameters
405    /// * `pagination`: The pagination parameters
406    ///
407    /// # Errors
408    ///
409    /// Returns [`Self::Error`] if the underlying repository fails
410    async fn list(
411        &mut self,
412        filter: OAuth2SessionFilter<'_>,
413        pagination: Pagination,
414    ) -> Result<Page<Session>, Self::Error>;
415
416    /// Count [`Session`]s matching the given filter
417    ///
418    /// # Parameters
419    ///
420    /// * `filter`: The filter parameters
421    ///
422    /// # Errors
423    ///
424    /// Returns [`Self::Error`] if the underlying repository fails
425    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
426
427    /// Record a batch of [`Session`] activity
428    ///
429    /// # Parameters
430    ///
431    /// * `activity`: A list of tuples containing the session ID, the last
432    ///   activity timestamp and the IP address of the client
433    ///
434    /// # Errors
435    ///
436    /// Returns [`Self::Error`] if the underlying repository fails
437    async fn record_batch_activity(
438        &mut self,
439        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
440    ) -> Result<(), Self::Error>;
441
442    /// Record the user agent of a [`Session`]
443    ///
444    /// # Parameters
445    ///
446    /// * `session`: The [`Session`] to record the user agent for
447    /// * `user_agent`: The user agent to record
448    async fn record_user_agent(
449        &mut self,
450        session: Session,
451        user_agent: String,
452    ) -> Result<Session, Self::Error>;
453
454    /// Set the human name of a [`Session`]
455    ///
456    /// # Parameters
457    ///
458    /// * `session`: The [`Session`] to set the human name for
459    /// * `human_name`: The human name to set
460    async fn set_human_name(
461        &mut self,
462        session: Session,
463        human_name: Option<String>,
464    ) -> Result<Session, Self::Error>;
465
466    /// Cleanup finished [`Session`]s
467    ///
468    /// Deletes sessions finished between `since` and `until`. Returns the
469    /// number of deleted sessions and the timestamp of the last deleted
470    /// session for pagination.
471    ///
472    /// # Parameters
473    ///
474    /// * `since`: The earliest finish time to delete (exclusive). If `None`,
475    ///   starts from the beginning.
476    /// * `until`: The latest finish time to delete (exclusive)
477    /// * `limit`: Maximum number of sessions to delete in this batch
478    ///
479    /// # Errors
480    ///
481    /// Returns [`Self::Error`] if the underlying repository fails
482    async fn cleanup_finished(
483        &mut self,
484        since: Option<DateTime<Utc>>,
485        until: DateTime<Utc>,
486        limit: usize,
487    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
488
489    /// Clear IP addresses from sessions inactive since the threshold
490    ///
491    /// Sets `last_active_ip` to `NULL` for sessions where `last_active_at` is
492    /// before the threshold. Returns the number of sessions affected and the
493    /// last `last_active_at` timestamp processed for pagination.
494    ///
495    /// # Parameters
496    ///
497    /// * `since`: Only process sessions with `last_active_at` at or after this
498    ///   timestamp (exclusive). If `None`, starts from the beginning.
499    /// * `threshold`: Clear IPs for sessions with `last_active_at` before this
500    ///   time
501    /// * `limit`: Maximum number of sessions to update in this batch
502    ///
503    /// # Errors
504    ///
505    /// Returns [`Self::Error`] if the underlying repository fails
506    async fn cleanup_inactive_ips(
507        &mut self,
508        since: Option<DateTime<Utc>>,
509        threshold: DateTime<Utc>,
510        limit: usize,
511    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
512}
513
514repository_impl!(OAuth2SessionRepository:
515    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
516
517    async fn add(
518        &mut self,
519        rng: &mut (dyn RngCore + Send),
520        clock: &dyn Clock,
521        client: &Client,
522        user: Option<&User>,
523        user_session: Option<&BrowserSession>,
524        scope: Scope,
525    ) -> Result<Session, Self::Error>;
526
527    async fn add_from_browser_session(
528        &mut self,
529        rng: &mut (dyn RngCore + Send),
530        clock: &dyn Clock,
531        client: &Client,
532        user_session: &BrowserSession,
533        scope: Scope,
534    ) -> Result<Session, Self::Error>;
535
536    async fn add_from_client_credentials(
537        &mut self,
538        rng: &mut (dyn RngCore + Send),
539        clock: &dyn Clock,
540        client: &Client,
541        scope: Scope,
542    ) -> Result<Session, Self::Error>;
543
544    async fn finish(&mut self, clock: &dyn Clock, session: Session)
545        -> Result<Session, Self::Error>;
546
547    async fn finish_bulk(
548        &mut self,
549        clock: &dyn Clock,
550        filter: OAuth2SessionFilter<'_>,
551    ) -> Result<usize, Self::Error>;
552
553    async fn list(
554        &mut self,
555        filter: OAuth2SessionFilter<'_>,
556        pagination: Pagination,
557    ) -> Result<Page<Session>, Self::Error>;
558
559    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
560
561    async fn record_batch_activity(
562        &mut self,
563        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
564    ) -> Result<(), Self::Error>;
565
566    async fn record_user_agent(
567        &mut self,
568        session: Session,
569        user_agent: String,
570    ) -> Result<Session, Self::Error>;
571
572    async fn set_human_name(
573        &mut self,
574        session: Session,
575        human_name: Option<String>,
576    ) -> Result<Session, Self::Error>;
577
578    async fn cleanup_finished(
579        &mut self,
580        since: Option<DateTime<Utc>>,
581        until: DateTime<Utc>,
582        limit: usize,
583    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
584
585    async fn cleanup_inactive_ips(
586        &mut self,
587        since: Option<DateTime<Utc>>,
588        threshold: DateTime<Utc>,
589        limit: usize,
590    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
591);