1mod link;
11mod provider;
12mod session;
13
14pub use self::{
15 link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16 session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21 use chrono::Duration;
22 use mas_data_model::{
23 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24 UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
25 };
26 use mas_iana::jose::JsonWebSignatureAlg;
27 use mas_storage::{
28 Pagination, RepositoryAccess,
29 upstream_oauth2::{
30 UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31 UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32 UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
33 },
34 user::UserRepository,
35 };
36 use oauth2_types::scope::{OPENID, Scope};
37 use rand::SeedableRng;
38 use sqlx::PgPool;
39
40 use crate::PgRepository;
41
42 #[sqlx::test(migrator = "crate::MIGRATOR")]
43 async fn test_repository(pool: PgPool) {
44 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45 let clock = MockClock::default();
46 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48 let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50 assert!(all_providers.is_empty());
51
52 let provider = repo
54 .upstream_oauth_provider()
55 .add(
56 &mut rng,
57 &clock,
58 UpstreamOAuthProviderParams {
59 issuer: Some("https://example.com/".to_owned()),
60 human_name: None,
61 brand_name: None,
62 scope: Scope::from_iter([OPENID]),
63 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65 fetch_userinfo: false,
66 userinfo_signed_response_alg: None,
67 token_endpoint_signing_alg: None,
68 client_id: "client-id".to_owned(),
69 encrypted_client_secret: None,
70 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71 token_endpoint_override: None,
72 authorization_endpoint_override: None,
73 userinfo_endpoint_override: None,
74 jwks_uri_override: None,
75 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77 response_mode: None,
78 additional_authorization_parameters: Vec::new(),
79 forward_login_hint: false,
80 ui_order: 0,
81 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
82 },
83 )
84 .await
85 .unwrap();
86
87 let provider = repo
89 .upstream_oauth_provider()
90 .lookup(provider.id)
91 .await
92 .unwrap()
93 .expect("provider to be found in the database");
94 assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
95 assert_eq!(provider.client_id, "client-id");
96
97 let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
99 assert_eq!(providers.len(), 1);
100 assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
101 assert_eq!(providers[0].client_id, "client-id");
102
103 let session = repo
105 .upstream_oauth_session()
106 .add(
107 &mut rng,
108 &clock,
109 &provider,
110 "some-state".to_owned(),
111 None,
112 Some("some-nonce".to_owned()),
113 )
114 .await
115 .unwrap();
116
117 let session = repo
119 .upstream_oauth_session()
120 .lookup(session.id)
121 .await
122 .unwrap()
123 .expect("session to be found in the database");
124 assert_eq!(session.provider_id, provider.id);
125 assert_eq!(session.link_id(), None);
126 assert!(session.is_pending());
127 assert!(!session.is_completed());
128 assert!(!session.is_consumed());
129
130 let link = repo
132 .upstream_oauth_link()
133 .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
134 .await
135 .unwrap();
136
137 repo.upstream_oauth_link()
139 .lookup(link.id)
140 .await
141 .unwrap()
142 .expect("link to be found in database");
143
144 let link = repo
146 .upstream_oauth_link()
147 .find_by_subject(&provider, "a-subject")
148 .await
149 .unwrap()
150 .expect("link to be found in database");
151 assert_eq!(link.subject, "a-subject");
152 assert_eq!(link.provider_id, provider.id);
153
154 let session = repo
155 .upstream_oauth_session()
156 .complete_with_link(&clock, session, &link, None, None, None, None)
157 .await
158 .unwrap();
159 let session = repo
161 .upstream_oauth_session()
162 .lookup(session.id)
163 .await
164 .unwrap()
165 .expect("session to be found in the database");
166 assert!(session.is_completed());
167 assert!(!session.is_consumed());
168 assert_eq!(session.link_id(), Some(link.id));
169
170 let user = repo
172 .user()
173 .add(&mut rng, &clock, "john".to_owned())
174 .await
175 .unwrap();
176 let browser_session = repo
177 .browser_session()
178 .add(&mut rng, &clock, &user, None)
179 .await
180 .unwrap();
181
182 let session = repo
183 .upstream_oauth_session()
184 .consume(&clock, session, &browser_session)
185 .await
186 .unwrap();
187
188 let session = repo
190 .upstream_oauth_session()
191 .lookup(session.id)
192 .await
193 .unwrap()
194 .expect("session to be found in the database");
195 assert!(session.is_consumed());
196
197 repo.upstream_oauth_link()
198 .associate_to_user(&link, &user)
199 .await
200 .unwrap();
201
202 let filter = UpstreamOAuthLinkFilter::new()
204 .for_user(&user)
205 .for_provider(&provider)
206 .for_subject("a-subject")
207 .enabled_providers_only();
208
209 let links = repo
210 .upstream_oauth_link()
211 .list(filter, Pagination::first(10))
212 .await
213 .unwrap();
214 assert!(!links.has_previous_page);
215 assert!(!links.has_next_page);
216 assert_eq!(links.edges.len(), 1);
217 assert_eq!(links.edges[0].node.id, link.id);
218 assert_eq!(links.edges[0].node.user_id, Some(user.id));
219
220 assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
221
222 assert_eq!(
224 repo.upstream_oauth_provider()
225 .count(UpstreamOAuthProviderFilter::new())
226 .await
227 .unwrap(),
228 1
229 );
230 assert_eq!(
231 repo.upstream_oauth_provider()
232 .count(UpstreamOAuthProviderFilter::new().enabled_only())
233 .await
234 .unwrap(),
235 1
236 );
237 assert_eq!(
238 repo.upstream_oauth_provider()
239 .count(UpstreamOAuthProviderFilter::new().disabled_only())
240 .await
241 .unwrap(),
242 0
243 );
244
245 repo.upstream_oauth_provider()
247 .disable(&clock, provider.clone())
248 .await
249 .unwrap();
250
251 assert_eq!(
253 repo.upstream_oauth_provider()
254 .count(UpstreamOAuthProviderFilter::new())
255 .await
256 .unwrap(),
257 1
258 );
259 assert_eq!(
260 repo.upstream_oauth_provider()
261 .count(UpstreamOAuthProviderFilter::new().enabled_only())
262 .await
263 .unwrap(),
264 0
265 );
266 assert_eq!(
267 repo.upstream_oauth_provider()
268 .count(UpstreamOAuthProviderFilter::new().disabled_only())
269 .await
270 .unwrap(),
271 1
272 );
273
274 let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
276
277 let session_count = repo
279 .upstream_oauth_session()
280 .count(session_filter)
281 .await
282 .unwrap();
283 assert_eq!(session_count, 1);
284
285 let session_page = repo
287 .upstream_oauth_session()
288 .list(session_filter, Pagination::first(10))
289 .await
290 .unwrap();
291
292 assert_eq!(session_page.edges.len(), 1);
293 assert_eq!(session_page.edges[0].node.id, session.id);
294 assert!(!session_page.has_next_page);
295 assert!(!session_page.has_previous_page);
296
297 repo.upstream_oauth_provider()
299 .delete(provider)
300 .await
301 .unwrap();
302 assert_eq!(
303 repo.upstream_oauth_provider()
304 .count(UpstreamOAuthProviderFilter::new())
305 .await
306 .unwrap(),
307 0
308 );
309 }
310
311 #[sqlx::test(migrator = "crate::MIGRATOR")]
314 async fn test_provider_repository_pagination(pool: PgPool) {
315 let scope = Scope::from_iter([OPENID]);
316
317 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
318 let clock = MockClock::default();
319 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
320
321 let filter = UpstreamOAuthProviderFilter::new();
322
323 assert_eq!(
325 repo.upstream_oauth_provider().count(filter).await.unwrap(),
326 0
327 );
328
329 let mut ids = Vec::with_capacity(20);
330 for idx in 0..20 {
332 let client_id = format!("client-{idx}");
333 let provider = repo
334 .upstream_oauth_provider()
335 .add(
336 &mut rng,
337 &clock,
338 UpstreamOAuthProviderParams {
339 issuer: None,
340 human_name: None,
341 brand_name: None,
342 scope: scope.clone(),
343 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
344 fetch_userinfo: false,
345 userinfo_signed_response_alg: None,
346 token_endpoint_signing_alg: None,
347 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
348 client_id,
349 encrypted_client_secret: None,
350 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
351 token_endpoint_override: None,
352 authorization_endpoint_override: None,
353 userinfo_endpoint_override: None,
354 jwks_uri_override: None,
355 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
356 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
357 response_mode: None,
358 additional_authorization_parameters: Vec::new(),
359 forward_login_hint: false,
360 ui_order: 0,
361 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
362 },
363 )
364 .await
365 .unwrap();
366 ids.push(provider.id);
367 clock.advance(Duration::microseconds(10 * 1000 * 1000));
368 }
369
370 assert_eq!(
372 repo.upstream_oauth_provider().count(filter).await.unwrap(),
373 20
374 );
375
376 let page = repo
378 .upstream_oauth_provider()
379 .list(filter, Pagination::first(10))
380 .await
381 .unwrap();
382
383 assert!(page.has_next_page);
385 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
386 assert_eq!(&edge_ids, &ids[..10]);
387
388 let other_page = repo
391 .upstream_oauth_provider()
392 .list(filter.enabled_only(), Pagination::first(10))
393 .await
394 .unwrap();
395
396 assert_eq!(page, other_page);
397
398 let page = repo
400 .upstream_oauth_provider()
401 .list(filter, Pagination::first(10).after(ids[9]))
402 .await
403 .unwrap();
404
405 assert!(!page.has_next_page);
407 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
408 assert_eq!(&edge_ids, &ids[10..]);
409
410 let page = repo
412 .upstream_oauth_provider()
413 .list(filter, Pagination::last(10))
414 .await
415 .unwrap();
416
417 assert!(page.has_previous_page);
419 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
420 assert_eq!(&edge_ids, &ids[10..]);
421
422 let page = repo
424 .upstream_oauth_provider()
425 .list(filter, Pagination::last(10).before(ids[10]))
426 .await
427 .unwrap();
428
429 assert!(!page.has_previous_page);
431 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
432 assert_eq!(&edge_ids, &ids[..10]);
433
434 let page = repo
436 .upstream_oauth_provider()
437 .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
438 .await
439 .unwrap();
440
441 assert!(!page.has_next_page);
443 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
444 assert_eq!(&edge_ids, &ids[6..8]);
445
446 assert!(
448 repo.upstream_oauth_provider()
449 .list(
450 UpstreamOAuthProviderFilter::new().disabled_only(),
451 Pagination::first(1)
452 )
453 .await
454 .unwrap()
455 .edges
456 .is_empty()
457 );
458 }
459
460 #[sqlx::test(migrator = "crate::MIGRATOR")]
463 async fn test_session_repository_pagination(pool: PgPool) {
464 let scope = Scope::from_iter([OPENID]);
465
466 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
467 let clock = MockClock::default();
468 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
469
470 let provider = repo
472 .upstream_oauth_provider()
473 .add(
474 &mut rng,
475 &clock,
476 UpstreamOAuthProviderParams {
477 issuer: Some("https://example.com/".to_owned()),
478 human_name: None,
479 brand_name: None,
480 scope,
481 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
482 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
483 fetch_userinfo: false,
484 userinfo_signed_response_alg: None,
485 token_endpoint_signing_alg: None,
486 client_id: "client-id".to_owned(),
487 encrypted_client_secret: None,
488 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
489 token_endpoint_override: None,
490 authorization_endpoint_override: None,
491 userinfo_endpoint_override: None,
492 jwks_uri_override: None,
493 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
494 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
495 response_mode: None,
496 additional_authorization_parameters: Vec::new(),
497 forward_login_hint: false,
498 ui_order: 0,
499 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
500 },
501 )
502 .await
503 .unwrap();
504
505 let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
506
507 assert_eq!(
509 repo.upstream_oauth_session().count(filter).await.unwrap(),
510 0
511 );
512
513 let mut links = Vec::with_capacity(3);
514 for subject in ["alice", "bob", "charlie"] {
515 let link = repo
516 .upstream_oauth_link()
517 .add(&mut rng, &clock, &provider, subject.to_owned(), None)
518 .await
519 .unwrap();
520 links.push(link);
521 }
522
523 let mut ids = Vec::with_capacity(20);
524 let sids = ["one", "two"].into_iter().cycle();
525 for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
527 let state = format!("state-{idx}");
528 let session = repo
529 .upstream_oauth_session()
530 .add(&mut rng, &clock, &provider, state, None, None)
531 .await
532 .unwrap();
533 let id_token_claims = serde_json::json!({
534 "sub": link.subject,
535 "sid": sid,
536 "aud": provider.client_id,
537 "iss": "https://example.com/",
538 });
539 let session = repo
540 .upstream_oauth_session()
541 .complete_with_link(
542 &clock,
543 session,
544 link,
545 None,
546 Some(id_token_claims),
547 None,
548 None,
549 )
550 .await
551 .unwrap();
552 ids.push(session.id);
553 clock.advance(Duration::microseconds(10 * 1000 * 1000));
554 }
555
556 assert_eq!(
558 repo.upstream_oauth_session().count(filter).await.unwrap(),
559 20
560 );
561
562 let page = repo
564 .upstream_oauth_session()
565 .list(filter, Pagination::first(10))
566 .await
567 .unwrap();
568
569 assert!(page.has_next_page);
571 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
572 assert_eq!(&edge_ids, &ids[..10]);
573
574 let page = repo
576 .upstream_oauth_session()
577 .list(filter, Pagination::first(10).after(ids[9]))
578 .await
579 .unwrap();
580
581 assert!(!page.has_next_page);
583 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
584 assert_eq!(&edge_ids, &ids[10..]);
585
586 let page = repo
588 .upstream_oauth_session()
589 .list(filter, Pagination::last(10))
590 .await
591 .unwrap();
592
593 assert!(page.has_previous_page);
595 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
596 assert_eq!(&edge_ids, &ids[10..]);
597
598 let page = repo
600 .upstream_oauth_session()
601 .list(filter, Pagination::last(10).before(ids[10]))
602 .await
603 .unwrap();
604
605 assert!(!page.has_previous_page);
607 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
608 assert_eq!(&edge_ids, &ids[..10]);
609
610 let page = repo
612 .upstream_oauth_session()
613 .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
614 .await
615 .unwrap();
616
617 assert!(!page.has_next_page);
619 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
620 assert_eq!(&edge_ids, &ids[6..11]);
621
622 assert_eq!(
624 repo.upstream_oauth_session()
625 .count(filter.with_sub_claim("alice").with_sid_claim("one"))
626 .await
627 .unwrap(),
628 4
629 );
630 assert_eq!(
631 repo.upstream_oauth_session()
632 .count(filter.with_sub_claim("bob").with_sid_claim("two"))
633 .await
634 .unwrap(),
635 4
636 );
637
638 let page = repo
639 .upstream_oauth_session()
640 .list(
641 filter.with_sub_claim("alice").with_sid_claim("one"),
642 Pagination::first(10),
643 )
644 .await
645 .unwrap();
646 assert_eq!(page.edges.len(), 4);
647 for edge in page.edges {
648 assert_eq!(
649 edge.node
650 .id_token_claims()
651 .unwrap()
652 .get("sub")
653 .unwrap()
654 .as_str(),
655 Some("alice")
656 );
657 assert_eq!(
658 edge.node
659 .id_token_claims()
660 .unwrap()
661 .get("sid")
662 .unwrap()
663 .as_str(),
664 Some("one")
665 );
666 }
667 }
668}