mas_storage_pg/oauth2/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing the PostgreSQL implementations of the OAuth2-related
8//! repositories
9
10mod access_token;
11mod authorization_grant;
12mod client;
13mod device_code_grant;
14mod refresh_token;
15mod session;
16
17pub use self::{
18    access_token::PgOAuth2AccessTokenRepository,
19    authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository,
20    device_code_grant::PgOAuth2DeviceCodeGrantRepository,
21    refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository,
22};
23
24#[cfg(test)]
25mod tests {
26    use chrono::Duration;
27    use mas_data_model::{AuthorizationCode, UserAgent};
28    use mas_storage::{
29        Clock, Pagination,
30        clock::MockClock,
31        oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository},
32    };
33    use oauth2_types::{
34        requests::{GrantType, ResponseMode},
35        scope::{EMAIL, OPENID, PROFILE, Scope},
36    };
37    use rand::SeedableRng;
38    use rand_chacha::ChaChaRng;
39    use sqlx::PgPool;
40    use ulid::Ulid;
41
42    use crate::PgRepository;
43
44    #[sqlx::test(migrator = "crate::MIGRATOR")]
45    async fn test_repositories(pool: PgPool) {
46        let mut rng = ChaChaRng::seed_from_u64(42);
47        let clock = MockClock::default();
48        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
49
50        // Lookup a non-existing client
51        let client = repo.oauth2_client().lookup(Ulid::nil()).await.unwrap();
52        assert_eq!(client, None);
53
54        // Find a non-existing client by client id
55        let client = repo
56            .oauth2_client()
57            .find_by_client_id("some-client-id")
58            .await
59            .unwrap();
60        assert_eq!(client, None);
61
62        // Create a client
63        let client = repo
64            .oauth2_client()
65            .add(
66                &mut rng,
67                &clock,
68                vec!["https://example.com/redirect".parse().unwrap()],
69                None,
70                None,
71                None,
72                vec![GrantType::AuthorizationCode],
73                Some("Test client".to_owned()),
74                Some("https://example.com/logo.png".parse().unwrap()),
75                Some("https://example.com/".parse().unwrap()),
76                Some("https://example.com/policy".parse().unwrap()),
77                Some("https://example.com/tos".parse().unwrap()),
78                Some("https://example.com/jwks.json".parse().unwrap()),
79                None,
80                None,
81                None,
82                None,
83                None,
84                Some("https://example.com/login".parse().unwrap()),
85            )
86            .await
87            .unwrap();
88
89        // Lookup the same client by id
90        let client_lookup = repo
91            .oauth2_client()
92            .lookup(client.id)
93            .await
94            .unwrap()
95            .expect("client not found");
96        assert_eq!(client, client_lookup);
97
98        // Find the same client by client id
99        let client_lookup = repo
100            .oauth2_client()
101            .find_by_client_id(&client.client_id)
102            .await
103            .unwrap()
104            .expect("client not found");
105        assert_eq!(client, client_lookup);
106
107        // Lookup a non-existing grant
108        let grant = repo
109            .oauth2_authorization_grant()
110            .lookup(Ulid::nil())
111            .await
112            .unwrap();
113        assert_eq!(grant, None);
114
115        // Find a non-existing grant by code
116        let grant = repo
117            .oauth2_authorization_grant()
118            .find_by_code("code")
119            .await
120            .unwrap();
121        assert_eq!(grant, None);
122
123        // Create an authorization grant
124        let grant = repo
125            .oauth2_authorization_grant()
126            .add(
127                &mut rng,
128                &clock,
129                &client,
130                "https://example.com/redirect".parse().unwrap(),
131                Scope::from_iter([OPENID]),
132                Some(AuthorizationCode {
133                    code: "code".to_owned(),
134                    pkce: None,
135                }),
136                Some("state".to_owned()),
137                Some("nonce".to_owned()),
138                None,
139                ResponseMode::Query,
140                true,
141                false,
142                None,
143            )
144            .await
145            .unwrap();
146        assert!(grant.is_pending());
147
148        // Lookup the same grant by id
149        let grant_lookup = repo
150            .oauth2_authorization_grant()
151            .lookup(grant.id)
152            .await
153            .unwrap()
154            .expect("grant not found");
155        assert_eq!(grant, grant_lookup);
156
157        // Find the same grant by code
158        let grant_lookup = repo
159            .oauth2_authorization_grant()
160            .find_by_code("code")
161            .await
162            .unwrap()
163            .expect("grant not found");
164        assert_eq!(grant, grant_lookup);
165
166        // Create a user and a start a user session
167        let user = repo
168            .user()
169            .add(&mut rng, &clock, "john".to_owned())
170            .await
171            .unwrap();
172        let user_session = repo
173            .browser_session()
174            .add(&mut rng, &clock, &user, None)
175            .await
176            .unwrap();
177
178        // Lookup the consent the user gave to the client
179        let consent = repo
180            .oauth2_client()
181            .get_consent_for_user(&client, &user)
182            .await
183            .unwrap();
184        assert!(consent.is_empty());
185
186        // Give consent to the client
187        let scope = Scope::from_iter([OPENID]);
188        repo.oauth2_client()
189            .give_consent_for_user(&mut rng, &clock, &client, &user, &scope)
190            .await
191            .unwrap();
192
193        // Lookup the consent the user gave to the client
194        let consent = repo
195            .oauth2_client()
196            .get_consent_for_user(&client, &user)
197            .await
198            .unwrap();
199        assert_eq!(scope, consent);
200
201        // Lookup a non-existing session
202        let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap();
203        assert_eq!(session, None);
204
205        // Create an OAuth session
206        let session = repo
207            .oauth2_session()
208            .add_from_browser_session(
209                &mut rng,
210                &clock,
211                &client,
212                &user_session,
213                grant.scope.clone(),
214            )
215            .await
216            .unwrap();
217
218        // Mark the grant as fulfilled
219        let grant = repo
220            .oauth2_authorization_grant()
221            .fulfill(&clock, &session, grant)
222            .await
223            .unwrap();
224        assert!(grant.is_fulfilled());
225
226        // Lookup the same session by id
227        let session_lookup = repo
228            .oauth2_session()
229            .lookup(session.id)
230            .await
231            .unwrap()
232            .expect("session not found");
233        assert_eq!(session, session_lookup);
234
235        // Mark the grant as exchanged
236        let grant = repo
237            .oauth2_authorization_grant()
238            .exchange(&clock, grant)
239            .await
240            .unwrap();
241        assert!(grant.is_exchanged());
242
243        // Lookup a non-existing token
244        let token = repo
245            .oauth2_access_token()
246            .lookup(Ulid::nil())
247            .await
248            .unwrap();
249        assert_eq!(token, None);
250
251        // Find a non-existing token
252        let token = repo
253            .oauth2_access_token()
254            .find_by_token("aabbcc")
255            .await
256            .unwrap();
257        assert_eq!(token, None);
258
259        // Create an access token
260        let access_token = repo
261            .oauth2_access_token()
262            .add(
263                &mut rng,
264                &clock,
265                &session,
266                "aabbcc".to_owned(),
267                Some(Duration::try_minutes(5).unwrap()),
268            )
269            .await
270            .unwrap();
271
272        // Lookup the same token by id
273        let access_token_lookup = repo
274            .oauth2_access_token()
275            .lookup(access_token.id)
276            .await
277            .unwrap()
278            .expect("token not found");
279        assert_eq!(access_token, access_token_lookup);
280
281        // Find the same token by token
282        let access_token_lookup = repo
283            .oauth2_access_token()
284            .find_by_token("aabbcc")
285            .await
286            .unwrap()
287            .expect("token not found");
288        assert_eq!(access_token, access_token_lookup);
289
290        // Lookup a non-existing refresh token
291        let refresh_token = repo
292            .oauth2_refresh_token()
293            .lookup(Ulid::nil())
294            .await
295            .unwrap();
296        assert_eq!(refresh_token, None);
297
298        // Find a non-existing refresh token
299        let refresh_token = repo
300            .oauth2_refresh_token()
301            .find_by_token("aabbcc")
302            .await
303            .unwrap();
304        assert_eq!(refresh_token, None);
305
306        // Create a refresh token
307        let refresh_token = repo
308            .oauth2_refresh_token()
309            .add(
310                &mut rng,
311                &clock,
312                &session,
313                &access_token,
314                "aabbcc".to_owned(),
315            )
316            .await
317            .unwrap();
318
319        // Lookup the same refresh token by id
320        let refresh_token_lookup = repo
321            .oauth2_refresh_token()
322            .lookup(refresh_token.id)
323            .await
324            .unwrap()
325            .expect("refresh token not found");
326        assert_eq!(refresh_token, refresh_token_lookup);
327
328        // Find the same refresh token by token
329        let refresh_token_lookup = repo
330            .oauth2_refresh_token()
331            .find_by_token("aabbcc")
332            .await
333            .unwrap()
334            .expect("refresh token not found");
335        assert_eq!(refresh_token, refresh_token_lookup);
336
337        assert!(access_token.is_valid(clock.now()));
338        clock.advance(Duration::try_minutes(6).unwrap());
339        assert!(!access_token.is_valid(clock.now()));
340
341        // XXX: we might want to create a new access token
342        clock.advance(Duration::try_minutes(-6).unwrap()); // Go back in time
343        assert!(access_token.is_valid(clock.now()));
344
345        // Create a new refresh token to be able to consume the old one
346        let new_refresh_token = repo
347            .oauth2_refresh_token()
348            .add(
349                &mut rng,
350                &clock,
351                &session,
352                &access_token,
353                "ddeeff".to_owned(),
354            )
355            .await
356            .unwrap();
357
358        // Mark the access token as revoked
359        let access_token = repo
360            .oauth2_access_token()
361            .revoke(&clock, access_token)
362            .await
363            .unwrap();
364        assert!(!access_token.is_valid(clock.now()));
365
366        // Mark the refresh token as consumed
367        assert!(refresh_token.is_valid());
368        let refresh_token = repo
369            .oauth2_refresh_token()
370            .consume(&clock, refresh_token, &new_refresh_token)
371            .await
372            .unwrap();
373        assert!(!refresh_token.is_valid());
374
375        // Record the user-agent on the session
376        assert!(session.user_agent.is_none());
377        let session = repo
378            .oauth2_session()
379            .record_user_agent(session, UserAgent::parse("Mozilla/5.0".to_owned()))
380            .await
381            .unwrap();
382        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
383
384        // Reload the session and check the user-agent
385        let session = repo
386            .oauth2_session()
387            .lookup(session.id)
388            .await
389            .unwrap()
390            .expect("session not found");
391        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
392
393        // Mark the session as finished
394        assert!(session.is_valid());
395        let session = repo.oauth2_session().finish(&clock, session).await.unwrap();
396        assert!(!session.is_valid());
397    }
398
399    /// Test the [`OAuth2SessionRepository::list`] and
400    /// [`OAuth2SessionRepository::count`] methods.
401    #[sqlx::test(migrator = "crate::MIGRATOR")]
402    async fn test_list_sessions(pool: PgPool) {
403        let mut rng = ChaChaRng::seed_from_u64(42);
404        let clock = MockClock::default();
405        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
406
407        // Create two users and their corresponding browser sessions
408        let user1 = repo
409            .user()
410            .add(&mut rng, &clock, "alice".to_owned())
411            .await
412            .unwrap();
413        let user1_session = repo
414            .browser_session()
415            .add(&mut rng, &clock, &user1, None)
416            .await
417            .unwrap();
418
419        let user2 = repo
420            .user()
421            .add(&mut rng, &clock, "bob".to_owned())
422            .await
423            .unwrap();
424        let user2_session = repo
425            .browser_session()
426            .add(&mut rng, &clock, &user2, None)
427            .await
428            .unwrap();
429
430        // Create two clients
431        let client1 = repo
432            .oauth2_client()
433            .add(
434                &mut rng,
435                &clock,
436                vec!["https://first.example.com/redirect".parse().unwrap()],
437                None,
438                None,
439                None,
440                vec![GrantType::AuthorizationCode],
441                Some("First client".to_owned()),
442                Some("https://first.example.com/logo.png".parse().unwrap()),
443                Some("https://first.example.com/".parse().unwrap()),
444                Some("https://first.example.com/policy".parse().unwrap()),
445                Some("https://first.example.com/tos".parse().unwrap()),
446                Some("https://first.example.com/jwks.json".parse().unwrap()),
447                None,
448                None,
449                None,
450                None,
451                None,
452                Some("https://first.example.com/login".parse().unwrap()),
453            )
454            .await
455            .unwrap();
456        let client2 = repo
457            .oauth2_client()
458            .add(
459                &mut rng,
460                &clock,
461                vec!["https://second.example.com/redirect".parse().unwrap()],
462                None,
463                None,
464                None,
465                vec![GrantType::AuthorizationCode],
466                Some("Second client".to_owned()),
467                Some("https://second.example.com/logo.png".parse().unwrap()),
468                Some("https://second.example.com/".parse().unwrap()),
469                Some("https://second.example.com/policy".parse().unwrap()),
470                Some("https://second.example.com/tos".parse().unwrap()),
471                Some("https://second.example.com/jwks.json".parse().unwrap()),
472                None,
473                None,
474                None,
475                None,
476                None,
477                Some("https://second.example.com/login".parse().unwrap()),
478            )
479            .await
480            .unwrap();
481
482        let scope = Scope::from_iter([OPENID, EMAIL]);
483        let scope2 = Scope::from_iter([OPENID, PROFILE]);
484
485        // Create two sessions for each user, one with each client
486        // We're moving the clock forward by 1 minute between each session to ensure
487        // we're getting consistent ordering in lists.
488        let session11 = repo
489            .oauth2_session()
490            .add_from_browser_session(&mut rng, &clock, &client1, &user1_session, scope.clone())
491            .await
492            .unwrap();
493        clock.advance(Duration::try_minutes(1).unwrap());
494
495        let session12 = repo
496            .oauth2_session()
497            .add_from_browser_session(&mut rng, &clock, &client1, &user2_session, scope.clone())
498            .await
499            .unwrap();
500        clock.advance(Duration::try_minutes(1).unwrap());
501
502        let session21 = repo
503            .oauth2_session()
504            .add_from_browser_session(&mut rng, &clock, &client2, &user1_session, scope2.clone())
505            .await
506            .unwrap();
507        clock.advance(Duration::try_minutes(1).unwrap());
508
509        let session22 = repo
510            .oauth2_session()
511            .add_from_browser_session(&mut rng, &clock, &client2, &user2_session, scope2.clone())
512            .await
513            .unwrap();
514        clock.advance(Duration::try_minutes(1).unwrap());
515
516        // We're also finishing two of the sessions
517        let session11 = repo
518            .oauth2_session()
519            .finish(&clock, session11)
520            .await
521            .unwrap();
522        let session22 = repo
523            .oauth2_session()
524            .finish(&clock, session22)
525            .await
526            .unwrap();
527
528        let pagination = Pagination::first(10);
529
530        // First, list all the sessions
531        let filter = OAuth2SessionFilter::new().for_any_user();
532        let list = repo
533            .oauth2_session()
534            .list(filter, pagination)
535            .await
536            .unwrap();
537        assert!(!list.has_next_page);
538        assert_eq!(list.edges.len(), 4);
539        assert_eq!(list.edges[0], session11);
540        assert_eq!(list.edges[1], session12);
541        assert_eq!(list.edges[2], session21);
542        assert_eq!(list.edges[3], session22);
543
544        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
545
546        // Now filter for only one user
547        let filter = OAuth2SessionFilter::new().for_user(&user1);
548        let list = repo
549            .oauth2_session()
550            .list(filter, pagination)
551            .await
552            .unwrap();
553        assert!(!list.has_next_page);
554        assert_eq!(list.edges.len(), 2);
555        assert_eq!(list.edges[0], session11);
556        assert_eq!(list.edges[1], session21);
557
558        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
559
560        // Filter for only one client
561        let filter = OAuth2SessionFilter::new().for_client(&client1);
562        let list = repo
563            .oauth2_session()
564            .list(filter, pagination)
565            .await
566            .unwrap();
567        assert!(!list.has_next_page);
568        assert_eq!(list.edges.len(), 2);
569        assert_eq!(list.edges[0], session11);
570        assert_eq!(list.edges[1], session12);
571
572        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
573
574        // Filter for both a user and a client
575        let filter = OAuth2SessionFilter::new()
576            .for_user(&user2)
577            .for_client(&client2);
578        let list = repo
579            .oauth2_session()
580            .list(filter, pagination)
581            .await
582            .unwrap();
583        assert!(!list.has_next_page);
584        assert_eq!(list.edges.len(), 1);
585        assert_eq!(list.edges[0], session22);
586
587        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
588
589        // Filter for active sessions
590        let filter = OAuth2SessionFilter::new().active_only();
591        let list = repo
592            .oauth2_session()
593            .list(filter, pagination)
594            .await
595            .unwrap();
596        assert!(!list.has_next_page);
597        assert_eq!(list.edges.len(), 2);
598        assert_eq!(list.edges[0], session12);
599        assert_eq!(list.edges[1], session21);
600
601        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
602
603        // Filter for finished sessions
604        let filter = OAuth2SessionFilter::new().finished_only();
605        let list = repo
606            .oauth2_session()
607            .list(filter, pagination)
608            .await
609            .unwrap();
610        assert!(!list.has_next_page);
611        assert_eq!(list.edges.len(), 2);
612        assert_eq!(list.edges[0], session11);
613        assert_eq!(list.edges[1], session22);
614
615        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
616
617        // Combine the finished filter with the user filter
618        let filter = OAuth2SessionFilter::new().finished_only().for_user(&user2);
619        let list = repo
620            .oauth2_session()
621            .list(filter, pagination)
622            .await
623            .unwrap();
624        assert!(!list.has_next_page);
625        assert_eq!(list.edges.len(), 1);
626        assert_eq!(list.edges[0], session22);
627
628        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
629
630        // Combine the finished filter with the client filter
631        let filter = OAuth2SessionFilter::new()
632            .finished_only()
633            .for_client(&client2);
634        let list = repo
635            .oauth2_session()
636            .list(filter, pagination)
637            .await
638            .unwrap();
639        assert!(!list.has_next_page);
640        assert_eq!(list.edges.len(), 1);
641        assert_eq!(list.edges[0], session22);
642
643        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
644
645        // Combine the active filter with the user filter
646        let filter = OAuth2SessionFilter::new().active_only().for_user(&user2);
647        let list = repo
648            .oauth2_session()
649            .list(filter, pagination)
650            .await
651            .unwrap();
652        assert!(!list.has_next_page);
653        assert_eq!(list.edges.len(), 1);
654        assert_eq!(list.edges[0], session12);
655
656        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
657
658        // Combine the active filter with the client filter
659        let filter = OAuth2SessionFilter::new()
660            .active_only()
661            .for_client(&client2);
662        let list = repo
663            .oauth2_session()
664            .list(filter, pagination)
665            .await
666            .unwrap();
667        assert!(!list.has_next_page);
668        assert_eq!(list.edges.len(), 1);
669        assert_eq!(list.edges[0], session21);
670
671        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
672
673        // Try the scope filter. We should get all sessions with the "openid" scope
674        let scope = Scope::from_iter([OPENID]);
675        let filter = OAuth2SessionFilter::new().with_scope(&scope);
676        let list = repo
677            .oauth2_session()
678            .list(filter, pagination)
679            .await
680            .unwrap();
681        assert!(!list.has_next_page);
682        assert_eq!(list.edges.len(), 4);
683        assert_eq!(list.edges[0], session11);
684        assert_eq!(list.edges[1], session12);
685        assert_eq!(list.edges[2], session21);
686        assert_eq!(list.edges[3], session22);
687        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
688
689        // We should get all sessions with the "openid" and "email" scope
690        let scope = Scope::from_iter([OPENID, EMAIL]);
691        let filter = OAuth2SessionFilter::new().with_scope(&scope);
692        let list = repo
693            .oauth2_session()
694            .list(filter, pagination)
695            .await
696            .unwrap();
697        assert!(!list.has_next_page);
698        assert_eq!(list.edges.len(), 2);
699        assert_eq!(list.edges[0], session11);
700        assert_eq!(list.edges[1], session12);
701        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
702
703        // Try combining the scope filter with the user filter
704        let filter = OAuth2SessionFilter::new()
705            .with_scope(&scope)
706            .for_user(&user1);
707        let list = repo
708            .oauth2_session()
709            .list(filter, pagination)
710            .await
711            .unwrap();
712        assert_eq!(list.edges.len(), 1);
713        assert_eq!(list.edges[0], session11);
714        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
715
716        // Finish all sessions of a client in batch
717        let affected = repo
718            .oauth2_session()
719            .finish_bulk(
720                &clock,
721                OAuth2SessionFilter::new()
722                    .for_client(&client1)
723                    .active_only(),
724            )
725            .await
726            .unwrap();
727        assert_eq!(affected, 1);
728
729        // We should have 3 finished sessions
730        assert_eq!(
731            repo.oauth2_session()
732                .count(OAuth2SessionFilter::new().finished_only())
733                .await
734                .unwrap(),
735            3
736        );
737
738        // We should have 1 active sessions
739        assert_eq!(
740            repo.oauth2_session()
741                .count(OAuth2SessionFilter::new().active_only())
742                .await
743                .unwrap(),
744            1
745        );
746    }
747
748    /// Test the [`OAuth2DeviceCodeGrantRepository`] implementation
749    #[sqlx::test(migrator = "crate::MIGRATOR")]
750    async fn test_device_code_grant_repository(pool: PgPool) {
751        let mut rng = ChaChaRng::seed_from_u64(42);
752        let clock = MockClock::default();
753        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
754
755        // Provision a client
756        let client = repo
757            .oauth2_client()
758            .add(
759                &mut rng,
760                &clock,
761                vec!["https://example.com/redirect".parse().unwrap()],
762                None,
763                None,
764                None,
765                vec![GrantType::AuthorizationCode],
766                Some("Example".to_owned()),
767                Some("https://example.com/logo.png".parse().unwrap()),
768                Some("https://example.com/".parse().unwrap()),
769                Some("https://example.com/policy".parse().unwrap()),
770                Some("https://example.com/tos".parse().unwrap()),
771                Some("https://example.com/jwks.json".parse().unwrap()),
772                None,
773                None,
774                None,
775                None,
776                None,
777                Some("https://example.com/login".parse().unwrap()),
778            )
779            .await
780            .unwrap();
781
782        // Provision a user
783        let user = repo
784            .user()
785            .add(&mut rng, &clock, "john".to_owned())
786            .await
787            .unwrap();
788
789        // Provision a browser session
790        let browser_session = repo
791            .browser_session()
792            .add(&mut rng, &clock, &user, None)
793            .await
794            .unwrap();
795
796        let user_code = "usercode";
797        let device_code = "devicecode";
798        let scope = Scope::from_iter([OPENID, EMAIL]);
799
800        // Create a device code grant
801        let grant = repo
802            .oauth2_device_code_grant()
803            .add(
804                &mut rng,
805                &clock,
806                OAuth2DeviceCodeGrantParams {
807                    client: &client,
808                    scope: scope.clone(),
809                    device_code: device_code.to_owned(),
810                    user_code: user_code.to_owned(),
811                    expires_in: Duration::try_minutes(5).unwrap(),
812                    ip_address: None,
813                    user_agent: None,
814                },
815            )
816            .await
817            .unwrap();
818
819        assert!(grant.is_pending());
820
821        // Check that we can find the grant by ID
822        let id = grant.id;
823        let lookup = repo.oauth2_device_code_grant().lookup(id).await.unwrap();
824        assert_eq!(lookup.as_ref(), Some(&grant));
825
826        // Check that we can find the grant by device code
827        let lookup = repo
828            .oauth2_device_code_grant()
829            .find_by_device_code(device_code)
830            .await
831            .unwrap();
832        assert_eq!(lookup.as_ref(), Some(&grant));
833
834        // Check that we can find the grant by user code
835        let lookup = repo
836            .oauth2_device_code_grant()
837            .find_by_user_code(user_code)
838            .await
839            .unwrap();
840        assert_eq!(lookup.as_ref(), Some(&grant));
841
842        // Let's mark it as fulfilled
843        let grant = repo
844            .oauth2_device_code_grant()
845            .fulfill(&clock, grant, &browser_session)
846            .await
847            .unwrap();
848        assert!(!grant.is_pending());
849        assert!(grant.is_fulfilled());
850
851        // Check that we can't mark it as rejected now
852        let res = repo
853            .oauth2_device_code_grant()
854            .reject(&clock, grant, &browser_session)
855            .await;
856        assert!(res.is_err());
857
858        // Look it up again
859        let grant = repo
860            .oauth2_device_code_grant()
861            .lookup(id)
862            .await
863            .unwrap()
864            .unwrap();
865
866        // We can't mark it as fulfilled again
867        let res = repo
868            .oauth2_device_code_grant()
869            .fulfill(&clock, grant, &browser_session)
870            .await;
871        assert!(res.is_err());
872
873        // Look it up again
874        let grant = repo
875            .oauth2_device_code_grant()
876            .lookup(id)
877            .await
878            .unwrap()
879            .unwrap();
880
881        // Create an OAuth 2.0 session
882        let session = repo
883            .oauth2_session()
884            .add_from_browser_session(&mut rng, &clock, &client, &browser_session, scope.clone())
885            .await
886            .unwrap();
887
888        // We can mark it as exchanged
889        let grant = repo
890            .oauth2_device_code_grant()
891            .exchange(&clock, grant, &session)
892            .await
893            .unwrap();
894        assert!(!grant.is_pending());
895        assert!(!grant.is_fulfilled());
896        assert!(grant.is_exchanged());
897
898        // We can't mark it as exchanged again
899        let res = repo
900            .oauth2_device_code_grant()
901            .exchange(&clock, grant, &session)
902            .await;
903        assert!(res.is_err());
904
905        // Do a new grant to reject it
906        let grant = repo
907            .oauth2_device_code_grant()
908            .add(
909                &mut rng,
910                &clock,
911                OAuth2DeviceCodeGrantParams {
912                    client: &client,
913                    scope: scope.clone(),
914                    device_code: "second_devicecode".to_owned(),
915                    user_code: "second_usercode".to_owned(),
916                    expires_in: Duration::try_minutes(5).unwrap(),
917                    ip_address: None,
918                    user_agent: None,
919                },
920            )
921            .await
922            .unwrap();
923
924        let id = grant.id;
925
926        // We can mark it as rejected
927        let grant = repo
928            .oauth2_device_code_grant()
929            .reject(&clock, grant, &browser_session)
930            .await
931            .unwrap();
932        assert!(!grant.is_pending());
933        assert!(grant.is_rejected());
934
935        // We can't mark it as rejected again
936        let res = repo
937            .oauth2_device_code_grant()
938            .reject(&clock, grant, &browser_session)
939            .await;
940        assert!(res.is_err());
941
942        // Look it up again
943        let grant = repo
944            .oauth2_device_code_grant()
945            .lookup(id)
946            .await
947            .unwrap()
948            .unwrap();
949
950        // We can't mark it as fulfilled
951        let res = repo
952            .oauth2_device_code_grant()
953            .fulfill(&clock, grant, &browser_session)
954            .await;
955        assert!(res.is_err());
956
957        // Look it up again
958        let grant = repo
959            .oauth2_device_code_grant()
960            .lookup(id)
961            .await
962            .unwrap()
963            .unwrap();
964
965        // We can't mark it as exchanged
966        let res = repo
967            .oauth2_device_code_grant()
968            .exchange(&clock, grant, &session)
969            .await;
970        assert!(res.is_err());
971    }
972}