mas_storage/upstream_oauth2/session.rs
1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use mas_data_model::{
10 BrowserSession, Clock, UpstreamOAuthAuthorizationSession, UpstreamOAuthLink,
11 UpstreamOAuthProvider,
12};
13use rand_core::RngCore;
14use ulid::Ulid;
15
16use crate::{Pagination, pagination::Page, repository_impl};
17
18/// Filter parameters for listing upstream OAuth sessions
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
20pub struct UpstreamOAuthSessionFilter<'a> {
21 provider: Option<&'a UpstreamOAuthProvider>,
22 sub_claim: Option<&'a str>,
23 sid_claim: Option<&'a str>,
24}
25
26impl<'a> UpstreamOAuthSessionFilter<'a> {
27 /// Create a new [`UpstreamOAuthSessionFilter`] with default values
28 #[must_use]
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 /// Set the upstream OAuth provider for which to list sessions
34 #[must_use]
35 pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
36 self.provider = Some(provider);
37 self
38 }
39
40 /// Get the upstream OAuth provider filter
41 ///
42 /// Returns [`None`] if no filter was set
43 #[must_use]
44 pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
45 self.provider
46 }
47
48 /// Set the `sub` claim to filter by
49 #[must_use]
50 pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
51 self.sub_claim = Some(sub_claim);
52 self
53 }
54
55 /// Get the `sub` claim filter
56 ///
57 /// Returns [`None`] if no filter was set
58 #[must_use]
59 pub fn sub_claim(&self) -> Option<&str> {
60 self.sub_claim
61 }
62
63 /// Set the `sid` claim to filter by
64 #[must_use]
65 pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
66 self.sid_claim = Some(sid_claim);
67 self
68 }
69
70 /// Get the `sid` claim filter
71 ///
72 /// Returns [`None`] if no filter was set
73 #[must_use]
74 pub fn sid_claim(&self) -> Option<&str> {
75 self.sid_claim
76 }
77}
78
79/// An [`UpstreamOAuthSessionRepository`] helps interacting with
80/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend
81#[async_trait]
82pub trait UpstreamOAuthSessionRepository: Send + Sync {
83 /// The error type returned by the repository
84 type Error;
85
86 /// Lookup a session by its ID
87 ///
88 /// Returns `None` if the session does not exist
89 ///
90 /// # Parameters
91 ///
92 /// * `id`: the ID of the session to lookup
93 ///
94 /// # Errors
95 ///
96 /// Returns [`Self::Error`] if the underlying repository fails
97 async fn lookup(
98 &mut self,
99 id: Ulid,
100 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
101
102 /// Add a session to the database
103 ///
104 /// Returns the newly created session
105 ///
106 /// # Parameters
107 ///
108 /// * `rng`: the random number generator to use
109 /// * `clock`: the clock source
110 /// * `upstream_oauth_provider`: the upstream OAuth provider for which to
111 /// create the session
112 /// * `state`: the authorization grant `state` parameter sent to the
113 /// upstream OAuth provider
114 /// * `code_challenge_verifier`: the code challenge verifier used in this
115 /// session, if PKCE is being used
116 /// * `nonce`: the `nonce` used in this session if in OIDC mode
117 ///
118 /// # Errors
119 ///
120 /// Returns [`Self::Error`] if the underlying repository fails
121 async fn add(
122 &mut self,
123 rng: &mut (dyn RngCore + Send),
124 clock: &dyn Clock,
125 upstream_oauth_provider: &UpstreamOAuthProvider,
126 state: String,
127 code_challenge_verifier: Option<String>,
128 nonce: Option<String>,
129 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
130
131 /// Mark a session as completed and associate the given link
132 ///
133 /// Returns the updated session
134 ///
135 /// # Parameters
136 ///
137 /// * `clock`: the clock source
138 /// * `upstream_oauth_authorization_session`: the session to update
139 /// * `upstream_oauth_link`: the link to associate with the session
140 /// * `id_token`: the ID token returned by the upstream OAuth provider, if
141 /// present
142 /// * `id_token_claims`: the claims contained in the ID token, if present
143 /// * `extra_callback_parameters`: the extra query parameters returned in
144 /// the callback, if any
145 /// * `userinfo`: the user info returned by the upstream OAuth provider, if
146 /// requested
147 ///
148 /// # Errors
149 ///
150 /// Returns [`Self::Error`] if the underlying repository fails
151 #[expect(clippy::too_many_arguments)]
152 async fn complete_with_link(
153 &mut self,
154 clock: &dyn Clock,
155 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
156 upstream_oauth_link: &UpstreamOAuthLink,
157 id_token: Option<String>,
158 id_token_claims: Option<serde_json::Value>,
159 extra_callback_parameters: Option<serde_json::Value>,
160 userinfo: Option<serde_json::Value>,
161 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
162
163 /// Mark a session as consumed
164 ///
165 /// Returns the updated session
166 ///
167 /// # Parameters
168 ///
169 /// * `clock`: the clock source
170 /// * `upstream_oauth_authorization_session`: the session to consume
171 /// * `browser_session`: the browser session that was authenticated with
172 /// this authorization session
173 ///
174 /// # Errors
175 ///
176 /// Returns [`Self::Error`] if the underlying repository fails
177 async fn consume(
178 &mut self,
179 clock: &dyn Clock,
180 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
181 browser_session: &BrowserSession,
182 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
183
184 /// List [`UpstreamOAuthAuthorizationSession`] with the given filter and
185 /// pagination
186 ///
187 /// # Parameters
188 ///
189 /// * `filter`: The filter to apply
190 /// * `pagination`: The pagination parameters
191 ///
192 /// # Errors
193 ///
194 /// Returns [`Self::Error`] if the underlying repository fails
195 async fn list(
196 &mut self,
197 filter: UpstreamOAuthSessionFilter<'_>,
198 pagination: Pagination,
199 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
200
201 /// Count the number of [`UpstreamOAuthAuthorizationSession`] with the given
202 /// filter
203 ///
204 /// # Parameters
205 ///
206 /// * `filter`: The filter to apply
207 ///
208 /// # Errors
209 ///
210 /// Returns [`Self::Error`] if the underlying repository fails
211 async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>)
212 -> Result<usize, Self::Error>;
213
214 /// Cleanup old authorization sessions that are not linked to a user session
215 ///
216 /// This will delete sessions with IDs up to and including `until`.
217 /// Authorization sessions with a user session linked must be kept around to
218 /// avoid breaking features like OIDC Backchannel Logout.
219 ///
220 /// Returns the number of sessions deleted and the cursor for the next batch
221 ///
222 /// # Parameters
223 ///
224 /// * `since`: The cursor to start from (exclusive), or `None` to start from
225 /// the beginning
226 /// * `until`: The maximum ULID to delete (inclusive upper bound)
227 /// * `limit`: The maximum number of sessions to delete in this batch
228 ///
229 /// # Errors
230 ///
231 /// Returns [`Self::Error`] if the underlying repository fails
232 async fn cleanup_orphaned(
233 &mut self,
234 since: Option<Ulid>,
235 until: Ulid,
236 limit: usize,
237 ) -> Result<(usize, Option<Ulid>), Self::Error>;
238}
239
240repository_impl!(UpstreamOAuthSessionRepository:
241 async fn lookup(
242 &mut self,
243 id: Ulid,
244 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
245
246 async fn add(
247 &mut self,
248 rng: &mut (dyn RngCore + Send),
249 clock: &dyn Clock,
250 upstream_oauth_provider: &UpstreamOAuthProvider,
251 state: String,
252 code_challenge_verifier: Option<String>,
253 nonce: Option<String>,
254 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
255
256 async fn complete_with_link(
257 &mut self,
258 clock: &dyn Clock,
259 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
260 upstream_oauth_link: &UpstreamOAuthLink,
261 id_token: Option<String>,
262 id_token_claims: Option<serde_json::Value>,
263 extra_callback_parameters: Option<serde_json::Value>,
264 userinfo: Option<serde_json::Value>,
265 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
266
267 async fn consume(
268 &mut self,
269 clock: &dyn Clock,
270 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
271 browser_session: &BrowserSession,
272 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
273
274 async fn list(
275 &mut self,
276 filter: UpstreamOAuthSessionFilter<'_>,
277 pagination: Pagination,
278 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
279
280 async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>) -> Result<usize, Self::Error>;
281
282 async fn cleanup_orphaned(
283 &mut self,
284 since: Option<Ulid>,
285 until: Ulid,
286 limit: usize,
287 ) -> Result<(usize, Option<Ulid>), Self::Error>;
288);