1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
13 UserRegistrationPassword, UserRegistrationToken,
14};
15use mas_storage::user::UserRegistrationRepository;
16use rand::RngCore;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use url::Url;
20use uuid::Uuid;
21
22use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
23
24pub struct PgUserRegistrationRepository<'c> {
27 conn: &'c mut PgConnection,
28}
29
30impl<'c> PgUserRegistrationRepository<'c> {
31 pub fn new(conn: &'c mut PgConnection) -> Self {
34 Self { conn }
35 }
36}
37
38struct UserRegistrationLookup {
39 user_registration_id: Uuid,
40 ip_address: Option<IpAddr>,
41 user_agent: Option<String>,
42 post_auth_action: Option<serde_json::Value>,
43 username: String,
44 display_name: Option<String>,
45 terms_url: Option<String>,
46 email_authentication_id: Option<Uuid>,
47 user_registration_token_id: Option<Uuid>,
48 hashed_password: Option<String>,
49 hashed_password_version: Option<i32>,
50 upstream_oauth_authorization_session_id: Option<Uuid>,
51 created_at: DateTime<Utc>,
52 completed_at: Option<DateTime<Utc>>,
53}
54
55impl TryFrom<UserRegistrationLookup> for UserRegistration {
56 type Error = DatabaseInconsistencyError;
57
58 fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
59 let id = Ulid::from(value.user_registration_id);
60
61 let password = match (value.hashed_password, value.hashed_password_version) {
62 (Some(hashed_password), Some(version)) => {
63 let version = version.try_into().map_err(|e| {
64 DatabaseInconsistencyError::on("user_registrations")
65 .column("hashed_password_version")
66 .row(id)
67 .source(e)
68 })?;
69
70 Some(UserRegistrationPassword {
71 hashed_password,
72 version,
73 })
74 }
75 (None, None) => None,
76 _ => {
77 return Err(DatabaseInconsistencyError::on("user_registrations")
78 .column("hashed_password")
79 .row(id));
80 }
81 };
82
83 let terms_url = value
84 .terms_url
85 .map(|u| u.parse())
86 .transpose()
87 .map_err(|e| {
88 DatabaseInconsistencyError::on("user_registrations")
89 .column("terms_url")
90 .row(id)
91 .source(e)
92 })?;
93
94 Ok(UserRegistration {
95 id,
96 ip_address: value.ip_address,
97 user_agent: value.user_agent,
98 post_auth_action: value.post_auth_action,
99 username: value.username,
100 display_name: value.display_name,
101 terms_url,
102 email_authentication_id: value.email_authentication_id.map(Ulid::from),
103 user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
104 password,
105 upstream_oauth_authorization_session_id: value
106 .upstream_oauth_authorization_session_id
107 .map(Ulid::from),
108 created_at: value.created_at,
109 completed_at: value.completed_at,
110 })
111 }
112}
113
114#[async_trait]
115impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
116 type Error = DatabaseError;
117
118 #[tracing::instrument(
119 name = "db.user_registration.lookup",
120 skip_all,
121 fields(
122 db.query.text,
123 user_registration.id = %id,
124 ),
125 err,
126 )]
127 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
128 let res = sqlx::query_as!(
129 UserRegistrationLookup,
130 r#"
131 SELECT user_registration_id
132 , ip_address as "ip_address: IpAddr"
133 , user_agent
134 , post_auth_action
135 , username
136 , display_name
137 , terms_url
138 , email_authentication_id
139 , user_registration_token_id
140 , hashed_password
141 , hashed_password_version
142 , upstream_oauth_authorization_session_id
143 , created_at
144 , completed_at
145 FROM user_registrations
146 WHERE user_registration_id = $1
147 "#,
148 Uuid::from(id),
149 )
150 .traced()
151 .fetch_optional(&mut *self.conn)
152 .await?;
153
154 let Some(res) = res else { return Ok(None) };
155
156 Ok(Some(res.try_into()?))
157 }
158
159 #[tracing::instrument(
160 name = "db.user_registration.add",
161 skip_all,
162 fields(
163 db.query.text,
164 user_registration.id,
165 ),
166 err,
167 )]
168 async fn add(
169 &mut self,
170 rng: &mut (dyn RngCore + Send),
171 clock: &dyn Clock,
172 username: String,
173 ip_address: Option<IpAddr>,
174 user_agent: Option<String>,
175 post_auth_action: Option<serde_json::Value>,
176 ) -> Result<UserRegistration, Self::Error> {
177 let created_at = clock.now();
178 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
179 tracing::Span::current().record("user_registration.id", tracing::field::display(id));
180
181 sqlx::query!(
182 r#"
183 INSERT INTO user_registrations
184 ( user_registration_id
185 , ip_address
186 , user_agent
187 , post_auth_action
188 , username
189 , created_at
190 )
191 VALUES ($1, $2, $3, $4, $5, $6)
192 "#,
193 Uuid::from(id),
194 ip_address as Option<IpAddr>,
195 user_agent.as_deref(),
196 post_auth_action,
197 username,
198 created_at,
199 )
200 .traced()
201 .execute(&mut *self.conn)
202 .await?;
203
204 Ok(UserRegistration {
205 id,
206 ip_address,
207 user_agent,
208 post_auth_action,
209 created_at,
210 completed_at: None,
211 username,
212 display_name: None,
213 terms_url: None,
214 email_authentication_id: None,
215 user_registration_token_id: None,
216 password: None,
217 upstream_oauth_authorization_session_id: None,
218 })
219 }
220
221 #[tracing::instrument(
222 name = "db.user_registration.set_display_name",
223 skip_all,
224 fields(
225 db.query.text,
226 user_registration.id = %user_registration.id,
227 user_registration.display_name = display_name,
228 ),
229 err,
230 )]
231 async fn set_display_name(
232 &mut self,
233 mut user_registration: UserRegistration,
234 display_name: String,
235 ) -> Result<UserRegistration, Self::Error> {
236 let res = sqlx::query!(
237 r#"
238 UPDATE user_registrations
239 SET display_name = $2
240 WHERE user_registration_id = $1 AND completed_at IS NULL
241 "#,
242 Uuid::from(user_registration.id),
243 display_name,
244 )
245 .traced()
246 .execute(&mut *self.conn)
247 .await?;
248
249 DatabaseError::ensure_affected_rows(&res, 1)?;
250
251 user_registration.display_name = Some(display_name);
252
253 Ok(user_registration)
254 }
255
256 #[tracing::instrument(
257 name = "db.user_registration.set_terms_url",
258 skip_all,
259 fields(
260 db.query.text,
261 user_registration.id = %user_registration.id,
262 user_registration.terms_url = %terms_url,
263 ),
264 err,
265 )]
266 async fn set_terms_url(
267 &mut self,
268 mut user_registration: UserRegistration,
269 terms_url: Url,
270 ) -> Result<UserRegistration, Self::Error> {
271 let res = sqlx::query!(
272 r#"
273 UPDATE user_registrations
274 SET terms_url = $2
275 WHERE user_registration_id = $1 AND completed_at IS NULL
276 "#,
277 Uuid::from(user_registration.id),
278 terms_url.as_str(),
279 )
280 .traced()
281 .execute(&mut *self.conn)
282 .await?;
283
284 DatabaseError::ensure_affected_rows(&res, 1)?;
285
286 user_registration.terms_url = Some(terms_url);
287
288 Ok(user_registration)
289 }
290
291 #[tracing::instrument(
292 name = "db.user_registration.set_email_authentication",
293 skip_all,
294 fields(
295 db.query.text,
296 %user_registration.id,
297 %user_email_authentication.id,
298 %user_email_authentication.email,
299 ),
300 err,
301 )]
302 async fn set_email_authentication(
303 &mut self,
304 mut user_registration: UserRegistration,
305 user_email_authentication: &UserEmailAuthentication,
306 ) -> Result<UserRegistration, Self::Error> {
307 let res = sqlx::query!(
308 r#"
309 UPDATE user_registrations
310 SET email_authentication_id = $2
311 WHERE user_registration_id = $1 AND completed_at IS NULL
312 "#,
313 Uuid::from(user_registration.id),
314 Uuid::from(user_email_authentication.id),
315 )
316 .traced()
317 .execute(&mut *self.conn)
318 .await?;
319
320 DatabaseError::ensure_affected_rows(&res, 1)?;
321
322 user_registration.email_authentication_id = Some(user_email_authentication.id);
323
324 Ok(user_registration)
325 }
326
327 #[tracing::instrument(
328 name = "db.user_registration.set_password",
329 skip_all,
330 fields(
331 db.query.text,
332 user_registration.id = %user_registration.id,
333 user_registration.hashed_password = hashed_password,
334 user_registration.hashed_password_version = version,
335 ),
336 err,
337 )]
338 async fn set_password(
339 &mut self,
340 mut user_registration: UserRegistration,
341 hashed_password: String,
342 version: u16,
343 ) -> Result<UserRegistration, Self::Error> {
344 let res = sqlx::query!(
345 r#"
346 UPDATE user_registrations
347 SET hashed_password = $2, hashed_password_version = $3
348 WHERE user_registration_id = $1 AND completed_at IS NULL
349 "#,
350 Uuid::from(user_registration.id),
351 hashed_password,
352 i32::from(version),
353 )
354 .traced()
355 .execute(&mut *self.conn)
356 .await?;
357
358 DatabaseError::ensure_affected_rows(&res, 1)?;
359
360 user_registration.password = Some(UserRegistrationPassword {
361 hashed_password,
362 version,
363 });
364
365 Ok(user_registration)
366 }
367
368 #[tracing::instrument(
369 name = "db.user_registration.set_registration_token",
370 skip_all,
371 fields(
372 db.query.text,
373 %user_registration.id,
374 %user_registration_token.id,
375 ),
376 err,
377 )]
378 async fn set_registration_token(
379 &mut self,
380 mut user_registration: UserRegistration,
381 user_registration_token: &UserRegistrationToken,
382 ) -> Result<UserRegistration, Self::Error> {
383 let res = sqlx::query!(
384 r#"
385 UPDATE user_registrations
386 SET user_registration_token_id = $2
387 WHERE user_registration_id = $1 AND completed_at IS NULL
388 "#,
389 Uuid::from(user_registration.id),
390 Uuid::from(user_registration_token.id),
391 )
392 .traced()
393 .execute(&mut *self.conn)
394 .await?;
395
396 DatabaseError::ensure_affected_rows(&res, 1)?;
397
398 user_registration.user_registration_token_id = Some(user_registration_token.id);
399
400 Ok(user_registration)
401 }
402
403 #[tracing::instrument(
404 name = "db.user_registration.set_upstream_oauth_authorization_session",
405 skip_all,
406 fields(
407 db.query.text,
408 %user_registration.id,
409 %upstream_oauth_authorization_session.id,
410 ),
411 err,
412 )]
413 async fn set_upstream_oauth_authorization_session(
414 &mut self,
415 mut user_registration: UserRegistration,
416 upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
417 ) -> Result<UserRegistration, Self::Error> {
418 let res = sqlx::query!(
419 r#"
420 UPDATE user_registrations
421 SET upstream_oauth_authorization_session_id = $2
422 WHERE user_registration_id = $1 AND completed_at IS NULL
423 "#,
424 Uuid::from(user_registration.id),
425 Uuid::from(upstream_oauth_authorization_session.id),
426 )
427 .traced()
428 .execute(&mut *self.conn)
429 .await?;
430
431 DatabaseError::ensure_affected_rows(&res, 1)?;
432
433 user_registration.upstream_oauth_authorization_session_id =
434 Some(upstream_oauth_authorization_session.id);
435
436 Ok(user_registration)
437 }
438
439 #[tracing::instrument(
440 name = "db.user_registration.complete",
441 skip_all,
442 fields(
443 db.query.text,
444 user_registration.id = %user_registration.id,
445 ),
446 err,
447 )]
448 async fn complete(
449 &mut self,
450 clock: &dyn Clock,
451 mut user_registration: UserRegistration,
452 ) -> Result<UserRegistration, Self::Error> {
453 let completed_at = clock.now();
454 let res = sqlx::query!(
455 r#"
456 UPDATE user_registrations
457 SET completed_at = $2
458 WHERE user_registration_id = $1 AND completed_at IS NULL
459 "#,
460 Uuid::from(user_registration.id),
461 completed_at,
462 )
463 .traced()
464 .execute(&mut *self.conn)
465 .await?;
466
467 DatabaseError::ensure_affected_rows(&res, 1)?;
468
469 user_registration.completed_at = Some(completed_at);
470
471 Ok(user_registration)
472 }
473
474 #[tracing::instrument(
475 name = "db.user_registration.cleanup",
476 skip_all,
477 fields(
478 db.query.text,
479 ),
480 err,
481 )]
482 async fn cleanup(
483 &mut self,
484 since: Option<Ulid>,
485 until: Ulid,
486 limit: usize,
487 ) -> Result<(usize, Option<Ulid>), Self::Error> {
488 let res = sqlx::query_scalar!(
493 r#"
494 WITH to_delete AS (
495 SELECT user_registration_id
496 FROM user_registrations
497 WHERE ($1::uuid IS NULL OR user_registration_id > $1)
498 AND user_registration_id <= $2
499 ORDER BY user_registration_id
500 LIMIT $3
501 )
502 DELETE FROM user_registrations
503 USING to_delete
504 WHERE user_registrations.user_registration_id = to_delete.user_registration_id
505 RETURNING user_registrations.user_registration_id
506 "#,
507 since.map(Uuid::from),
508 Uuid::from(until),
509 i64::try_from(limit).unwrap_or(i64::MAX)
510 )
511 .traced()
512 .fetch_all(&mut *self.conn)
513 .await?;
514
515 let count = res.len();
516 let max_id = res.into_iter().max();
517
518 Ok((count, max_id.map(Ulid::from)))
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use std::net::{IpAddr, Ipv4Addr};
525
526 use mas_data_model::{
527 Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
528 UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
529 UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock,
530 };
531 use mas_iana::jose::JsonWebSignatureAlg;
532 use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
533 use oauth2_types::scope::Scope;
534 use rand::SeedableRng;
535 use rand_chacha::ChaChaRng;
536 use sqlx::PgPool;
537
538 use crate::PgRepository;
539
540 #[sqlx::test(migrator = "crate::MIGRATOR")]
541 async fn test_create_lookup_complete(pool: PgPool) {
542 let mut rng = ChaChaRng::seed_from_u64(42);
543 let clock = MockClock::default();
544
545 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
546
547 let registration = repo
548 .user_registration()
549 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
550 .await
551 .unwrap();
552
553 assert_eq!(registration.created_at, clock.now());
554 assert_eq!(registration.completed_at, None);
555 assert_eq!(registration.username, "alice");
556 assert_eq!(registration.display_name, None);
557 assert_eq!(registration.terms_url, None);
558 assert_eq!(registration.email_authentication_id, None);
559 assert_eq!(registration.password, None);
560 assert_eq!(registration.user_agent, None);
561 assert_eq!(registration.ip_address, None);
562 assert_eq!(registration.post_auth_action, None);
563
564 let lookup = repo
565 .user_registration()
566 .lookup(registration.id)
567 .await
568 .unwrap()
569 .unwrap();
570
571 assert_eq!(lookup.id, registration.id);
572 assert_eq!(lookup.created_at, registration.created_at);
573 assert_eq!(lookup.completed_at, registration.completed_at);
574 assert_eq!(lookup.username, registration.username);
575 assert_eq!(lookup.display_name, registration.display_name);
576 assert_eq!(lookup.terms_url, registration.terms_url);
577 assert_eq!(
578 lookup.email_authentication_id,
579 registration.email_authentication_id
580 );
581 assert_eq!(lookup.password, registration.password);
582 assert_eq!(lookup.user_agent, registration.user_agent);
583 assert_eq!(lookup.ip_address, registration.ip_address);
584 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
585
586 let registration = repo
588 .user_registration()
589 .complete(&clock, registration)
590 .await
591 .unwrap();
592 assert_eq!(registration.completed_at, Some(clock.now()));
593
594 let lookup = repo
596 .user_registration()
597 .lookup(registration.id)
598 .await
599 .unwrap()
600 .unwrap();
601 assert_eq!(lookup.completed_at, registration.completed_at);
602
603 let res = repo
605 .user_registration()
606 .complete(&clock, registration)
607 .await;
608 assert!(res.is_err());
609 }
610
611 #[sqlx::test(migrator = "crate::MIGRATOR")]
612 async fn test_create_useragent_ipaddress(pool: PgPool) {
613 let mut rng = ChaChaRng::seed_from_u64(42);
614 let clock = MockClock::default();
615
616 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
617
618 let registration = repo
619 .user_registration()
620 .add(
621 &mut rng,
622 &clock,
623 "alice".to_owned(),
624 Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
625 Some("Mozilla/5.0".to_owned()),
626 Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
627 )
628 .await
629 .unwrap();
630
631 assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
632 assert_eq!(
633 registration.ip_address,
634 Some(IpAddr::V4(Ipv4Addr::LOCALHOST))
635 );
636 assert_eq!(
637 registration.post_auth_action,
638 Some(
639 serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
640 )
641 );
642
643 let lookup = repo
644 .user_registration()
645 .lookup(registration.id)
646 .await
647 .unwrap()
648 .unwrap();
649
650 assert_eq!(lookup.user_agent, registration.user_agent);
651 assert_eq!(lookup.ip_address, registration.ip_address);
652 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
653 }
654
655 #[sqlx::test(migrator = "crate::MIGRATOR")]
656 async fn test_set_display_name(pool: PgPool) {
657 let mut rng = ChaChaRng::seed_from_u64(42);
658 let clock = MockClock::default();
659
660 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
661
662 let registration = repo
663 .user_registration()
664 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
665 .await
666 .unwrap();
667
668 assert_eq!(registration.display_name, None);
669
670 let registration = repo
671 .user_registration()
672 .set_display_name(registration, "Alice".to_owned())
673 .await
674 .unwrap();
675
676 assert_eq!(registration.display_name, Some("Alice".to_owned()));
677
678 let lookup = repo
679 .user_registration()
680 .lookup(registration.id)
681 .await
682 .unwrap()
683 .unwrap();
684
685 assert_eq!(lookup.display_name, registration.display_name);
686
687 let registration = repo
689 .user_registration()
690 .set_display_name(registration, "Bob".to_owned())
691 .await
692 .unwrap();
693
694 assert_eq!(registration.display_name, Some("Bob".to_owned()));
695
696 let lookup = repo
697 .user_registration()
698 .lookup(registration.id)
699 .await
700 .unwrap()
701 .unwrap();
702
703 assert_eq!(lookup.display_name, registration.display_name);
704
705 let registration = repo
707 .user_registration()
708 .complete(&clock, registration)
709 .await
710 .unwrap();
711
712 let res = repo
713 .user_registration()
714 .set_display_name(registration, "Charlie".to_owned())
715 .await;
716 assert!(res.is_err());
717 }
718
719 #[sqlx::test(migrator = "crate::MIGRATOR")]
720 async fn test_set_terms_url(pool: PgPool) {
721 let mut rng = ChaChaRng::seed_from_u64(42);
722 let clock = MockClock::default();
723
724 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
725
726 let registration = repo
727 .user_registration()
728 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
729 .await
730 .unwrap();
731
732 assert_eq!(registration.terms_url, None);
733
734 let registration = repo
735 .user_registration()
736 .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
737 .await
738 .unwrap();
739
740 assert_eq!(
741 registration.terms_url,
742 Some("https://example.com/terms".parse().unwrap())
743 );
744
745 let lookup = repo
746 .user_registration()
747 .lookup(registration.id)
748 .await
749 .unwrap()
750 .unwrap();
751
752 assert_eq!(lookup.terms_url, registration.terms_url);
753
754 let registration = repo
756 .user_registration()
757 .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
758 .await
759 .unwrap();
760
761 assert_eq!(
762 registration.terms_url,
763 Some("https://example.com/terms2".parse().unwrap())
764 );
765
766 let lookup = repo
767 .user_registration()
768 .lookup(registration.id)
769 .await
770 .unwrap()
771 .unwrap();
772
773 assert_eq!(lookup.terms_url, registration.terms_url);
774
775 let registration = repo
777 .user_registration()
778 .complete(&clock, registration)
779 .await
780 .unwrap();
781
782 let res = repo
783 .user_registration()
784 .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
785 .await;
786 assert!(res.is_err());
787 }
788
789 #[sqlx::test(migrator = "crate::MIGRATOR")]
790 async fn test_set_email_authentication(pool: PgPool) {
791 let mut rng = ChaChaRng::seed_from_u64(42);
792 let clock = MockClock::default();
793
794 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
795
796 let registration = repo
797 .user_registration()
798 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
799 .await
800 .unwrap();
801
802 assert_eq!(registration.email_authentication_id, None);
803
804 let authentication = repo
805 .user_email()
806 .add_authentication_for_registration(
807 &mut rng,
808 &clock,
809 "alice@example.com".to_owned(),
810 ®istration,
811 )
812 .await
813 .unwrap();
814
815 let registration = repo
816 .user_registration()
817 .set_email_authentication(registration, &authentication)
818 .await
819 .unwrap();
820
821 assert_eq!(
822 registration.email_authentication_id,
823 Some(authentication.id)
824 );
825
826 let lookup = repo
827 .user_registration()
828 .lookup(registration.id)
829 .await
830 .unwrap()
831 .unwrap();
832
833 assert_eq!(
834 lookup.email_authentication_id,
835 registration.email_authentication_id
836 );
837
838 let registration = repo
840 .user_registration()
841 .set_email_authentication(registration, &authentication)
842 .await
843 .unwrap();
844
845 assert_eq!(
846 registration.email_authentication_id,
847 Some(authentication.id)
848 );
849
850 let lookup = repo
851 .user_registration()
852 .lookup(registration.id)
853 .await
854 .unwrap()
855 .unwrap();
856
857 assert_eq!(
858 lookup.email_authentication_id,
859 registration.email_authentication_id
860 );
861
862 let registration = repo
864 .user_registration()
865 .complete(&clock, registration)
866 .await
867 .unwrap();
868
869 let res = repo
870 .user_registration()
871 .set_email_authentication(registration, &authentication)
872 .await;
873 assert!(res.is_err());
874 }
875
876 #[sqlx::test(migrator = "crate::MIGRATOR")]
877 async fn test_set_password(pool: PgPool) {
878 let mut rng = ChaChaRng::seed_from_u64(42);
879 let clock = MockClock::default();
880
881 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
882
883 let registration = repo
884 .user_registration()
885 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
886 .await
887 .unwrap();
888
889 assert_eq!(registration.password, None);
890
891 let registration = repo
892 .user_registration()
893 .set_password(registration, "fakehashedpassword".to_owned(), 1)
894 .await
895 .unwrap();
896
897 assert_eq!(
898 registration.password,
899 Some(UserRegistrationPassword {
900 hashed_password: "fakehashedpassword".to_owned(),
901 version: 1,
902 })
903 );
904
905 let lookup = repo
906 .user_registration()
907 .lookup(registration.id)
908 .await
909 .unwrap()
910 .unwrap();
911
912 assert_eq!(lookup.password, registration.password);
913
914 let registration = repo
916 .user_registration()
917 .set_password(registration, "fakehashedpassword2".to_owned(), 2)
918 .await
919 .unwrap();
920
921 assert_eq!(
922 registration.password,
923 Some(UserRegistrationPassword {
924 hashed_password: "fakehashedpassword2".to_owned(),
925 version: 2,
926 })
927 );
928
929 let lookup = repo
930 .user_registration()
931 .lookup(registration.id)
932 .await
933 .unwrap()
934 .unwrap();
935
936 assert_eq!(lookup.password, registration.password);
937
938 let registration = repo
940 .user_registration()
941 .complete(&clock, registration)
942 .await
943 .unwrap();
944
945 let res = repo
946 .user_registration()
947 .set_password(registration, "fakehashedpassword3".to_owned(), 3)
948 .await;
949 assert!(res.is_err());
950 }
951
952 #[sqlx::test(migrator = "crate::MIGRATOR")]
953 async fn test_set_upstream_oauth_session(pool: PgPool) {
954 let mut rng = ChaChaRng::seed_from_u64(42);
955 let clock = MockClock::default();
956
957 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
958
959 let registration = repo
960 .user_registration()
961 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
962 .await
963 .unwrap();
964
965 assert_eq!(registration.upstream_oauth_authorization_session_id, None);
966
967 let provider = repo
968 .upstream_oauth_provider()
969 .add(
970 &mut rng,
971 &clock,
972 UpstreamOAuthProviderParams {
973 issuer: Some("https://example.com/".to_owned()),
974 human_name: Some("Example Ltd.".to_owned()),
975 brand_name: None,
976 scope: Scope::from_iter([oauth2_types::scope::OPENID]),
977 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
978 token_endpoint_signing_alg: None,
979 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
980 client_id: "client".to_owned(),
981 encrypted_client_secret: None,
982 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
983 authorization_endpoint_override: None,
984 token_endpoint_override: None,
985 userinfo_endpoint_override: None,
986 fetch_userinfo: false,
987 userinfo_signed_response_alg: None,
988 jwks_uri_override: None,
989 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
990 pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
991 response_mode: None,
992 additional_authorization_parameters: Vec::new(),
993 forward_login_hint: false,
994 ui_order: 0,
995 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
996 },
997 )
998 .await
999 .unwrap();
1000
1001 let session = repo
1002 .upstream_oauth_session()
1003 .add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
1004 .await
1005 .unwrap();
1006
1007 let registration = repo
1008 .user_registration()
1009 .set_upstream_oauth_authorization_session(registration, &session)
1010 .await
1011 .unwrap();
1012
1013 assert_eq!(
1014 registration.upstream_oauth_authorization_session_id,
1015 Some(session.id)
1016 );
1017
1018 let lookup = repo
1019 .user_registration()
1020 .lookup(registration.id)
1021 .await
1022 .unwrap()
1023 .unwrap();
1024
1025 assert_eq!(
1026 lookup.upstream_oauth_authorization_session_id,
1027 registration.upstream_oauth_authorization_session_id
1028 );
1029
1030 let registration = repo
1032 .user_registration()
1033 .set_upstream_oauth_authorization_session(registration, &session)
1034 .await
1035 .unwrap();
1036
1037 assert_eq!(
1038 registration.upstream_oauth_authorization_session_id,
1039 Some(session.id)
1040 );
1041
1042 let lookup = repo
1043 .user_registration()
1044 .lookup(registration.id)
1045 .await
1046 .unwrap()
1047 .unwrap();
1048
1049 assert_eq!(
1050 lookup.upstream_oauth_authorization_session_id,
1051 registration.upstream_oauth_authorization_session_id
1052 );
1053
1054 let registration = repo
1056 .user_registration()
1057 .complete(&clock, registration)
1058 .await
1059 .unwrap();
1060
1061 let res = repo
1062 .user_registration()
1063 .set_upstream_oauth_authorization_session(registration, &session)
1064 .await;
1065 assert!(res.is_err());
1066 }
1067}