mas_storage_pg/upstream_oauth2/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the repositories
8//! related to the upstream OAuth 2.0 providers
9
10mod link;
11mod provider;
12mod session;
13
14pub use self::{
15    link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16    session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21    use chrono::Duration;
22    use mas_data_model::{
23        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod,
24    };
25    use mas_iana::jose::JsonWebSignatureAlg;
26    use mas_storage::{
27        Pagination, RepositoryAccess,
28        clock::MockClock,
29        upstream_oauth2::{
30            UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31            UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32            UpstreamOAuthSessionRepository,
33        },
34        user::UserRepository,
35    };
36    use oauth2_types::scope::{OPENID, Scope};
37    use rand::SeedableRng;
38    use sqlx::PgPool;
39
40    use crate::PgRepository;
41
42    #[sqlx::test(migrator = "crate::MIGRATOR")]
43    async fn test_repository(pool: PgPool) {
44        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45        let clock = MockClock::default();
46        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48        // The provider list should be empty at the start
49        let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50        assert!(all_providers.is_empty());
51
52        // Let's add a provider
53        let provider = repo
54            .upstream_oauth_provider()
55            .add(
56                &mut rng,
57                &clock,
58                UpstreamOAuthProviderParams {
59                    issuer: Some("https://example.com/".to_owned()),
60                    human_name: None,
61                    brand_name: None,
62                    scope: Scope::from_iter([OPENID]),
63                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65                    fetch_userinfo: false,
66                    userinfo_signed_response_alg: None,
67                    token_endpoint_signing_alg: None,
68                    client_id: "client-id".to_owned(),
69                    encrypted_client_secret: None,
70                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71                    token_endpoint_override: None,
72                    authorization_endpoint_override: None,
73                    userinfo_endpoint_override: None,
74                    jwks_uri_override: None,
75                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77                    response_mode: None,
78                    additional_authorization_parameters: Vec::new(),
79                    ui_order: 0,
80                },
81            )
82            .await
83            .unwrap();
84
85        // Look it up in the database
86        let provider = repo
87            .upstream_oauth_provider()
88            .lookup(provider.id)
89            .await
90            .unwrap()
91            .expect("provider to be found in the database");
92        assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
93        assert_eq!(provider.client_id, "client-id");
94
95        // It should be in the list of all providers
96        let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
97        assert_eq!(providers.len(), 1);
98        assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
99        assert_eq!(providers[0].client_id, "client-id");
100
101        // Start a session
102        let session = repo
103            .upstream_oauth_session()
104            .add(
105                &mut rng,
106                &clock,
107                &provider,
108                "some-state".to_owned(),
109                None,
110                "some-nonce".to_owned(),
111            )
112            .await
113            .unwrap();
114
115        // Look it up in the database
116        let session = repo
117            .upstream_oauth_session()
118            .lookup(session.id)
119            .await
120            .unwrap()
121            .expect("session to be found in the database");
122        assert_eq!(session.provider_id, provider.id);
123        assert_eq!(session.link_id(), None);
124        assert!(session.is_pending());
125        assert!(!session.is_completed());
126        assert!(!session.is_consumed());
127
128        // Create a link
129        let link = repo
130            .upstream_oauth_link()
131            .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
132            .await
133            .unwrap();
134
135        // We can look it up by its ID
136        repo.upstream_oauth_link()
137            .lookup(link.id)
138            .await
139            .unwrap()
140            .expect("link to be found in database");
141
142        // or by its subject
143        let link = repo
144            .upstream_oauth_link()
145            .find_by_subject(&provider, "a-subject")
146            .await
147            .unwrap()
148            .expect("link to be found in database");
149        assert_eq!(link.subject, "a-subject");
150        assert_eq!(link.provider_id, provider.id);
151
152        let session = repo
153            .upstream_oauth_session()
154            .complete_with_link(&clock, session, &link, None, None, None)
155            .await
156            .unwrap();
157        // Reload the session
158        let session = repo
159            .upstream_oauth_session()
160            .lookup(session.id)
161            .await
162            .unwrap()
163            .expect("session to be found in the database");
164        assert!(session.is_completed());
165        assert!(!session.is_consumed());
166        assert_eq!(session.link_id(), Some(link.id));
167
168        let session = repo
169            .upstream_oauth_session()
170            .consume(&clock, session)
171            .await
172            .unwrap();
173        // Reload the session
174        let session = repo
175            .upstream_oauth_session()
176            .lookup(session.id)
177            .await
178            .unwrap()
179            .expect("session to be found in the database");
180        assert!(session.is_consumed());
181
182        let user = repo
183            .user()
184            .add(&mut rng, &clock, "john".to_owned())
185            .await
186            .unwrap();
187        repo.upstream_oauth_link()
188            .associate_to_user(&link, &user)
189            .await
190            .unwrap();
191
192        // XXX: we should also try other combinations of the filter
193        let filter = UpstreamOAuthLinkFilter::new()
194            .for_user(&user)
195            .for_provider(&provider)
196            .for_subject("a-subject")
197            .enabled_providers_only();
198
199        let links = repo
200            .upstream_oauth_link()
201            .list(filter, Pagination::first(10))
202            .await
203            .unwrap();
204        assert!(!links.has_previous_page);
205        assert!(!links.has_next_page);
206        assert_eq!(links.edges.len(), 1);
207        assert_eq!(links.edges[0].id, link.id);
208        assert_eq!(links.edges[0].user_id, Some(user.id));
209
210        assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
211
212        // There should be exactly one enabled provider
213        assert_eq!(
214            repo.upstream_oauth_provider()
215                .count(UpstreamOAuthProviderFilter::new())
216                .await
217                .unwrap(),
218            1
219        );
220        assert_eq!(
221            repo.upstream_oauth_provider()
222                .count(UpstreamOAuthProviderFilter::new().enabled_only())
223                .await
224                .unwrap(),
225            1
226        );
227        assert_eq!(
228            repo.upstream_oauth_provider()
229                .count(UpstreamOAuthProviderFilter::new().disabled_only())
230                .await
231                .unwrap(),
232            0
233        );
234
235        // Disable the provider
236        repo.upstream_oauth_provider()
237            .disable(&clock, provider.clone())
238            .await
239            .unwrap();
240
241        // There should be exactly one disabled provider
242        assert_eq!(
243            repo.upstream_oauth_provider()
244                .count(UpstreamOAuthProviderFilter::new())
245                .await
246                .unwrap(),
247            1
248        );
249        assert_eq!(
250            repo.upstream_oauth_provider()
251                .count(UpstreamOAuthProviderFilter::new().enabled_only())
252                .await
253                .unwrap(),
254            0
255        );
256        assert_eq!(
257            repo.upstream_oauth_provider()
258                .count(UpstreamOAuthProviderFilter::new().disabled_only())
259                .await
260                .unwrap(),
261            1
262        );
263
264        // Try deleting the provider
265        repo.upstream_oauth_provider()
266            .delete(provider)
267            .await
268            .unwrap();
269        assert_eq!(
270            repo.upstream_oauth_provider()
271                .count(UpstreamOAuthProviderFilter::new())
272                .await
273                .unwrap(),
274            0
275        );
276    }
277
278    /// Test that the pagination works as expected in the upstream OAuth
279    /// provider repository
280    #[sqlx::test(migrator = "crate::MIGRATOR")]
281    async fn test_provider_repository_pagination(pool: PgPool) {
282        let scope = Scope::from_iter([OPENID]);
283
284        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
285        let clock = MockClock::default();
286        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
287
288        let filter = UpstreamOAuthProviderFilter::new();
289
290        // Count the number of providers before we start
291        assert_eq!(
292            repo.upstream_oauth_provider().count(filter).await.unwrap(),
293            0
294        );
295
296        let mut ids = Vec::with_capacity(20);
297        // Create 20 providers
298        for idx in 0..20 {
299            let client_id = format!("client-{idx}");
300            let provider = repo
301                .upstream_oauth_provider()
302                .add(
303                    &mut rng,
304                    &clock,
305                    UpstreamOAuthProviderParams {
306                        issuer: None,
307                        human_name: None,
308                        brand_name: None,
309                        scope: scope.clone(),
310                        token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
311                        fetch_userinfo: false,
312                        userinfo_signed_response_alg: None,
313                        token_endpoint_signing_alg: None,
314                        id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
315                        client_id,
316                        encrypted_client_secret: None,
317                        claims_imports: UpstreamOAuthProviderClaimsImports::default(),
318                        token_endpoint_override: None,
319                        authorization_endpoint_override: None,
320                        userinfo_endpoint_override: None,
321                        jwks_uri_override: None,
322                        discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
323                        pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
324                        response_mode: None,
325                        additional_authorization_parameters: Vec::new(),
326                        ui_order: 0,
327                    },
328                )
329                .await
330                .unwrap();
331            ids.push(provider.id);
332            clock.advance(Duration::microseconds(10 * 1000 * 1000));
333        }
334
335        // Now we have 20 providers
336        assert_eq!(
337            repo.upstream_oauth_provider().count(filter).await.unwrap(),
338            20
339        );
340
341        // Lookup the first 10 items
342        let page = repo
343            .upstream_oauth_provider()
344            .list(filter, Pagination::first(10))
345            .await
346            .unwrap();
347
348        // It returned the first 10 items
349        assert!(page.has_next_page);
350        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
351        assert_eq!(&edge_ids, &ids[..10]);
352
353        // Getting the same page with the "enabled only" filter should return the same
354        // results
355        let other_page = repo
356            .upstream_oauth_provider()
357            .list(filter.enabled_only(), Pagination::first(10))
358            .await
359            .unwrap();
360
361        assert_eq!(page, other_page);
362
363        // Lookup the next 10 items
364        let page = repo
365            .upstream_oauth_provider()
366            .list(filter, Pagination::first(10).after(ids[9]))
367            .await
368            .unwrap();
369
370        // It returned the next 10 items
371        assert!(!page.has_next_page);
372        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
373        assert_eq!(&edge_ids, &ids[10..]);
374
375        // Lookup the last 10 items
376        let page = repo
377            .upstream_oauth_provider()
378            .list(filter, Pagination::last(10))
379            .await
380            .unwrap();
381
382        // It returned the last 10 items
383        assert!(page.has_previous_page);
384        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
385        assert_eq!(&edge_ids, &ids[10..]);
386
387        // Lookup the previous 10 items
388        let page = repo
389            .upstream_oauth_provider()
390            .list(filter, Pagination::last(10).before(ids[10]))
391            .await
392            .unwrap();
393
394        // It returned the previous 10 items
395        assert!(!page.has_previous_page);
396        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
397        assert_eq!(&edge_ids, &ids[..10]);
398
399        // Lookup 10 items between two IDs
400        let page = repo
401            .upstream_oauth_provider()
402            .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
403            .await
404            .unwrap();
405
406        // It returned the items in between
407        assert!(!page.has_next_page);
408        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
409        assert_eq!(&edge_ids, &ids[6..8]);
410
411        // There should not be any disabled providers
412        assert!(
413            repo.upstream_oauth_provider()
414                .list(
415                    UpstreamOAuthProviderFilter::new().disabled_only(),
416                    Pagination::first(1)
417                )
418                .await
419                .unwrap()
420                .edges
421                .is_empty()
422        );
423    }
424}