mas_storage/upstream_oauth2/
session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use mas_data_model::{
10    BrowserSession, Clock, UpstreamOAuthAuthorizationSession, UpstreamOAuthLink,
11    UpstreamOAuthProvider,
12};
13use rand_core::RngCore;
14use ulid::Ulid;
15
16use crate::{Pagination, pagination::Page, repository_impl};
17
18/// Filter parameters for listing upstream OAuth sessions
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
20pub struct UpstreamOAuthSessionFilter<'a> {
21    provider: Option<&'a UpstreamOAuthProvider>,
22    sub_claim: Option<&'a str>,
23    sid_claim: Option<&'a str>,
24}
25
26impl<'a> UpstreamOAuthSessionFilter<'a> {
27    /// Create a new [`UpstreamOAuthSessionFilter`] with default values
28    #[must_use]
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    /// Set the upstream OAuth provider for which to list sessions
34    #[must_use]
35    pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
36        self.provider = Some(provider);
37        self
38    }
39
40    /// Get the upstream OAuth provider filter
41    ///
42    /// Returns [`None`] if no filter was set
43    #[must_use]
44    pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
45        self.provider
46    }
47
48    /// Set the `sub` claim to filter by
49    #[must_use]
50    pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
51        self.sub_claim = Some(sub_claim);
52        self
53    }
54
55    /// Get the `sub` claim filter
56    ///
57    /// Returns [`None`] if no filter was set
58    #[must_use]
59    pub fn sub_claim(&self) -> Option<&str> {
60        self.sub_claim
61    }
62
63    /// Set the `sid` claim to filter by
64    #[must_use]
65    pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
66        self.sid_claim = Some(sid_claim);
67        self
68    }
69
70    /// Get the `sid` claim filter
71    ///
72    /// Returns [`None`] if no filter was set
73    #[must_use]
74    pub fn sid_claim(&self) -> Option<&str> {
75        self.sid_claim
76    }
77}
78
79/// An [`UpstreamOAuthSessionRepository`] helps interacting with
80/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend
81#[async_trait]
82pub trait UpstreamOAuthSessionRepository: Send + Sync {
83    /// The error type returned by the repository
84    type Error;
85
86    /// Lookup a session by its ID
87    ///
88    /// Returns `None` if the session does not exist
89    ///
90    /// # Parameters
91    ///
92    /// * `id`: the ID of the session to lookup
93    ///
94    /// # Errors
95    ///
96    /// Returns [`Self::Error`] if the underlying repository fails
97    async fn lookup(
98        &mut self,
99        id: Ulid,
100    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
101
102    /// Add a session to the database
103    ///
104    /// Returns the newly created session
105    ///
106    /// # Parameters
107    ///
108    /// * `rng`: the random number generator to use
109    /// * `clock`: the clock source
110    /// * `upstream_oauth_provider`: the upstream OAuth provider for which to
111    ///   create the session
112    /// * `state`: the authorization grant `state` parameter sent to the
113    ///   upstream OAuth provider
114    /// * `code_challenge_verifier`: the code challenge verifier used in this
115    ///   session, if PKCE is being used
116    /// * `nonce`: the `nonce` used in this session if in OIDC mode
117    ///
118    /// # Errors
119    ///
120    /// Returns [`Self::Error`] if the underlying repository fails
121    async fn add(
122        &mut self,
123        rng: &mut (dyn RngCore + Send),
124        clock: &dyn Clock,
125        upstream_oauth_provider: &UpstreamOAuthProvider,
126        state: String,
127        code_challenge_verifier: Option<String>,
128        nonce: Option<String>,
129    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
130
131    /// Mark a session as completed and associate the given link
132    ///
133    /// Returns the updated session
134    ///
135    /// # Parameters
136    ///
137    /// * `clock`: the clock source
138    /// * `upstream_oauth_authorization_session`: the session to update
139    /// * `upstream_oauth_link`: the link to associate with the session
140    /// * `id_token`: the ID token returned by the upstream OAuth provider, if
141    ///   present
142    /// * `id_token_claims`: the claims contained in the ID token, if present
143    /// * `extra_callback_parameters`: the extra query parameters returned in
144    ///   the callback, if any
145    /// * `userinfo`: the user info returned by the upstream OAuth provider, if
146    ///   requested
147    ///
148    /// # Errors
149    ///
150    /// Returns [`Self::Error`] if the underlying repository fails
151    #[expect(clippy::too_many_arguments)]
152    async fn complete_with_link(
153        &mut self,
154        clock: &dyn Clock,
155        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
156        upstream_oauth_link: &UpstreamOAuthLink,
157        id_token: Option<String>,
158        id_token_claims: Option<serde_json::Value>,
159        extra_callback_parameters: Option<serde_json::Value>,
160        userinfo: Option<serde_json::Value>,
161    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
162
163    /// Mark a session as consumed
164    ///
165    /// Returns the updated session
166    ///
167    /// # Parameters
168    ///
169    /// * `clock`: the clock source
170    /// * `upstream_oauth_authorization_session`: the session to consume
171    /// * `browser_session`: the browser session that was authenticated with
172    ///   this authorization session
173    ///
174    /// # Errors
175    ///
176    /// Returns [`Self::Error`] if the underlying repository fails
177    async fn consume(
178        &mut self,
179        clock: &dyn Clock,
180        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
181        browser_session: &BrowserSession,
182    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
183
184    /// List [`UpstreamOAuthAuthorizationSession`] with the given filter and
185    /// pagination
186    ///
187    /// # Parameters
188    ///
189    /// * `filter`: The filter to apply
190    /// * `pagination`: The pagination parameters
191    ///
192    /// # Errors
193    ///
194    /// Returns [`Self::Error`] if the underlying repository fails
195    async fn list(
196        &mut self,
197        filter: UpstreamOAuthSessionFilter<'_>,
198        pagination: Pagination,
199    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
200
201    /// Count the number of [`UpstreamOAuthAuthorizationSession`] with the given
202    /// filter
203    ///
204    /// # Parameters
205    ///
206    /// * `filter`: The filter to apply
207    ///
208    /// # Errors
209    ///
210    /// Returns [`Self::Error`] if the underlying repository fails
211    async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>)
212    -> Result<usize, Self::Error>;
213
214    /// Cleanup old authorization sessions that are not linked to a user session
215    ///
216    /// This will delete sessions with IDs up to and including `until`.
217    /// Authorization sessions with a user session linked must be kept around to
218    /// avoid breaking features like OIDC Backchannel Logout.
219    ///
220    /// Returns the number of sessions deleted and the cursor for the next batch
221    ///
222    /// # Parameters
223    ///
224    /// * `since`: The cursor to start from (exclusive), or `None` to start from
225    ///   the beginning
226    /// * `until`: The maximum ULID to delete (inclusive upper bound)
227    /// * `limit`: The maximum number of sessions to delete in this batch
228    ///
229    /// # Errors
230    ///
231    /// Returns [`Self::Error`] if the underlying repository fails
232    async fn cleanup_orphaned(
233        &mut self,
234        since: Option<Ulid>,
235        until: Ulid,
236        limit: usize,
237    ) -> Result<(usize, Option<Ulid>), Self::Error>;
238}
239
240repository_impl!(UpstreamOAuthSessionRepository:
241    async fn lookup(
242        &mut self,
243        id: Ulid,
244    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
245
246    async fn add(
247        &mut self,
248        rng: &mut (dyn RngCore + Send),
249        clock: &dyn Clock,
250        upstream_oauth_provider: &UpstreamOAuthProvider,
251        state: String,
252        code_challenge_verifier: Option<String>,
253        nonce: Option<String>,
254    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
255
256    async fn complete_with_link(
257        &mut self,
258        clock: &dyn Clock,
259        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
260        upstream_oauth_link: &UpstreamOAuthLink,
261        id_token: Option<String>,
262        id_token_claims: Option<serde_json::Value>,
263        extra_callback_parameters: Option<serde_json::Value>,
264        userinfo: Option<serde_json::Value>,
265    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
266
267    async fn consume(
268        &mut self,
269        clock: &dyn Clock,
270        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
271        browser_session: &BrowserSession,
272    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
273
274    async fn list(
275        &mut self,
276        filter: UpstreamOAuthSessionFilter<'_>,
277        pagination: Pagination,
278    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
279
280    async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>) -> Result<usize, Self::Error>;
281
282    async fn cleanup_orphaned(
283        &mut self,
284        since: Option<Ulid>,
285        until: Ulid,
286        limit: usize,
287    ) -> Result<(usize, Option<Ulid>), Self::Error>;
288);