mas_storage/oauth2/session.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Device, Session, User, UserAgent};
12use oauth2_types::scope::Scope;
13use rand_core::RngCore;
14use ulid::Ulid;
15
16use crate::{Clock, Pagination, pagination::Page, repository_impl};
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum OAuth2SessionState {
20 Active,
21 Finished,
22}
23
24impl OAuth2SessionState {
25 pub fn is_active(self) -> bool {
26 matches!(self, Self::Active)
27 }
28
29 pub fn is_finished(self) -> bool {
30 matches!(self, Self::Finished)
31 }
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
35pub enum ClientKind {
36 Static,
37 Dynamic,
38}
39
40impl ClientKind {
41 pub fn is_static(self) -> bool {
42 matches!(self, Self::Static)
43 }
44}
45
46/// Filter parameters for listing OAuth 2.0 sessions
47#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
48pub struct OAuth2SessionFilter<'a> {
49 user: Option<&'a User>,
50 any_user: Option<bool>,
51 browser_session: Option<&'a BrowserSession>,
52 device: Option<&'a Device>,
53 client: Option<&'a Client>,
54 client_kind: Option<ClientKind>,
55 state: Option<OAuth2SessionState>,
56 scope: Option<&'a Scope>,
57 last_active_before: Option<DateTime<Utc>>,
58 last_active_after: Option<DateTime<Utc>>,
59}
60
61impl<'a> OAuth2SessionFilter<'a> {
62 /// Create a new [`OAuth2SessionFilter`] with default values
63 #[must_use]
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 /// List sessions for a specific user
69 #[must_use]
70 pub fn for_user(mut self, user: &'a User) -> Self {
71 self.user = Some(user);
72 self
73 }
74
75 /// Get the user filter
76 ///
77 /// Returns [`None`] if no user filter was set
78 #[must_use]
79 pub fn user(&self) -> Option<&'a User> {
80 self.user
81 }
82
83 /// List sessions which belong to any user
84 #[must_use]
85 pub fn for_any_user(mut self) -> Self {
86 self.any_user = Some(true);
87 self
88 }
89
90 /// List sessions which belong to no user
91 #[must_use]
92 pub fn for_no_user(mut self) -> Self {
93 self.any_user = Some(false);
94 self
95 }
96
97 /// Get the 'any user' filter
98 ///
99 /// Returns [`None`] if no 'any user' filter was set
100 #[must_use]
101 pub fn any_user(&self) -> Option<bool> {
102 self.any_user
103 }
104
105 /// List sessions started by a specific browser session
106 #[must_use]
107 pub fn for_browser_session(mut self, browser_session: &'a BrowserSession) -> Self {
108 self.browser_session = Some(browser_session);
109 self
110 }
111
112 /// Get the browser session filter
113 ///
114 /// Returns [`None`] if no browser session filter was set
115 #[must_use]
116 pub fn browser_session(&self) -> Option<&'a BrowserSession> {
117 self.browser_session
118 }
119
120 /// List sessions for a specific client
121 #[must_use]
122 pub fn for_client(mut self, client: &'a Client) -> Self {
123 self.client = Some(client);
124 self
125 }
126
127 /// Get the client filter
128 ///
129 /// Returns [`None`] if no client filter was set
130 #[must_use]
131 pub fn client(&self) -> Option<&'a Client> {
132 self.client
133 }
134
135 /// List only static clients
136 #[must_use]
137 pub fn only_static_clients(mut self) -> Self {
138 self.client_kind = Some(ClientKind::Static);
139 self
140 }
141
142 /// List only dynamic clients
143 #[must_use]
144 pub fn only_dynamic_clients(mut self) -> Self {
145 self.client_kind = Some(ClientKind::Dynamic);
146 self
147 }
148
149 /// Get the client kind filter
150 ///
151 /// Returns [`None`] if no client kind filter was set
152 #[must_use]
153 pub fn client_kind(&self) -> Option<ClientKind> {
154 self.client_kind
155 }
156
157 /// Only return sessions with a last active time before the given time
158 #[must_use]
159 pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
160 self.last_active_before = Some(last_active_before);
161 self
162 }
163
164 /// Only return sessions with a last active time after the given time
165 #[must_use]
166 pub fn with_last_active_after(mut self, last_active_after: DateTime<Utc>) -> Self {
167 self.last_active_after = Some(last_active_after);
168 self
169 }
170
171 /// Get the last active before filter
172 ///
173 /// Returns [`None`] if no client filter was set
174 #[must_use]
175 pub fn last_active_before(&self) -> Option<DateTime<Utc>> {
176 self.last_active_before
177 }
178
179 /// Get the last active after filter
180 ///
181 /// Returns [`None`] if no client filter was set
182 #[must_use]
183 pub fn last_active_after(&self) -> Option<DateTime<Utc>> {
184 self.last_active_after
185 }
186
187 /// Only return active sessions
188 #[must_use]
189 pub fn active_only(mut self) -> Self {
190 self.state = Some(OAuth2SessionState::Active);
191 self
192 }
193
194 /// Only return finished sessions
195 #[must_use]
196 pub fn finished_only(mut self) -> Self {
197 self.state = Some(OAuth2SessionState::Finished);
198 self
199 }
200
201 /// Get the state filter
202 ///
203 /// Returns [`None`] if no state filter was set
204 #[must_use]
205 pub fn state(&self) -> Option<OAuth2SessionState> {
206 self.state
207 }
208
209 /// Only return sessions with the given scope
210 #[must_use]
211 pub fn with_scope(mut self, scope: &'a Scope) -> Self {
212 self.scope = Some(scope);
213 self
214 }
215
216 /// Get the scope filter
217 ///
218 /// Returns [`None`] if no scope filter was set
219 #[must_use]
220 pub fn scope(&self) -> Option<&'a Scope> {
221 self.scope
222 }
223
224 /// Only return sessions that have the given device in their scope
225 #[must_use]
226 pub fn for_device(mut self, device: &'a Device) -> Self {
227 self.device = Some(device);
228 self
229 }
230
231 /// Get the device filter
232 ///
233 /// Returns [`None`] if no device filter was set
234 #[must_use]
235 pub fn device(&self) -> Option<&'a Device> {
236 self.device
237 }
238}
239
240/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
241/// saved in the storage backend
242#[async_trait]
243pub trait OAuth2SessionRepository: Send + Sync {
244 /// The error type returned by the repository
245 type Error;
246
247 /// Lookup an [`Session`] by its ID
248 ///
249 /// Returns `None` if no [`Session`] was found
250 ///
251 /// # Parameters
252 ///
253 /// * `id`: The ID of the [`Session`] to lookup
254 ///
255 /// # Errors
256 ///
257 /// Returns [`Self::Error`] if the underlying repository fails
258 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
259
260 /// Create a new [`Session`] with the given parameters
261 ///
262 /// Returns the newly created [`Session`]
263 ///
264 /// # Parameters
265 ///
266 /// * `rng`: The random number generator to use
267 /// * `clock`: The clock used to generate timestamps
268 /// * `client`: The [`Client`] which created the [`Session`]
269 /// * `user`: The [`User`] for which the session should be created, if any
270 /// * `user_session`: The [`BrowserSession`] of the user which completed the
271 /// authorization, if any
272 /// * `scope`: The [`Scope`] of the [`Session`]
273 ///
274 /// # Errors
275 ///
276 /// Returns [`Self::Error`] if the underlying repository fails
277 async fn add(
278 &mut self,
279 rng: &mut (dyn RngCore + Send),
280 clock: &dyn Clock,
281 client: &Client,
282 user: Option<&User>,
283 user_session: Option<&BrowserSession>,
284 scope: Scope,
285 ) -> Result<Session, Self::Error>;
286
287 /// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
288 ///
289 /// Returns the newly created [`Session`]
290 ///
291 /// # Parameters
292 ///
293 /// * `rng`: The random number generator to use
294 /// * `clock`: The clock used to generate timestamps
295 /// * `client`: The [`Client`] which created the [`Session`]
296 /// * `user_session`: The [`BrowserSession`] of the user which completed the
297 /// authorization
298 /// * `scope`: The [`Scope`] of the [`Session`]
299 ///
300 /// # Errors
301 ///
302 /// Returns [`Self::Error`] if the underlying repository fails
303 async fn add_from_browser_session(
304 &mut self,
305 rng: &mut (dyn RngCore + Send),
306 clock: &dyn Clock,
307 client: &Client,
308 user_session: &BrowserSession,
309 scope: Scope,
310 ) -> Result<Session, Self::Error> {
311 self.add(
312 rng,
313 clock,
314 client,
315 Some(&user_session.user),
316 Some(user_session),
317 scope,
318 )
319 .await
320 }
321
322 /// Create a new [`Session`] for a [`Client`] using the client credentials
323 /// flow
324 ///
325 /// Returns the newly created [`Session`]
326 ///
327 /// # Parameters
328 ///
329 /// * `rng`: The random number generator to use
330 /// * `clock`: The clock used to generate timestamps
331 /// * `client`: The [`Client`] which created the [`Session`]
332 /// * `scope`: The [`Scope`] of the [`Session`]
333 ///
334 /// # Errors
335 ///
336 /// Returns [`Self::Error`] if the underlying repository fails
337 async fn add_from_client_credentials(
338 &mut self,
339 rng: &mut (dyn RngCore + Send),
340 clock: &dyn Clock,
341 client: &Client,
342 scope: Scope,
343 ) -> Result<Session, Self::Error> {
344 self.add(rng, clock, client, None, None, scope).await
345 }
346
347 /// Mark a [`Session`] as finished
348 ///
349 /// Returns the updated [`Session`]
350 ///
351 /// # Parameters
352 ///
353 /// * `clock`: The clock used to generate timestamps
354 /// * `session`: The [`Session`] to mark as finished
355 ///
356 /// # Errors
357 ///
358 /// Returns [`Self::Error`] if the underlying repository fails
359 async fn finish(&mut self, clock: &dyn Clock, session: Session)
360 -> Result<Session, Self::Error>;
361
362 /// Mark all the [`Session`] matching the given filter as finished
363 ///
364 /// Returns the number of sessions affected
365 ///
366 /// # Parameters
367 ///
368 /// * `clock`: The clock used to generate timestamps
369 /// * `filter`: The filter parameters
370 ///
371 /// # Errors
372 ///
373 /// Returns [`Self::Error`] if the underlying repository fails
374 async fn finish_bulk(
375 &mut self,
376 clock: &dyn Clock,
377 filter: OAuth2SessionFilter<'_>,
378 ) -> Result<usize, Self::Error>;
379
380 /// List [`Session`]s matching the given filter and pagination parameters
381 ///
382 /// # Parameters
383 ///
384 /// * `filter`: The filter parameters
385 /// * `pagination`: The pagination parameters
386 ///
387 /// # Errors
388 ///
389 /// Returns [`Self::Error`] if the underlying repository fails
390 async fn list(
391 &mut self,
392 filter: OAuth2SessionFilter<'_>,
393 pagination: Pagination,
394 ) -> Result<Page<Session>, Self::Error>;
395
396 /// Count [`Session`]s matching the given filter
397 ///
398 /// # Parameters
399 ///
400 /// * `filter`: The filter parameters
401 ///
402 /// # Errors
403 ///
404 /// Returns [`Self::Error`] if the underlying repository fails
405 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
406
407 /// Record a batch of [`Session`] activity
408 ///
409 /// # Parameters
410 ///
411 /// * `activity`: A list of tuples containing the session ID, the last
412 /// activity timestamp and the IP address of the client
413 ///
414 /// # Errors
415 ///
416 /// Returns [`Self::Error`] if the underlying repository fails
417 async fn record_batch_activity(
418 &mut self,
419 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
420 ) -> Result<(), Self::Error>;
421
422 /// Record the user agent of a [`Session`]
423 ///
424 /// # Parameters
425 ///
426 /// * `session`: The [`Session`] to record the user agent for
427 /// * `user_agent`: The user agent to record
428 async fn record_user_agent(
429 &mut self,
430 session: Session,
431 user_agent: UserAgent,
432 ) -> Result<Session, Self::Error>;
433}
434
435repository_impl!(OAuth2SessionRepository:
436 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
437
438 async fn add(
439 &mut self,
440 rng: &mut (dyn RngCore + Send),
441 clock: &dyn Clock,
442 client: &Client,
443 user: Option<&User>,
444 user_session: Option<&BrowserSession>,
445 scope: Scope,
446 ) -> Result<Session, Self::Error>;
447
448 async fn add_from_browser_session(
449 &mut self,
450 rng: &mut (dyn RngCore + Send),
451 clock: &dyn Clock,
452 client: &Client,
453 user_session: &BrowserSession,
454 scope: Scope,
455 ) -> Result<Session, Self::Error>;
456
457 async fn add_from_client_credentials(
458 &mut self,
459 rng: &mut (dyn RngCore + Send),
460 clock: &dyn Clock,
461 client: &Client,
462 scope: Scope,
463 ) -> Result<Session, Self::Error>;
464
465 async fn finish(&mut self, clock: &dyn Clock, session: Session)
466 -> Result<Session, Self::Error>;
467
468 async fn finish_bulk(
469 &mut self,
470 clock: &dyn Clock,
471 filter: OAuth2SessionFilter<'_>,
472 ) -> Result<usize, Self::Error>;
473
474 async fn list(
475 &mut self,
476 filter: OAuth2SessionFilter<'_>,
477 pagination: Pagination,
478 ) -> Result<Page<Session>, Self::Error>;
479
480 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
481
482 async fn record_batch_activity(
483 &mut self,
484 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
485 ) -> Result<(), Self::Error>;
486
487 async fn record_user_agent(
488 &mut self,
489 session: Session,
490 user_agent: UserAgent,
491 ) -> Result<Session, Self::Error>;
492);