mas_storage_pg/compat/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing PostgreSQL implementation of repositories for the
8//! compatibility layer
9
10mod access_token;
11mod refresh_token;
12mod session;
13mod sso_login;
14
15pub use self::{
16    access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
17    session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
18};
19
20#[cfg(test)]
21mod tests {
22    use chrono::Duration;
23    use mas_data_model::{Device, UserAgent};
24    use mas_storage::{
25        Clock, Pagination, RepositoryAccess,
26        clock::MockClock,
27        compat::{
28            CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
29            CompatSessionRepository, CompatSsoLoginFilter,
30        },
31        user::UserRepository,
32    };
33    use rand::SeedableRng;
34    use rand_chacha::ChaChaRng;
35    use sqlx::PgPool;
36    use ulid::Ulid;
37
38    use crate::PgRepository;
39
40    #[sqlx::test(migrator = "crate::MIGRATOR")]
41    async fn test_session_repository(pool: PgPool) {
42        let mut rng = ChaChaRng::seed_from_u64(42);
43        let clock = MockClock::default();
44        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
45
46        // Create a user
47        let user = repo
48            .user()
49            .add(&mut rng, &clock, "john".to_owned())
50            .await
51            .unwrap();
52
53        let all = CompatSessionFilter::new().for_user(&user);
54        let active = all.active_only();
55        let finished = all.finished_only();
56        let pagination = Pagination::first(10);
57
58        assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
59        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
60        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
61
62        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
63        assert!(full_list.edges.is_empty());
64        let active_list = repo
65            .compat_session()
66            .list(active, pagination)
67            .await
68            .unwrap();
69        assert!(active_list.edges.is_empty());
70        let finished_list = repo
71            .compat_session()
72            .list(finished, pagination)
73            .await
74            .unwrap();
75        assert!(finished_list.edges.is_empty());
76
77        // Start a compat session for that user
78        let device = Device::generate(&mut rng);
79        let device_str = device.as_str().to_owned();
80        let session = repo
81            .compat_session()
82            .add(&mut rng, &clock, &user, device.clone(), None, false)
83            .await
84            .unwrap();
85        assert_eq!(session.user_id, user.id);
86        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
87        assert!(session.is_valid());
88        assert!(!session.is_finished());
89
90        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
91        assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
92        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
93
94        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
95        assert_eq!(full_list.edges.len(), 1);
96        assert_eq!(full_list.edges[0].0.id, session.id);
97        let active_list = repo
98            .compat_session()
99            .list(active, pagination)
100            .await
101            .unwrap();
102        assert_eq!(active_list.edges.len(), 1);
103        assert_eq!(active_list.edges[0].0.id, session.id);
104        let finished_list = repo
105            .compat_session()
106            .list(finished, pagination)
107            .await
108            .unwrap();
109        assert!(finished_list.edges.is_empty());
110
111        // Lookup the session and check it didn't change
112        let session_lookup = repo
113            .compat_session()
114            .lookup(session.id)
115            .await
116            .unwrap()
117            .expect("compat session not found");
118        assert_eq!(session_lookup.id, session.id);
119        assert_eq!(session_lookup.user_id, user.id);
120        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
121        assert!(session_lookup.is_valid());
122        assert!(!session_lookup.is_finished());
123
124        // Record a user-agent for the session
125        assert!(session_lookup.user_agent.is_none());
126        let session = repo
127            .compat_session()
128            .record_user_agent(session_lookup, UserAgent::parse("Mozilla/5.0".to_owned()))
129            .await
130            .unwrap();
131        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
132
133        // Reload the session and check again
134        let session_lookup = repo
135            .compat_session()
136            .lookup(session.id)
137            .await
138            .unwrap()
139            .expect("compat session not found");
140        assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
141
142        // Look up the session by device
143        let list = repo
144            .compat_session()
145            .list(
146                CompatSessionFilter::new()
147                    .for_user(&user)
148                    .for_device(&device),
149                pagination,
150            )
151            .await
152            .unwrap();
153        assert_eq!(list.edges.len(), 1);
154        let session_lookup = &list.edges[0].0;
155        assert_eq!(session_lookup.id, session.id);
156        assert_eq!(session_lookup.user_id, user.id);
157        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
158        assert!(session_lookup.is_valid());
159        assert!(!session_lookup.is_finished());
160
161        // Finish the session
162        let session = repo.compat_session().finish(&clock, session).await.unwrap();
163        assert!(!session.is_valid());
164        assert!(session.is_finished());
165
166        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
167        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
168        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
169
170        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
171        assert_eq!(full_list.edges.len(), 1);
172        assert_eq!(full_list.edges[0].0.id, session.id);
173        let active_list = repo
174            .compat_session()
175            .list(active, pagination)
176            .await
177            .unwrap();
178        assert!(active_list.edges.is_empty());
179        let finished_list = repo
180            .compat_session()
181            .list(finished, pagination)
182            .await
183            .unwrap();
184        assert_eq!(finished_list.edges.len(), 1);
185        assert_eq!(finished_list.edges[0].0.id, session.id);
186
187        // Reload the session and check again
188        let session_lookup = repo
189            .compat_session()
190            .lookup(session.id)
191            .await
192            .unwrap()
193            .expect("compat session not found");
194        assert!(!session_lookup.is_valid());
195        assert!(session_lookup.is_finished());
196
197        // Now add another session, with an SSO login this time
198        let unknown_session = session;
199        // Start a new SSO login
200        let login = repo
201            .compat_sso_login()
202            .add(
203                &mut rng,
204                &clock,
205                "login-token".to_owned(),
206                "https://example.com/callback".parse().unwrap(),
207            )
208            .await
209            .unwrap();
210        assert!(login.is_pending());
211
212        // Start a browser session for the user
213        let browser_session = repo
214            .browser_session()
215            .add(&mut rng, &clock, &user, None)
216            .await
217            .unwrap();
218
219        // Start a compat session for that user
220        let device = Device::generate(&mut rng);
221        let sso_login_session = repo
222            .compat_session()
223            .add(
224                &mut rng,
225                &clock,
226                &user,
227                device,
228                Some(&browser_session),
229                false,
230            )
231            .await
232            .unwrap();
233
234        // Associate the login with the session
235        let login = repo
236            .compat_sso_login()
237            .fulfill(&clock, login, &browser_session)
238            .await
239            .unwrap();
240        assert!(login.is_fulfilled());
241        let login = repo
242            .compat_sso_login()
243            .exchange(&clock, login, &sso_login_session)
244            .await
245            .unwrap();
246        assert!(login.is_exchanged());
247
248        // Now query the session list with both the unknown and SSO login session type
249        // filter
250        let all = CompatSessionFilter::new().for_user(&user);
251        let sso_login = all.sso_login_only();
252        let unknown = all.unknown_only();
253        assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
254        assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
255        assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
256
257        let list = repo
258            .compat_session()
259            .list(sso_login, pagination)
260            .await
261            .unwrap();
262        assert_eq!(list.edges.len(), 1);
263        assert_eq!(list.edges[0].0.id, sso_login_session.id);
264        let list = repo
265            .compat_session()
266            .list(unknown, pagination)
267            .await
268            .unwrap();
269        assert_eq!(list.edges.len(), 1);
270        assert_eq!(list.edges[0].0.id, unknown_session.id);
271
272        // Check that combining the two filters works
273        // At this point, there is one active SSO login session and one finished unknown
274        // session
275        assert_eq!(
276            repo.compat_session()
277                .count(all.sso_login_only().active_only())
278                .await
279                .unwrap(),
280            1
281        );
282        assert_eq!(
283            repo.compat_session()
284                .count(all.sso_login_only().finished_only())
285                .await
286                .unwrap(),
287            0
288        );
289        assert_eq!(
290            repo.compat_session()
291                .count(all.unknown_only().active_only())
292                .await
293                .unwrap(),
294            0
295        );
296        assert_eq!(
297            repo.compat_session()
298                .count(all.unknown_only().finished_only())
299                .await
300                .unwrap(),
301            1
302        );
303
304        // Check that we can batch finish sessions
305        let affected = repo
306            .compat_session()
307            .finish_bulk(&clock, all.sso_login_only().active_only())
308            .await
309            .unwrap();
310        assert_eq!(affected, 1);
311        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
312        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
313    }
314
315    #[sqlx::test(migrator = "crate::MIGRATOR")]
316    async fn test_access_token_repository(pool: PgPool) {
317        const FIRST_TOKEN: &str = "first_access_token";
318        const SECOND_TOKEN: &str = "second_access_token";
319        let mut rng = ChaChaRng::seed_from_u64(42);
320        let clock = MockClock::default();
321        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
322
323        // Create a user
324        let user = repo
325            .user()
326            .add(&mut rng, &clock, "john".to_owned())
327            .await
328            .unwrap();
329
330        // Start a compat session for that user
331        let device = Device::generate(&mut rng);
332        let session = repo
333            .compat_session()
334            .add(&mut rng, &clock, &user, device, None, false)
335            .await
336            .unwrap();
337
338        // Add an access token to that session
339        let token = repo
340            .compat_access_token()
341            .add(
342                &mut rng,
343                &clock,
344                &session,
345                FIRST_TOKEN.to_owned(),
346                Some(Duration::try_minutes(1).unwrap()),
347            )
348            .await
349            .unwrap();
350        assert_eq!(token.session_id, session.id);
351        assert_eq!(token.token, FIRST_TOKEN);
352
353        // Commit the txn and grab a new transaction, to test a conflict
354        repo.save().await.unwrap();
355
356        {
357            let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
358            // Adding the same token a second time should conflict
359            assert!(
360                repo.compat_access_token()
361                    .add(
362                        &mut rng,
363                        &clock,
364                        &session,
365                        FIRST_TOKEN.to_owned(),
366                        Some(Duration::try_minutes(1).unwrap()),
367                    )
368                    .await
369                    .is_err()
370            );
371            repo.cancel().await.unwrap();
372        }
373
374        // Grab a new repo
375        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
376
377        // Looking up via ID works
378        let token_lookup = repo
379            .compat_access_token()
380            .lookup(token.id)
381            .await
382            .unwrap()
383            .expect("compat access token not found");
384        assert_eq!(token.id, token_lookup.id);
385        assert_eq!(token_lookup.session_id, session.id);
386
387        // Looking up via the token value works
388        let token_lookup = repo
389            .compat_access_token()
390            .find_by_token(FIRST_TOKEN)
391            .await
392            .unwrap()
393            .expect("compat access token not found");
394        assert_eq!(token.id, token_lookup.id);
395        assert_eq!(token_lookup.session_id, session.id);
396
397        // Token is currently valid
398        assert!(token.is_valid(clock.now()));
399
400        clock.advance(Duration::try_minutes(1).unwrap());
401        // Token should have expired
402        assert!(!token.is_valid(clock.now()));
403
404        // Add a second access token, this time without expiration
405        let token = repo
406            .compat_access_token()
407            .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
408            .await
409            .unwrap();
410        assert_eq!(token.session_id, session.id);
411        assert_eq!(token.token, SECOND_TOKEN);
412
413        // Token is currently valid
414        assert!(token.is_valid(clock.now()));
415
416        // Make it expire
417        repo.compat_access_token()
418            .expire(&clock, token)
419            .await
420            .unwrap();
421
422        // Reload it
423        let token = repo
424            .compat_access_token()
425            .find_by_token(SECOND_TOKEN)
426            .await
427            .unwrap()
428            .expect("compat access token not found");
429
430        // Token is not valid anymore
431        assert!(!token.is_valid(clock.now()));
432
433        repo.save().await.unwrap();
434    }
435
436    #[sqlx::test(migrator = "crate::MIGRATOR")]
437    async fn test_refresh_token_repository(pool: PgPool) {
438        const ACCESS_TOKEN: &str = "access_token";
439        const REFRESH_TOKEN: &str = "refresh_token";
440        let mut rng = ChaChaRng::seed_from_u64(42);
441        let clock = MockClock::default();
442        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
443
444        // Create a user
445        let user = repo
446            .user()
447            .add(&mut rng, &clock, "john".to_owned())
448            .await
449            .unwrap();
450
451        // Start a compat session for that user
452        let device = Device::generate(&mut rng);
453        let session = repo
454            .compat_session()
455            .add(&mut rng, &clock, &user, device, None, false)
456            .await
457            .unwrap();
458
459        // Add an access token to that session
460        let access_token = repo
461            .compat_access_token()
462            .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
463            .await
464            .unwrap();
465
466        let refresh_token = repo
467            .compat_refresh_token()
468            .add(
469                &mut rng,
470                &clock,
471                &session,
472                &access_token,
473                REFRESH_TOKEN.to_owned(),
474            )
475            .await
476            .unwrap();
477        assert_eq!(refresh_token.session_id, session.id);
478        assert_eq!(refresh_token.access_token_id, access_token.id);
479        assert_eq!(refresh_token.token, REFRESH_TOKEN);
480        assert!(refresh_token.is_valid());
481        assert!(!refresh_token.is_consumed());
482
483        // Look it up by ID and check everything matches
484        let refresh_token_lookup = repo
485            .compat_refresh_token()
486            .lookup(refresh_token.id)
487            .await
488            .unwrap()
489            .expect("refresh token not found");
490        assert_eq!(refresh_token_lookup.id, refresh_token.id);
491        assert_eq!(refresh_token_lookup.session_id, session.id);
492        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
493        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
494        assert!(refresh_token_lookup.is_valid());
495        assert!(!refresh_token_lookup.is_consumed());
496
497        // Look it up by token and check everything matches
498        let refresh_token_lookup = repo
499            .compat_refresh_token()
500            .find_by_token(REFRESH_TOKEN)
501            .await
502            .unwrap()
503            .expect("refresh token not found");
504        assert_eq!(refresh_token_lookup.id, refresh_token.id);
505        assert_eq!(refresh_token_lookup.session_id, session.id);
506        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
507        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
508        assert!(refresh_token_lookup.is_valid());
509        assert!(!refresh_token_lookup.is_consumed());
510
511        // Consume it
512        let refresh_token = repo
513            .compat_refresh_token()
514            .consume(&clock, refresh_token)
515            .await
516            .unwrap();
517        assert!(!refresh_token.is_valid());
518        assert!(refresh_token.is_consumed());
519
520        // Reload it and check again
521        let refresh_token_lookup = repo
522            .compat_refresh_token()
523            .find_by_token(REFRESH_TOKEN)
524            .await
525            .unwrap()
526            .expect("refresh token not found");
527        assert!(!refresh_token_lookup.is_valid());
528        assert!(refresh_token_lookup.is_consumed());
529
530        // Consuming it again should not work
531        assert!(
532            repo.compat_refresh_token()
533                .consume(&clock, refresh_token)
534                .await
535                .is_err()
536        );
537
538        repo.save().await.unwrap();
539    }
540
541    #[sqlx::test(migrator = "crate::MIGRATOR")]
542    async fn test_compat_sso_login_repository(pool: PgPool) {
543        let mut rng = ChaChaRng::seed_from_u64(42);
544        let clock = MockClock::default();
545        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
546
547        // Create a user
548        let user = repo
549            .user()
550            .add(&mut rng, &clock, "john".to_owned())
551            .await
552            .unwrap();
553
554        // Lookup an unknown SSO login
555        let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
556        assert_eq!(login, None);
557
558        let all = CompatSsoLoginFilter::new();
559        let for_user = all.for_user(&user);
560        let pending = all.pending_only();
561        let fulfilled = all.fulfilled_only();
562        let exchanged = all.exchanged_only();
563
564        // Check the initial counts
565        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
566        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
567        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
568        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
569        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
570
571        // Lookup an unknown login token
572        let login = repo
573            .compat_sso_login()
574            .find_by_token("login-token")
575            .await
576            .unwrap();
577        assert_eq!(login, None);
578
579        // Start a new SSO login
580        let login = repo
581            .compat_sso_login()
582            .add(
583                &mut rng,
584                &clock,
585                "login-token".to_owned(),
586                "https://example.com/callback".parse().unwrap(),
587            )
588            .await
589            .unwrap();
590        assert!(login.is_pending());
591
592        // Check the counts
593        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
594        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
595        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
596        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
597        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
598
599        // Lookup the login by ID
600        let login_lookup = repo
601            .compat_sso_login()
602            .lookup(login.id)
603            .await
604            .unwrap()
605            .expect("login not found");
606        assert_eq!(login_lookup, login);
607
608        // Find the login by token
609        let login_lookup = repo
610            .compat_sso_login()
611            .find_by_token("login-token")
612            .await
613            .unwrap()
614            .expect("login not found");
615        assert_eq!(login_lookup, login);
616
617        // Start a compat session for that user
618        let device = Device::generate(&mut rng);
619        let compat_session = repo
620            .compat_session()
621            .add(&mut rng, &clock, &user, device, None, false)
622            .await
623            .unwrap();
624
625        // Exchanging before fulfilling should not work
626        // Note: It should also not poison the SQL transaction
627        let res = repo
628            .compat_sso_login()
629            .exchange(&clock, login.clone(), &compat_session)
630            .await;
631        assert!(res.is_err());
632
633        // Start a browser session for that user
634        let browser_session = repo
635            .browser_session()
636            .add(&mut rng, &clock, &user, None)
637            .await
638            .unwrap();
639
640        // Associate the login with the session
641        let login = repo
642            .compat_sso_login()
643            .fulfill(&clock, login, &browser_session)
644            .await
645            .unwrap();
646        assert!(login.is_fulfilled());
647
648        // Check the counts
649        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
650        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
651        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
652        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
653        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
654
655        // Fulfilling again should not work
656        // Note: It should also not poison the SQL transaction
657        let res = repo
658            .compat_sso_login()
659            .fulfill(&clock, login.clone(), &browser_session)
660            .await;
661        assert!(res.is_err());
662
663        // Exchange that login
664        let login = repo
665            .compat_sso_login()
666            .exchange(&clock, login, &compat_session)
667            .await
668            .unwrap();
669        assert!(login.is_exchanged());
670
671        // Check the counts
672        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
673        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
674        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
675        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
676        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
677
678        // Exchange again should not work
679        // Note: It should also not poison the SQL transaction
680        let res = repo
681            .compat_sso_login()
682            .exchange(&clock, login.clone(), &compat_session)
683            .await;
684        assert!(res.is_err());
685
686        // Fulfilling after exchanging should not work
687        // Note: It should also not poison the SQL transaction
688        let res = repo
689            .compat_sso_login()
690            .fulfill(&clock, login.clone(), &browser_session)
691            .await;
692        assert!(res.is_err());
693
694        let pagination = Pagination::first(10);
695
696        // List all logins
697        let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
698        assert!(!logins.has_next_page);
699        assert_eq!(logins.edges, &[login.clone()]);
700
701        // List the logins for the user
702        let logins = repo
703            .compat_sso_login()
704            .list(for_user, pagination)
705            .await
706            .unwrap();
707        assert!(!logins.has_next_page);
708        assert_eq!(logins.edges, &[login.clone()]);
709
710        // List only the pending logins for the user
711        let logins = repo
712            .compat_sso_login()
713            .list(for_user.pending_only(), pagination)
714            .await
715            .unwrap();
716        assert!(!logins.has_next_page);
717        assert!(logins.edges.is_empty());
718
719        // List only the fulfilled logins for the user
720        let logins = repo
721            .compat_sso_login()
722            .list(for_user.fulfilled_only(), pagination)
723            .await
724            .unwrap();
725        assert!(!logins.has_next_page);
726        assert!(logins.edges.is_empty());
727
728        // List only the exchanged logins for the user
729        let logins = repo
730            .compat_sso_login()
731            .list(for_user.exchanged_only(), pagination)
732            .await
733            .unwrap();
734        assert!(!logins.has_next_page);
735        assert_eq!(logins.edges, &[login]);
736    }
737}