mas_storage_pg/upstream_oauth2/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the repositories
8//! related to the upstream OAuth 2.0 providers
9
10mod link;
11mod provider;
12mod session;
13
14pub use self::{
15    link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16    session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21    use chrono::Duration;
22    use mas_data_model::{
23        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24        UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
25    };
26    use mas_iana::jose::JsonWebSignatureAlg;
27    use mas_storage::{
28        Pagination, RepositoryAccess,
29        upstream_oauth2::{
30            UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31            UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32            UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
33        },
34        user::UserRepository,
35    };
36    use oauth2_types::scope::{OPENID, Scope};
37    use rand::SeedableRng;
38    use sqlx::PgPool;
39
40    use crate::PgRepository;
41
42    #[sqlx::test(migrator = "crate::MIGRATOR")]
43    async fn test_repository(pool: PgPool) {
44        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45        let clock = MockClock::default();
46        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48        // The provider list should be empty at the start
49        let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50        assert!(all_providers.is_empty());
51
52        // Let's add a provider
53        let provider = repo
54            .upstream_oauth_provider()
55            .add(
56                &mut rng,
57                &clock,
58                UpstreamOAuthProviderParams {
59                    issuer: Some("https://example.com/".to_owned()),
60                    human_name: None,
61                    brand_name: None,
62                    scope: Scope::from_iter([OPENID]),
63                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65                    fetch_userinfo: false,
66                    userinfo_signed_response_alg: None,
67                    token_endpoint_signing_alg: None,
68                    client_id: "client-id".to_owned(),
69                    encrypted_client_secret: None,
70                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71                    token_endpoint_override: None,
72                    authorization_endpoint_override: None,
73                    userinfo_endpoint_override: None,
74                    jwks_uri_override: None,
75                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77                    response_mode: None,
78                    additional_authorization_parameters: Vec::new(),
79                    forward_login_hint: false,
80                    ui_order: 0,
81                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
82                },
83            )
84            .await
85            .unwrap();
86
87        // Look it up in the database
88        let provider = repo
89            .upstream_oauth_provider()
90            .lookup(provider.id)
91            .await
92            .unwrap()
93            .expect("provider to be found in the database");
94        assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
95        assert_eq!(provider.client_id, "client-id");
96
97        // It should be in the list of all providers
98        let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
99        assert_eq!(providers.len(), 1);
100        assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
101        assert_eq!(providers[0].client_id, "client-id");
102
103        // Start a session
104        let session = repo
105            .upstream_oauth_session()
106            .add(
107                &mut rng,
108                &clock,
109                &provider,
110                "some-state".to_owned(),
111                None,
112                Some("some-nonce".to_owned()),
113            )
114            .await
115            .unwrap();
116
117        // Look it up in the database
118        let session = repo
119            .upstream_oauth_session()
120            .lookup(session.id)
121            .await
122            .unwrap()
123            .expect("session to be found in the database");
124        assert_eq!(session.provider_id, provider.id);
125        assert_eq!(session.link_id(), None);
126        assert!(session.is_pending());
127        assert!(!session.is_completed());
128        assert!(!session.is_consumed());
129
130        // Create a link
131        let link = repo
132            .upstream_oauth_link()
133            .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
134            .await
135            .unwrap();
136
137        // We can look it up by its ID
138        repo.upstream_oauth_link()
139            .lookup(link.id)
140            .await
141            .unwrap()
142            .expect("link to be found in database");
143
144        // or by its subject
145        let link = repo
146            .upstream_oauth_link()
147            .find_by_subject(&provider, "a-subject")
148            .await
149            .unwrap()
150            .expect("link to be found in database");
151        assert_eq!(link.subject, "a-subject");
152        assert_eq!(link.provider_id, provider.id);
153
154        let session = repo
155            .upstream_oauth_session()
156            .complete_with_link(&clock, session, &link, None, None, None, None)
157            .await
158            .unwrap();
159        // Reload the session
160        let session = repo
161            .upstream_oauth_session()
162            .lookup(session.id)
163            .await
164            .unwrap()
165            .expect("session to be found in the database");
166        assert!(session.is_completed());
167        assert!(!session.is_consumed());
168        assert_eq!(session.link_id(), Some(link.id));
169
170        // We need to create a user and start a browser session to consume the session
171        let user = repo
172            .user()
173            .add(&mut rng, &clock, "john".to_owned())
174            .await
175            .unwrap();
176        let browser_session = repo
177            .browser_session()
178            .add(&mut rng, &clock, &user, None)
179            .await
180            .unwrap();
181
182        let session = repo
183            .upstream_oauth_session()
184            .consume(&clock, session, &browser_session)
185            .await
186            .unwrap();
187
188        // Reload the session
189        let session = repo
190            .upstream_oauth_session()
191            .lookup(session.id)
192            .await
193            .unwrap()
194            .expect("session to be found in the database");
195        assert!(session.is_consumed());
196
197        repo.upstream_oauth_link()
198            .associate_to_user(&link, &user)
199            .await
200            .unwrap();
201
202        // XXX: we should also try other combinations of the filter
203        let filter = UpstreamOAuthLinkFilter::new()
204            .for_user(&user)
205            .for_provider(&provider)
206            .for_subject("a-subject")
207            .enabled_providers_only();
208
209        let links = repo
210            .upstream_oauth_link()
211            .list(filter, Pagination::first(10))
212            .await
213            .unwrap();
214        assert!(!links.has_previous_page);
215        assert!(!links.has_next_page);
216        assert_eq!(links.edges.len(), 1);
217        assert_eq!(links.edges[0].node.id, link.id);
218        assert_eq!(links.edges[0].node.user_id, Some(user.id));
219
220        assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
221
222        // There should be exactly one enabled provider
223        assert_eq!(
224            repo.upstream_oauth_provider()
225                .count(UpstreamOAuthProviderFilter::new())
226                .await
227                .unwrap(),
228            1
229        );
230        assert_eq!(
231            repo.upstream_oauth_provider()
232                .count(UpstreamOAuthProviderFilter::new().enabled_only())
233                .await
234                .unwrap(),
235            1
236        );
237        assert_eq!(
238            repo.upstream_oauth_provider()
239                .count(UpstreamOAuthProviderFilter::new().disabled_only())
240                .await
241                .unwrap(),
242            0
243        );
244
245        // Disable the provider
246        repo.upstream_oauth_provider()
247            .disable(&clock, provider.clone())
248            .await
249            .unwrap();
250
251        // There should be exactly one disabled provider
252        assert_eq!(
253            repo.upstream_oauth_provider()
254                .count(UpstreamOAuthProviderFilter::new())
255                .await
256                .unwrap(),
257            1
258        );
259        assert_eq!(
260            repo.upstream_oauth_provider()
261                .count(UpstreamOAuthProviderFilter::new().enabled_only())
262                .await
263                .unwrap(),
264            0
265        );
266        assert_eq!(
267            repo.upstream_oauth_provider()
268                .count(UpstreamOAuthProviderFilter::new().disabled_only())
269                .await
270                .unwrap(),
271            1
272        );
273
274        // Test listing and counting sessions
275        let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
276
277        // Count the sessions for the provider
278        let session_count = repo
279            .upstream_oauth_session()
280            .count(session_filter)
281            .await
282            .unwrap();
283        assert_eq!(session_count, 1);
284
285        // List the sessions for the provider
286        let session_page = repo
287            .upstream_oauth_session()
288            .list(session_filter, Pagination::first(10))
289            .await
290            .unwrap();
291
292        assert_eq!(session_page.edges.len(), 1);
293        assert_eq!(session_page.edges[0].node.id, session.id);
294        assert!(!session_page.has_next_page);
295        assert!(!session_page.has_previous_page);
296
297        // Try deleting the provider
298        repo.upstream_oauth_provider()
299            .delete(provider)
300            .await
301            .unwrap();
302        assert_eq!(
303            repo.upstream_oauth_provider()
304                .count(UpstreamOAuthProviderFilter::new())
305                .await
306                .unwrap(),
307            0
308        );
309    }
310
311    /// Test that the pagination works as expected in the upstream OAuth
312    /// provider repository
313    #[sqlx::test(migrator = "crate::MIGRATOR")]
314    async fn test_provider_repository_pagination(pool: PgPool) {
315        let scope = Scope::from_iter([OPENID]);
316
317        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
318        let clock = MockClock::default();
319        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
320
321        let filter = UpstreamOAuthProviderFilter::new();
322
323        // Count the number of providers before we start
324        assert_eq!(
325            repo.upstream_oauth_provider().count(filter).await.unwrap(),
326            0
327        );
328
329        let mut ids = Vec::with_capacity(20);
330        // Create 20 providers
331        for idx in 0..20 {
332            let client_id = format!("client-{idx}");
333            let provider = repo
334                .upstream_oauth_provider()
335                .add(
336                    &mut rng,
337                    &clock,
338                    UpstreamOAuthProviderParams {
339                        issuer: None,
340                        human_name: None,
341                        brand_name: None,
342                        scope: scope.clone(),
343                        token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
344                        fetch_userinfo: false,
345                        userinfo_signed_response_alg: None,
346                        token_endpoint_signing_alg: None,
347                        id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
348                        client_id,
349                        encrypted_client_secret: None,
350                        claims_imports: UpstreamOAuthProviderClaimsImports::default(),
351                        token_endpoint_override: None,
352                        authorization_endpoint_override: None,
353                        userinfo_endpoint_override: None,
354                        jwks_uri_override: None,
355                        discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
356                        pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
357                        response_mode: None,
358                        additional_authorization_parameters: Vec::new(),
359                        forward_login_hint: false,
360                        ui_order: 0,
361                        on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
362                    },
363                )
364                .await
365                .unwrap();
366            ids.push(provider.id);
367            clock.advance(Duration::microseconds(10 * 1000 * 1000));
368        }
369
370        // Now we have 20 providers
371        assert_eq!(
372            repo.upstream_oauth_provider().count(filter).await.unwrap(),
373            20
374        );
375
376        // Lookup the first 10 items
377        let page = repo
378            .upstream_oauth_provider()
379            .list(filter, Pagination::first(10))
380            .await
381            .unwrap();
382
383        // It returned the first 10 items
384        assert!(page.has_next_page);
385        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
386        assert_eq!(&edge_ids, &ids[..10]);
387
388        // Getting the same page with the "enabled only" filter should return the same
389        // results
390        let other_page = repo
391            .upstream_oauth_provider()
392            .list(filter.enabled_only(), Pagination::first(10))
393            .await
394            .unwrap();
395
396        assert_eq!(page, other_page);
397
398        // Lookup the next 10 items
399        let page = repo
400            .upstream_oauth_provider()
401            .list(filter, Pagination::first(10).after(ids[9]))
402            .await
403            .unwrap();
404
405        // It returned the next 10 items
406        assert!(!page.has_next_page);
407        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
408        assert_eq!(&edge_ids, &ids[10..]);
409
410        // Lookup the last 10 items
411        let page = repo
412            .upstream_oauth_provider()
413            .list(filter, Pagination::last(10))
414            .await
415            .unwrap();
416
417        // It returned the last 10 items
418        assert!(page.has_previous_page);
419        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
420        assert_eq!(&edge_ids, &ids[10..]);
421
422        // Lookup the previous 10 items
423        let page = repo
424            .upstream_oauth_provider()
425            .list(filter, Pagination::last(10).before(ids[10]))
426            .await
427            .unwrap();
428
429        // It returned the previous 10 items
430        assert!(!page.has_previous_page);
431        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
432        assert_eq!(&edge_ids, &ids[..10]);
433
434        // Lookup 10 items between two IDs
435        let page = repo
436            .upstream_oauth_provider()
437            .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
438            .await
439            .unwrap();
440
441        // It returned the items in between
442        assert!(!page.has_next_page);
443        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
444        assert_eq!(&edge_ids, &ids[6..8]);
445
446        // There should not be any disabled providers
447        assert!(
448            repo.upstream_oauth_provider()
449                .list(
450                    UpstreamOAuthProviderFilter::new().disabled_only(),
451                    Pagination::first(1)
452                )
453                .await
454                .unwrap()
455                .edges
456                .is_empty()
457        );
458    }
459
460    /// Test that the pagination works as expected in the upstream OAuth
461    /// session repository
462    #[sqlx::test(migrator = "crate::MIGRATOR")]
463    async fn test_session_repository_pagination(pool: PgPool) {
464        let scope = Scope::from_iter([OPENID]);
465
466        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
467        let clock = MockClock::default();
468        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
469
470        // Create a provider
471        let provider = repo
472            .upstream_oauth_provider()
473            .add(
474                &mut rng,
475                &clock,
476                UpstreamOAuthProviderParams {
477                    issuer: Some("https://example.com/".to_owned()),
478                    human_name: None,
479                    brand_name: None,
480                    scope,
481                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
482                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
483                    fetch_userinfo: false,
484                    userinfo_signed_response_alg: None,
485                    token_endpoint_signing_alg: None,
486                    client_id: "client-id".to_owned(),
487                    encrypted_client_secret: None,
488                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
489                    token_endpoint_override: None,
490                    authorization_endpoint_override: None,
491                    userinfo_endpoint_override: None,
492                    jwks_uri_override: None,
493                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
494                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
495                    response_mode: None,
496                    additional_authorization_parameters: Vec::new(),
497                    forward_login_hint: false,
498                    ui_order: 0,
499                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
500                },
501            )
502            .await
503            .unwrap();
504
505        let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
506
507        // Count the number of sessions before we start
508        assert_eq!(
509            repo.upstream_oauth_session().count(filter).await.unwrap(),
510            0
511        );
512
513        let mut links = Vec::with_capacity(3);
514        for subject in ["alice", "bob", "charlie"] {
515            let link = repo
516                .upstream_oauth_link()
517                .add(&mut rng, &clock, &provider, subject.to_owned(), None)
518                .await
519                .unwrap();
520            links.push(link);
521        }
522
523        let mut ids = Vec::with_capacity(20);
524        let sids = ["one", "two"].into_iter().cycle();
525        // Create 20 sessions
526        for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
527            let state = format!("state-{idx}");
528            let session = repo
529                .upstream_oauth_session()
530                .add(&mut rng, &clock, &provider, state, None, None)
531                .await
532                .unwrap();
533            let id_token_claims = serde_json::json!({
534                "sub": link.subject,
535                "sid": sid,
536                "aud": provider.client_id,
537                "iss": "https://example.com/",
538            });
539            let session = repo
540                .upstream_oauth_session()
541                .complete_with_link(
542                    &clock,
543                    session,
544                    link,
545                    None,
546                    Some(id_token_claims),
547                    None,
548                    None,
549                )
550                .await
551                .unwrap();
552            ids.push(session.id);
553            clock.advance(Duration::microseconds(10 * 1000 * 1000));
554        }
555
556        // Now we have 20 sessions
557        assert_eq!(
558            repo.upstream_oauth_session().count(filter).await.unwrap(),
559            20
560        );
561
562        // Lookup the first 10 items
563        let page = repo
564            .upstream_oauth_session()
565            .list(filter, Pagination::first(10))
566            .await
567            .unwrap();
568
569        // It returned the first 10 items
570        assert!(page.has_next_page);
571        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
572        assert_eq!(&edge_ids, &ids[..10]);
573
574        // Lookup the next 10 items
575        let page = repo
576            .upstream_oauth_session()
577            .list(filter, Pagination::first(10).after(ids[9]))
578            .await
579            .unwrap();
580
581        // It returned the next 10 items
582        assert!(!page.has_next_page);
583        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
584        assert_eq!(&edge_ids, &ids[10..]);
585
586        // Lookup the last 10 items
587        let page = repo
588            .upstream_oauth_session()
589            .list(filter, Pagination::last(10))
590            .await
591            .unwrap();
592
593        // It returned the last 10 items
594        assert!(page.has_previous_page);
595        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
596        assert_eq!(&edge_ids, &ids[10..]);
597
598        // Lookup the previous 10 items
599        let page = repo
600            .upstream_oauth_session()
601            .list(filter, Pagination::last(10).before(ids[10]))
602            .await
603            .unwrap();
604
605        // It returned the previous 10 items
606        assert!(!page.has_previous_page);
607        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
608        assert_eq!(&edge_ids, &ids[..10]);
609
610        // Lookup 5 items between two IDs
611        let page = repo
612            .upstream_oauth_session()
613            .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
614            .await
615            .unwrap();
616
617        // It returned the items in between
618        assert!(!page.has_next_page);
619        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
620        assert_eq!(&edge_ids, &ids[6..11]);
621
622        // Check the sub/sid filters
623        assert_eq!(
624            repo.upstream_oauth_session()
625                .count(filter.with_sub_claim("alice").with_sid_claim("one"))
626                .await
627                .unwrap(),
628            4
629        );
630        assert_eq!(
631            repo.upstream_oauth_session()
632                .count(filter.with_sub_claim("bob").with_sid_claim("two"))
633                .await
634                .unwrap(),
635            4
636        );
637
638        let page = repo
639            .upstream_oauth_session()
640            .list(
641                filter.with_sub_claim("alice").with_sid_claim("one"),
642                Pagination::first(10),
643            )
644            .await
645            .unwrap();
646        assert_eq!(page.edges.len(), 4);
647        for edge in page.edges {
648            assert_eq!(
649                edge.node
650                    .id_token_claims()
651                    .unwrap()
652                    .get("sub")
653                    .unwrap()
654                    .as_str(),
655                Some("alice")
656            );
657            assert_eq!(
658                edge.node
659                    .id_token_claims()
660                    .unwrap()
661                    .get("sid")
662                    .unwrap()
663                    .as_str(),
664                Some("one")
665            );
666        }
667    }
668}