mas_storage_pg/upstream_oauth2/
mod.rs1mod link;
11mod provider;
12mod session;
13
14pub use self::{
15 link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16 session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21 use chrono::Duration;
22 use mas_data_model::{
23 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod,
24 };
25 use mas_iana::jose::JsonWebSignatureAlg;
26 use mas_storage::{
27 Pagination, RepositoryAccess,
28 clock::MockClock,
29 upstream_oauth2::{
30 UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31 UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32 UpstreamOAuthSessionRepository,
33 },
34 user::UserRepository,
35 };
36 use oauth2_types::scope::{OPENID, Scope};
37 use rand::SeedableRng;
38 use sqlx::PgPool;
39
40 use crate::PgRepository;
41
42 #[sqlx::test(migrator = "crate::MIGRATOR")]
43 async fn test_repository(pool: PgPool) {
44 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45 let clock = MockClock::default();
46 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48 let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50 assert!(all_providers.is_empty());
51
52 let provider = repo
54 .upstream_oauth_provider()
55 .add(
56 &mut rng,
57 &clock,
58 UpstreamOAuthProviderParams {
59 issuer: Some("https://example.com/".to_owned()),
60 human_name: None,
61 brand_name: None,
62 scope: Scope::from_iter([OPENID]),
63 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65 fetch_userinfo: false,
66 userinfo_signed_response_alg: None,
67 token_endpoint_signing_alg: None,
68 client_id: "client-id".to_owned(),
69 encrypted_client_secret: None,
70 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71 token_endpoint_override: None,
72 authorization_endpoint_override: None,
73 userinfo_endpoint_override: None,
74 jwks_uri_override: None,
75 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77 response_mode: None,
78 additional_authorization_parameters: Vec::new(),
79 ui_order: 0,
80 },
81 )
82 .await
83 .unwrap();
84
85 let provider = repo
87 .upstream_oauth_provider()
88 .lookup(provider.id)
89 .await
90 .unwrap()
91 .expect("provider to be found in the database");
92 assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
93 assert_eq!(provider.client_id, "client-id");
94
95 let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
97 assert_eq!(providers.len(), 1);
98 assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
99 assert_eq!(providers[0].client_id, "client-id");
100
101 let session = repo
103 .upstream_oauth_session()
104 .add(
105 &mut rng,
106 &clock,
107 &provider,
108 "some-state".to_owned(),
109 None,
110 "some-nonce".to_owned(),
111 )
112 .await
113 .unwrap();
114
115 let session = repo
117 .upstream_oauth_session()
118 .lookup(session.id)
119 .await
120 .unwrap()
121 .expect("session to be found in the database");
122 assert_eq!(session.provider_id, provider.id);
123 assert_eq!(session.link_id(), None);
124 assert!(session.is_pending());
125 assert!(!session.is_completed());
126 assert!(!session.is_consumed());
127
128 let link = repo
130 .upstream_oauth_link()
131 .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
132 .await
133 .unwrap();
134
135 repo.upstream_oauth_link()
137 .lookup(link.id)
138 .await
139 .unwrap()
140 .expect("link to be found in database");
141
142 let link = repo
144 .upstream_oauth_link()
145 .find_by_subject(&provider, "a-subject")
146 .await
147 .unwrap()
148 .expect("link to be found in database");
149 assert_eq!(link.subject, "a-subject");
150 assert_eq!(link.provider_id, provider.id);
151
152 let session = repo
153 .upstream_oauth_session()
154 .complete_with_link(&clock, session, &link, None, None, None)
155 .await
156 .unwrap();
157 let session = repo
159 .upstream_oauth_session()
160 .lookup(session.id)
161 .await
162 .unwrap()
163 .expect("session to be found in the database");
164 assert!(session.is_completed());
165 assert!(!session.is_consumed());
166 assert_eq!(session.link_id(), Some(link.id));
167
168 let session = repo
169 .upstream_oauth_session()
170 .consume(&clock, session)
171 .await
172 .unwrap();
173 let session = repo
175 .upstream_oauth_session()
176 .lookup(session.id)
177 .await
178 .unwrap()
179 .expect("session to be found in the database");
180 assert!(session.is_consumed());
181
182 let user = repo
183 .user()
184 .add(&mut rng, &clock, "john".to_owned())
185 .await
186 .unwrap();
187 repo.upstream_oauth_link()
188 .associate_to_user(&link, &user)
189 .await
190 .unwrap();
191
192 let filter = UpstreamOAuthLinkFilter::new()
194 .for_user(&user)
195 .for_provider(&provider)
196 .for_subject("a-subject")
197 .enabled_providers_only();
198
199 let links = repo
200 .upstream_oauth_link()
201 .list(filter, Pagination::first(10))
202 .await
203 .unwrap();
204 assert!(!links.has_previous_page);
205 assert!(!links.has_next_page);
206 assert_eq!(links.edges.len(), 1);
207 assert_eq!(links.edges[0].id, link.id);
208 assert_eq!(links.edges[0].user_id, Some(user.id));
209
210 assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
211
212 assert_eq!(
214 repo.upstream_oauth_provider()
215 .count(UpstreamOAuthProviderFilter::new())
216 .await
217 .unwrap(),
218 1
219 );
220 assert_eq!(
221 repo.upstream_oauth_provider()
222 .count(UpstreamOAuthProviderFilter::new().enabled_only())
223 .await
224 .unwrap(),
225 1
226 );
227 assert_eq!(
228 repo.upstream_oauth_provider()
229 .count(UpstreamOAuthProviderFilter::new().disabled_only())
230 .await
231 .unwrap(),
232 0
233 );
234
235 repo.upstream_oauth_provider()
237 .disable(&clock, provider.clone())
238 .await
239 .unwrap();
240
241 assert_eq!(
243 repo.upstream_oauth_provider()
244 .count(UpstreamOAuthProviderFilter::new())
245 .await
246 .unwrap(),
247 1
248 );
249 assert_eq!(
250 repo.upstream_oauth_provider()
251 .count(UpstreamOAuthProviderFilter::new().enabled_only())
252 .await
253 .unwrap(),
254 0
255 );
256 assert_eq!(
257 repo.upstream_oauth_provider()
258 .count(UpstreamOAuthProviderFilter::new().disabled_only())
259 .await
260 .unwrap(),
261 1
262 );
263
264 repo.upstream_oauth_provider()
266 .delete(provider)
267 .await
268 .unwrap();
269 assert_eq!(
270 repo.upstream_oauth_provider()
271 .count(UpstreamOAuthProviderFilter::new())
272 .await
273 .unwrap(),
274 0
275 );
276 }
277
278 #[sqlx::test(migrator = "crate::MIGRATOR")]
281 async fn test_provider_repository_pagination(pool: PgPool) {
282 let scope = Scope::from_iter([OPENID]);
283
284 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
285 let clock = MockClock::default();
286 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
287
288 let filter = UpstreamOAuthProviderFilter::new();
289
290 assert_eq!(
292 repo.upstream_oauth_provider().count(filter).await.unwrap(),
293 0
294 );
295
296 let mut ids = Vec::with_capacity(20);
297 for idx in 0..20 {
299 let client_id = format!("client-{idx}");
300 let provider = repo
301 .upstream_oauth_provider()
302 .add(
303 &mut rng,
304 &clock,
305 UpstreamOAuthProviderParams {
306 issuer: None,
307 human_name: None,
308 brand_name: None,
309 scope: scope.clone(),
310 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
311 fetch_userinfo: false,
312 userinfo_signed_response_alg: None,
313 token_endpoint_signing_alg: None,
314 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
315 client_id,
316 encrypted_client_secret: None,
317 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
318 token_endpoint_override: None,
319 authorization_endpoint_override: None,
320 userinfo_endpoint_override: None,
321 jwks_uri_override: None,
322 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
323 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
324 response_mode: None,
325 additional_authorization_parameters: Vec::new(),
326 ui_order: 0,
327 },
328 )
329 .await
330 .unwrap();
331 ids.push(provider.id);
332 clock.advance(Duration::microseconds(10 * 1000 * 1000));
333 }
334
335 assert_eq!(
337 repo.upstream_oauth_provider().count(filter).await.unwrap(),
338 20
339 );
340
341 let page = repo
343 .upstream_oauth_provider()
344 .list(filter, Pagination::first(10))
345 .await
346 .unwrap();
347
348 assert!(page.has_next_page);
350 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
351 assert_eq!(&edge_ids, &ids[..10]);
352
353 let other_page = repo
356 .upstream_oauth_provider()
357 .list(filter.enabled_only(), Pagination::first(10))
358 .await
359 .unwrap();
360
361 assert_eq!(page, other_page);
362
363 let page = repo
365 .upstream_oauth_provider()
366 .list(filter, Pagination::first(10).after(ids[9]))
367 .await
368 .unwrap();
369
370 assert!(!page.has_next_page);
372 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
373 assert_eq!(&edge_ids, &ids[10..]);
374
375 let page = repo
377 .upstream_oauth_provider()
378 .list(filter, Pagination::last(10))
379 .await
380 .unwrap();
381
382 assert!(page.has_previous_page);
384 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
385 assert_eq!(&edge_ids, &ids[10..]);
386
387 let page = repo
389 .upstream_oauth_provider()
390 .list(filter, Pagination::last(10).before(ids[10]))
391 .await
392 .unwrap();
393
394 assert!(!page.has_previous_page);
396 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
397 assert_eq!(&edge_ids, &ids[..10]);
398
399 let page = repo
401 .upstream_oauth_provider()
402 .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
403 .await
404 .unwrap();
405
406 assert!(!page.has_next_page);
408 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
409 assert_eq!(&edge_ids, &ids[6..8]);
410
411 assert!(
413 repo.upstream_oauth_provider()
414 .list(
415 UpstreamOAuthProviderFilter::new().disabled_only(),
416 Pagination::first(1)
417 )
418 .await
419 .unwrap()
420 .edges
421 .is_empty()
422 );
423 }
424}