mas_storage/oauth2/session.rs
1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Client, Clock, Device, Session, User};
13use oauth2_types::scope::Scope;
14use rand_core::RngCore;
15use ulid::Ulid;
16
17use crate::{Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter};
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum OAuth2SessionState {
21 Active,
22 Finished,
23}
24
25impl OAuth2SessionState {
26 pub fn is_active(self) -> bool {
27 matches!(self, Self::Active)
28 }
29
30 pub fn is_finished(self) -> bool {
31 matches!(self, Self::Finished)
32 }
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
36pub enum ClientKind {
37 Static,
38 Dynamic,
39}
40
41impl ClientKind {
42 pub fn is_static(self) -> bool {
43 matches!(self, Self::Static)
44 }
45}
46
47/// Filter parameters for listing OAuth 2.0 sessions
48#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
49pub struct OAuth2SessionFilter<'a> {
50 user: Option<&'a User>,
51 any_user: Option<bool>,
52 browser_session: Option<&'a BrowserSession>,
53 browser_session_filter: Option<BrowserSessionFilter<'a>>,
54 device: Option<&'a Device>,
55 client: Option<&'a Client>,
56 client_kind: Option<ClientKind>,
57 state: Option<OAuth2SessionState>,
58 scope: Option<&'a Scope>,
59 last_active_before: Option<DateTime<Utc>>,
60 last_active_after: Option<DateTime<Utc>>,
61}
62
63impl<'a> OAuth2SessionFilter<'a> {
64 /// Create a new [`OAuth2SessionFilter`] with default values
65 #[must_use]
66 pub fn new() -> Self {
67 Self::default()
68 }
69
70 /// List sessions for a specific user
71 #[must_use]
72 pub fn for_user(mut self, user: &'a User) -> Self {
73 self.user = Some(user);
74 self
75 }
76
77 /// Get the user filter
78 ///
79 /// Returns [`None`] if no user filter was set
80 #[must_use]
81 pub fn user(&self) -> Option<&'a User> {
82 self.user
83 }
84
85 /// List sessions which belong to any user
86 #[must_use]
87 pub fn for_any_user(mut self) -> Self {
88 self.any_user = Some(true);
89 self
90 }
91
92 /// List sessions which belong to no user
93 #[must_use]
94 pub fn for_no_user(mut self) -> Self {
95 self.any_user = Some(false);
96 self
97 }
98
99 /// Get the 'any user' filter
100 ///
101 /// Returns [`None`] if no 'any user' filter was set
102 #[must_use]
103 pub fn any_user(&self) -> Option<bool> {
104 self.any_user
105 }
106
107 /// List sessions started by a specific browser session
108 #[must_use]
109 pub fn for_browser_session(mut self, browser_session: &'a BrowserSession) -> Self {
110 self.browser_session = Some(browser_session);
111 self
112 }
113
114 /// List sessions started by a set of browser sessions
115 #[must_use]
116 pub fn for_browser_sessions(
117 mut self,
118 browser_session_filter: BrowserSessionFilter<'a>,
119 ) -> Self {
120 self.browser_session_filter = Some(browser_session_filter);
121 self
122 }
123
124 /// Get the browser session filter
125 ///
126 /// Returns [`None`] if no browser session filter was set
127 #[must_use]
128 pub fn browser_session(&self) -> Option<&'a BrowserSession> {
129 self.browser_session
130 }
131
132 /// Get the browser sessions filter
133 ///
134 /// Returns [`None`] if no browser session filter was set
135 #[must_use]
136 pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
137 self.browser_session_filter
138 }
139
140 /// List sessions for a specific client
141 #[must_use]
142 pub fn for_client(mut self, client: &'a Client) -> Self {
143 self.client = Some(client);
144 self
145 }
146
147 /// Get the client filter
148 ///
149 /// Returns [`None`] if no client filter was set
150 #[must_use]
151 pub fn client(&self) -> Option<&'a Client> {
152 self.client
153 }
154
155 /// List only static clients
156 #[must_use]
157 pub fn only_static_clients(mut self) -> Self {
158 self.client_kind = Some(ClientKind::Static);
159 self
160 }
161
162 /// List only dynamic clients
163 #[must_use]
164 pub fn only_dynamic_clients(mut self) -> Self {
165 self.client_kind = Some(ClientKind::Dynamic);
166 self
167 }
168
169 /// Get the client kind filter
170 ///
171 /// Returns [`None`] if no client kind filter was set
172 #[must_use]
173 pub fn client_kind(&self) -> Option<ClientKind> {
174 self.client_kind
175 }
176
177 /// Only return sessions with a last active time before the given time
178 #[must_use]
179 pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
180 self.last_active_before = Some(last_active_before);
181 self
182 }
183
184 /// Only return sessions with a last active time after the given time
185 #[must_use]
186 pub fn with_last_active_after(mut self, last_active_after: DateTime<Utc>) -> Self {
187 self.last_active_after = Some(last_active_after);
188 self
189 }
190
191 /// Get the last active before filter
192 ///
193 /// Returns [`None`] if no client filter was set
194 #[must_use]
195 pub fn last_active_before(&self) -> Option<DateTime<Utc>> {
196 self.last_active_before
197 }
198
199 /// Get the last active after filter
200 ///
201 /// Returns [`None`] if no client filter was set
202 #[must_use]
203 pub fn last_active_after(&self) -> Option<DateTime<Utc>> {
204 self.last_active_after
205 }
206
207 /// Only return active sessions
208 #[must_use]
209 pub fn active_only(mut self) -> Self {
210 self.state = Some(OAuth2SessionState::Active);
211 self
212 }
213
214 /// Only return finished sessions
215 #[must_use]
216 pub fn finished_only(mut self) -> Self {
217 self.state = Some(OAuth2SessionState::Finished);
218 self
219 }
220
221 /// Get the state filter
222 ///
223 /// Returns [`None`] if no state filter was set
224 #[must_use]
225 pub fn state(&self) -> Option<OAuth2SessionState> {
226 self.state
227 }
228
229 /// Only return sessions with the given scope
230 #[must_use]
231 pub fn with_scope(mut self, scope: &'a Scope) -> Self {
232 self.scope = Some(scope);
233 self
234 }
235
236 /// Get the scope filter
237 ///
238 /// Returns [`None`] if no scope filter was set
239 #[must_use]
240 pub fn scope(&self) -> Option<&'a Scope> {
241 self.scope
242 }
243
244 /// Only return sessions that have the given device in their scope
245 #[must_use]
246 pub fn for_device(mut self, device: &'a Device) -> Self {
247 self.device = Some(device);
248 self
249 }
250
251 /// Get the device filter
252 ///
253 /// Returns [`None`] if no device filter was set
254 #[must_use]
255 pub fn device(&self) -> Option<&'a Device> {
256 self.device
257 }
258}
259
260/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
261/// saved in the storage backend
262#[async_trait]
263pub trait OAuth2SessionRepository: Send + Sync {
264 /// The error type returned by the repository
265 type Error;
266
267 /// Lookup an [`Session`] by its ID
268 ///
269 /// Returns `None` if no [`Session`] was found
270 ///
271 /// # Parameters
272 ///
273 /// * `id`: The ID of the [`Session`] to lookup
274 ///
275 /// # Errors
276 ///
277 /// Returns [`Self::Error`] if the underlying repository fails
278 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
279
280 /// Create a new [`Session`] with the given parameters
281 ///
282 /// Returns the newly created [`Session`]
283 ///
284 /// # Parameters
285 ///
286 /// * `rng`: The random number generator to use
287 /// * `clock`: The clock used to generate timestamps
288 /// * `client`: The [`Client`] which created the [`Session`]
289 /// * `user`: The [`User`] for which the session should be created, if any
290 /// * `user_session`: The [`BrowserSession`] of the user which completed the
291 /// authorization, if any
292 /// * `scope`: The [`Scope`] of the [`Session`]
293 ///
294 /// # Errors
295 ///
296 /// Returns [`Self::Error`] if the underlying repository fails
297 async fn add(
298 &mut self,
299 rng: &mut (dyn RngCore + Send),
300 clock: &dyn Clock,
301 client: &Client,
302 user: Option<&User>,
303 user_session: Option<&BrowserSession>,
304 scope: Scope,
305 ) -> Result<Session, Self::Error>;
306
307 /// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
308 ///
309 /// Returns the newly created [`Session`]
310 ///
311 /// # Parameters
312 ///
313 /// * `rng`: The random number generator to use
314 /// * `clock`: The clock used to generate timestamps
315 /// * `client`: The [`Client`] which created the [`Session`]
316 /// * `user_session`: The [`BrowserSession`] of the user which completed the
317 /// authorization
318 /// * `scope`: The [`Scope`] of the [`Session`]
319 ///
320 /// # Errors
321 ///
322 /// Returns [`Self::Error`] if the underlying repository fails
323 async fn add_from_browser_session(
324 &mut self,
325 rng: &mut (dyn RngCore + Send),
326 clock: &dyn Clock,
327 client: &Client,
328 user_session: &BrowserSession,
329 scope: Scope,
330 ) -> Result<Session, Self::Error> {
331 self.add(
332 rng,
333 clock,
334 client,
335 Some(&user_session.user),
336 Some(user_session),
337 scope,
338 )
339 .await
340 }
341
342 /// Create a new [`Session`] for a [`Client`] using the client credentials
343 /// flow
344 ///
345 /// Returns the newly created [`Session`]
346 ///
347 /// # Parameters
348 ///
349 /// * `rng`: The random number generator to use
350 /// * `clock`: The clock used to generate timestamps
351 /// * `client`: The [`Client`] which created the [`Session`]
352 /// * `scope`: The [`Scope`] of the [`Session`]
353 ///
354 /// # Errors
355 ///
356 /// Returns [`Self::Error`] if the underlying repository fails
357 async fn add_from_client_credentials(
358 &mut self,
359 rng: &mut (dyn RngCore + Send),
360 clock: &dyn Clock,
361 client: &Client,
362 scope: Scope,
363 ) -> Result<Session, Self::Error> {
364 self.add(rng, clock, client, None, None, scope).await
365 }
366
367 /// Mark a [`Session`] as finished
368 ///
369 /// Returns the updated [`Session`]
370 ///
371 /// # Parameters
372 ///
373 /// * `clock`: The clock used to generate timestamps
374 /// * `session`: The [`Session`] to mark as finished
375 ///
376 /// # Errors
377 ///
378 /// Returns [`Self::Error`] if the underlying repository fails
379 async fn finish(&mut self, clock: &dyn Clock, session: Session)
380 -> Result<Session, Self::Error>;
381
382 /// Mark all the [`Session`] matching the given filter as finished
383 ///
384 /// Returns the number of sessions affected
385 ///
386 /// # Parameters
387 ///
388 /// * `clock`: The clock used to generate timestamps
389 /// * `filter`: The filter parameters
390 ///
391 /// # Errors
392 ///
393 /// Returns [`Self::Error`] if the underlying repository fails
394 async fn finish_bulk(
395 &mut self,
396 clock: &dyn Clock,
397 filter: OAuth2SessionFilter<'_>,
398 ) -> Result<usize, Self::Error>;
399
400 /// List [`Session`]s matching the given filter and pagination parameters
401 ///
402 /// # Parameters
403 ///
404 /// * `filter`: The filter parameters
405 /// * `pagination`: The pagination parameters
406 ///
407 /// # Errors
408 ///
409 /// Returns [`Self::Error`] if the underlying repository fails
410 async fn list(
411 &mut self,
412 filter: OAuth2SessionFilter<'_>,
413 pagination: Pagination,
414 ) -> Result<Page<Session>, Self::Error>;
415
416 /// Count [`Session`]s matching the given filter
417 ///
418 /// # Parameters
419 ///
420 /// * `filter`: The filter parameters
421 ///
422 /// # Errors
423 ///
424 /// Returns [`Self::Error`] if the underlying repository fails
425 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
426
427 /// Record a batch of [`Session`] activity
428 ///
429 /// # Parameters
430 ///
431 /// * `activity`: A list of tuples containing the session ID, the last
432 /// activity timestamp and the IP address of the client
433 ///
434 /// # Errors
435 ///
436 /// Returns [`Self::Error`] if the underlying repository fails
437 async fn record_batch_activity(
438 &mut self,
439 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
440 ) -> Result<(), Self::Error>;
441
442 /// Record the user agent of a [`Session`]
443 ///
444 /// # Parameters
445 ///
446 /// * `session`: The [`Session`] to record the user agent for
447 /// * `user_agent`: The user agent to record
448 async fn record_user_agent(
449 &mut self,
450 session: Session,
451 user_agent: String,
452 ) -> Result<Session, Self::Error>;
453
454 /// Set the human name of a [`Session`]
455 ///
456 /// # Parameters
457 ///
458 /// * `session`: The [`Session`] to set the human name for
459 /// * `human_name`: The human name to set
460 async fn set_human_name(
461 &mut self,
462 session: Session,
463 human_name: Option<String>,
464 ) -> Result<Session, Self::Error>;
465
466 /// Cleanup finished [`Session`]s
467 ///
468 /// Deletes sessions finished between `since` and `until`. Returns the
469 /// number of deleted sessions and the timestamp of the last deleted
470 /// session for pagination.
471 ///
472 /// # Parameters
473 ///
474 /// * `since`: The earliest finish time to delete (exclusive). If `None`,
475 /// starts from the beginning.
476 /// * `until`: The latest finish time to delete (exclusive)
477 /// * `limit`: Maximum number of sessions to delete in this batch
478 ///
479 /// # Errors
480 ///
481 /// Returns [`Self::Error`] if the underlying repository fails
482 async fn cleanup_finished(
483 &mut self,
484 since: Option<DateTime<Utc>>,
485 until: DateTime<Utc>,
486 limit: usize,
487 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
488
489 /// Clear IP addresses from sessions inactive since the threshold
490 ///
491 /// Sets `last_active_ip` to `NULL` for sessions where `last_active_at` is
492 /// before the threshold. Returns the number of sessions affected and the
493 /// last `last_active_at` timestamp processed for pagination.
494 ///
495 /// # Parameters
496 ///
497 /// * `since`: Only process sessions with `last_active_at` at or after this
498 /// timestamp (exclusive). If `None`, starts from the beginning.
499 /// * `threshold`: Clear IPs for sessions with `last_active_at` before this
500 /// time
501 /// * `limit`: Maximum number of sessions to update in this batch
502 ///
503 /// # Errors
504 ///
505 /// Returns [`Self::Error`] if the underlying repository fails
506 async fn cleanup_inactive_ips(
507 &mut self,
508 since: Option<DateTime<Utc>>,
509 threshold: DateTime<Utc>,
510 limit: usize,
511 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
512}
513
514repository_impl!(OAuth2SessionRepository:
515 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
516
517 async fn add(
518 &mut self,
519 rng: &mut (dyn RngCore + Send),
520 clock: &dyn Clock,
521 client: &Client,
522 user: Option<&User>,
523 user_session: Option<&BrowserSession>,
524 scope: Scope,
525 ) -> Result<Session, Self::Error>;
526
527 async fn add_from_browser_session(
528 &mut self,
529 rng: &mut (dyn RngCore + Send),
530 clock: &dyn Clock,
531 client: &Client,
532 user_session: &BrowserSession,
533 scope: Scope,
534 ) -> Result<Session, Self::Error>;
535
536 async fn add_from_client_credentials(
537 &mut self,
538 rng: &mut (dyn RngCore + Send),
539 clock: &dyn Clock,
540 client: &Client,
541 scope: Scope,
542 ) -> Result<Session, Self::Error>;
543
544 async fn finish(&mut self, clock: &dyn Clock, session: Session)
545 -> Result<Session, Self::Error>;
546
547 async fn finish_bulk(
548 &mut self,
549 clock: &dyn Clock,
550 filter: OAuth2SessionFilter<'_>,
551 ) -> Result<usize, Self::Error>;
552
553 async fn list(
554 &mut self,
555 filter: OAuth2SessionFilter<'_>,
556 pagination: Pagination,
557 ) -> Result<Page<Session>, Self::Error>;
558
559 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
560
561 async fn record_batch_activity(
562 &mut self,
563 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
564 ) -> Result<(), Self::Error>;
565
566 async fn record_user_agent(
567 &mut self,
568 session: Session,
569 user_agent: String,
570 ) -> Result<Session, Self::Error>;
571
572 async fn set_human_name(
573 &mut self,
574 session: Session,
575 human_name: Option<String>,
576 ) -> Result<Session, Self::Error>;
577
578 async fn cleanup_finished(
579 &mut self,
580 since: Option<DateTime<Utc>>,
581 until: DateTime<Utc>,
582 limit: usize,
583 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
584
585 async fn cleanup_inactive_ips(
586 &mut self,
587 since: Option<DateTime<Utc>>,
588 threshold: DateTime<Utc>,
589 limit: usize,
590 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
591);