1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 UserAgent, UserEmailAuthentication, UserRegistration, UserRegistrationPassword,
12};
13use mas_storage::{Clock, user::UserRegistrationRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use url::Url;
18use uuid::Uuid;
19
20use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
21
22pub struct PgUserRegistrationRepository<'c> {
25 conn: &'c mut PgConnection,
26}
27
28impl<'c> PgUserRegistrationRepository<'c> {
29 pub fn new(conn: &'c mut PgConnection) -> Self {
32 Self { conn }
33 }
34}
35
36struct UserRegistrationLookup {
37 user_registration_id: Uuid,
38 ip_address: Option<IpAddr>,
39 user_agent: Option<String>,
40 post_auth_action: Option<serde_json::Value>,
41 username: String,
42 display_name: Option<String>,
43 terms_url: Option<String>,
44 email_authentication_id: Option<Uuid>,
45 hashed_password: Option<String>,
46 hashed_password_version: Option<i32>,
47 created_at: DateTime<Utc>,
48 completed_at: Option<DateTime<Utc>>,
49}
50
51impl TryFrom<UserRegistrationLookup> for UserRegistration {
52 type Error = DatabaseInconsistencyError;
53
54 fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
55 let id = Ulid::from(value.user_registration_id);
56 let user_agent = value.user_agent.map(UserAgent::parse);
57
58 let password = match (value.hashed_password, value.hashed_password_version) {
59 (Some(hashed_password), Some(version)) => {
60 let version = version.try_into().map_err(|e| {
61 DatabaseInconsistencyError::on("user_registrations")
62 .column("hashed_password_version")
63 .row(id)
64 .source(e)
65 })?;
66
67 Some(UserRegistrationPassword {
68 hashed_password,
69 version,
70 })
71 }
72 (None, None) => None,
73 _ => {
74 return Err(DatabaseInconsistencyError::on("user_registrations")
75 .column("hashed_password")
76 .row(id));
77 }
78 };
79
80 let terms_url = value
81 .terms_url
82 .map(|u| u.parse())
83 .transpose()
84 .map_err(|e| {
85 DatabaseInconsistencyError::on("user_registrations")
86 .column("terms_url")
87 .row(id)
88 .source(e)
89 })?;
90
91 Ok(UserRegistration {
92 id,
93 ip_address: value.ip_address,
94 user_agent,
95 post_auth_action: value.post_auth_action,
96 username: value.username,
97 display_name: value.display_name,
98 terms_url,
99 email_authentication_id: value.email_authentication_id.map(Ulid::from),
100 password,
101 created_at: value.created_at,
102 completed_at: value.completed_at,
103 })
104 }
105}
106
107#[async_trait]
108impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
109 type Error = DatabaseError;
110
111 #[tracing::instrument(
112 name = "db.user_registration.lookup",
113 skip_all,
114 fields(
115 db.query.text,
116 user_registration.id = %id,
117 ),
118 err,
119 )]
120 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
121 let res = sqlx::query_as!(
122 UserRegistrationLookup,
123 r#"
124 SELECT user_registration_id
125 , ip_address as "ip_address: IpAddr"
126 , user_agent
127 , post_auth_action
128 , username
129 , display_name
130 , terms_url
131 , email_authentication_id
132 , hashed_password
133 , hashed_password_version
134 , created_at
135 , completed_at
136 FROM user_registrations
137 WHERE user_registration_id = $1
138 "#,
139 Uuid::from(id),
140 )
141 .traced()
142 .fetch_optional(&mut *self.conn)
143 .await?;
144
145 let Some(res) = res else { return Ok(None) };
146
147 Ok(Some(res.try_into()?))
148 }
149
150 #[tracing::instrument(
151 name = "db.user_registration.add",
152 skip_all,
153 fields(
154 db.query.text,
155 user_registration.id,
156 ),
157 err,
158 )]
159 async fn add(
160 &mut self,
161 rng: &mut (dyn RngCore + Send),
162 clock: &dyn Clock,
163 username: String,
164 ip_address: Option<IpAddr>,
165 user_agent: Option<UserAgent>,
166 post_auth_action: Option<serde_json::Value>,
167 ) -> Result<UserRegistration, Self::Error> {
168 let created_at = clock.now();
169 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
170 tracing::Span::current().record("user_registration.id", tracing::field::display(id));
171
172 sqlx::query!(
173 r#"
174 INSERT INTO user_registrations
175 ( user_registration_id
176 , ip_address
177 , user_agent
178 , post_auth_action
179 , username
180 , created_at
181 )
182 VALUES ($1, $2, $3, $4, $5, $6)
183 "#,
184 Uuid::from(id),
185 ip_address as Option<IpAddr>,
186 user_agent.as_deref(),
187 post_auth_action,
188 username,
189 created_at,
190 )
191 .traced()
192 .execute(&mut *self.conn)
193 .await?;
194
195 Ok(UserRegistration {
196 id,
197 ip_address,
198 user_agent,
199 post_auth_action,
200 created_at,
201 completed_at: None,
202 username,
203 display_name: None,
204 terms_url: None,
205 email_authentication_id: None,
206 password: None,
207 })
208 }
209
210 #[tracing::instrument(
211 name = "db.user_registration.set_display_name",
212 skip_all,
213 fields(
214 db.query.text,
215 user_registration.id = %user_registration.id,
216 user_registration.display_name = display_name,
217 ),
218 err,
219 )]
220 async fn set_display_name(
221 &mut self,
222 mut user_registration: UserRegistration,
223 display_name: String,
224 ) -> Result<UserRegistration, Self::Error> {
225 let res = sqlx::query!(
226 r#"
227 UPDATE user_registrations
228 SET display_name = $2
229 WHERE user_registration_id = $1 AND completed_at IS NULL
230 "#,
231 Uuid::from(user_registration.id),
232 display_name,
233 )
234 .traced()
235 .execute(&mut *self.conn)
236 .await?;
237
238 DatabaseError::ensure_affected_rows(&res, 1)?;
239
240 user_registration.display_name = Some(display_name);
241
242 Ok(user_registration)
243 }
244
245 #[tracing::instrument(
246 name = "db.user_registration.set_terms_url",
247 skip_all,
248 fields(
249 db.query.text,
250 user_registration.id = %user_registration.id,
251 user_registration.terms_url = %terms_url,
252 ),
253 err,
254 )]
255 async fn set_terms_url(
256 &mut self,
257 mut user_registration: UserRegistration,
258 terms_url: Url,
259 ) -> Result<UserRegistration, Self::Error> {
260 let res = sqlx::query!(
261 r#"
262 UPDATE user_registrations
263 SET terms_url = $2
264 WHERE user_registration_id = $1 AND completed_at IS NULL
265 "#,
266 Uuid::from(user_registration.id),
267 terms_url.as_str(),
268 )
269 .traced()
270 .execute(&mut *self.conn)
271 .await?;
272
273 DatabaseError::ensure_affected_rows(&res, 1)?;
274
275 user_registration.terms_url = Some(terms_url);
276
277 Ok(user_registration)
278 }
279
280 #[tracing::instrument(
281 name = "db.user_registration.set_email_authentication",
282 skip_all,
283 fields(
284 db.query.text,
285 %user_registration.id,
286 %user_email_authentication.id,
287 %user_email_authentication.email,
288 ),
289 err,
290 )]
291 async fn set_email_authentication(
292 &mut self,
293 mut user_registration: UserRegistration,
294 user_email_authentication: &UserEmailAuthentication,
295 ) -> Result<UserRegistration, Self::Error> {
296 let res = sqlx::query!(
297 r#"
298 UPDATE user_registrations
299 SET email_authentication_id = $2
300 WHERE user_registration_id = $1 AND completed_at IS NULL
301 "#,
302 Uuid::from(user_registration.id),
303 Uuid::from(user_email_authentication.id),
304 )
305 .traced()
306 .execute(&mut *self.conn)
307 .await?;
308
309 DatabaseError::ensure_affected_rows(&res, 1)?;
310
311 user_registration.email_authentication_id = Some(user_email_authentication.id);
312
313 Ok(user_registration)
314 }
315
316 #[tracing::instrument(
317 name = "db.user_registration.set_password",
318 skip_all,
319 fields(
320 db.query.text,
321 user_registration.id = %user_registration.id,
322 user_registration.hashed_password = hashed_password,
323 user_registration.hashed_password_version = version,
324 ),
325 err,
326 )]
327 async fn set_password(
328 &mut self,
329 mut user_registration: UserRegistration,
330 hashed_password: String,
331 version: u16,
332 ) -> Result<UserRegistration, Self::Error> {
333 let res = sqlx::query!(
334 r#"
335 UPDATE user_registrations
336 SET hashed_password = $2, hashed_password_version = $3
337 WHERE user_registration_id = $1 AND completed_at IS NULL
338 "#,
339 Uuid::from(user_registration.id),
340 hashed_password,
341 i32::from(version),
342 )
343 .traced()
344 .execute(&mut *self.conn)
345 .await?;
346
347 DatabaseError::ensure_affected_rows(&res, 1)?;
348
349 user_registration.password = Some(UserRegistrationPassword {
350 hashed_password,
351 version,
352 });
353
354 Ok(user_registration)
355 }
356
357 #[tracing::instrument(
358 name = "db.user_registration.complete",
359 skip_all,
360 fields(
361 db.query.text,
362 user_registration.id = %user_registration.id,
363 ),
364 err,
365 )]
366 async fn complete(
367 &mut self,
368 clock: &dyn Clock,
369 mut user_registration: UserRegistration,
370 ) -> Result<UserRegistration, Self::Error> {
371 let completed_at = clock.now();
372 let res = sqlx::query!(
373 r#"
374 UPDATE user_registrations
375 SET completed_at = $2
376 WHERE user_registration_id = $1 AND completed_at IS NULL
377 "#,
378 Uuid::from(user_registration.id),
379 completed_at,
380 )
381 .traced()
382 .execute(&mut *self.conn)
383 .await?;
384
385 DatabaseError::ensure_affected_rows(&res, 1)?;
386
387 user_registration.completed_at = Some(completed_at);
388
389 Ok(user_registration)
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use std::net::{IpAddr, Ipv4Addr};
396
397 use mas_data_model::{UserAgent, UserRegistrationPassword};
398 use mas_storage::{Clock, clock::MockClock};
399 use rand::SeedableRng;
400 use rand_chacha::ChaChaRng;
401 use sqlx::PgPool;
402
403 use crate::PgRepository;
404
405 #[sqlx::test(migrator = "crate::MIGRATOR")]
406 async fn test_create_lookup_complete(pool: PgPool) {
407 let mut rng = ChaChaRng::seed_from_u64(42);
408 let clock = MockClock::default();
409
410 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
411
412 let registration = repo
413 .user_registration()
414 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
415 .await
416 .unwrap();
417
418 assert_eq!(registration.created_at, clock.now());
419 assert_eq!(registration.completed_at, None);
420 assert_eq!(registration.username, "alice");
421 assert_eq!(registration.display_name, None);
422 assert_eq!(registration.terms_url, None);
423 assert_eq!(registration.email_authentication_id, None);
424 assert_eq!(registration.password, None);
425 assert_eq!(registration.user_agent, None);
426 assert_eq!(registration.ip_address, None);
427 assert_eq!(registration.post_auth_action, None);
428
429 let lookup = repo
430 .user_registration()
431 .lookup(registration.id)
432 .await
433 .unwrap()
434 .unwrap();
435
436 assert_eq!(lookup.id, registration.id);
437 assert_eq!(lookup.created_at, registration.created_at);
438 assert_eq!(lookup.completed_at, registration.completed_at);
439 assert_eq!(lookup.username, registration.username);
440 assert_eq!(lookup.display_name, registration.display_name);
441 assert_eq!(lookup.terms_url, registration.terms_url);
442 assert_eq!(
443 lookup.email_authentication_id,
444 registration.email_authentication_id
445 );
446 assert_eq!(lookup.password, registration.password);
447 assert_eq!(lookup.user_agent, registration.user_agent);
448 assert_eq!(lookup.ip_address, registration.ip_address);
449 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
450
451 let registration = repo
453 .user_registration()
454 .complete(&clock, registration)
455 .await
456 .unwrap();
457 assert_eq!(registration.completed_at, Some(clock.now()));
458
459 let lookup = repo
461 .user_registration()
462 .lookup(registration.id)
463 .await
464 .unwrap()
465 .unwrap();
466 assert_eq!(lookup.completed_at, registration.completed_at);
467
468 let res = repo
470 .user_registration()
471 .complete(&clock, registration)
472 .await;
473 assert!(res.is_err());
474 }
475
476 #[sqlx::test(migrator = "crate::MIGRATOR")]
477 async fn test_create_useragent_ipaddress(pool: PgPool) {
478 let mut rng = ChaChaRng::seed_from_u64(42);
479 let clock = MockClock::default();
480
481 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
482
483 let registration = repo
484 .user_registration()
485 .add(
486 &mut rng,
487 &clock,
488 "alice".to_owned(),
489 Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
490 Some(UserAgent::parse("Mozilla/5.0".to_owned())),
491 Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
492 )
493 .await
494 .unwrap();
495
496 assert_eq!(
497 registration.user_agent,
498 Some(UserAgent::parse("Mozilla/5.0".to_owned()))
499 );
500 assert_eq!(
501 registration.ip_address,
502 Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
503 );
504 assert_eq!(
505 registration.post_auth_action,
506 Some(
507 serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
508 )
509 );
510
511 let lookup = repo
512 .user_registration()
513 .lookup(registration.id)
514 .await
515 .unwrap()
516 .unwrap();
517
518 assert_eq!(lookup.user_agent, registration.user_agent);
519 assert_eq!(lookup.ip_address, registration.ip_address);
520 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
521 }
522
523 #[sqlx::test(migrator = "crate::MIGRATOR")]
524 async fn test_set_display_name(pool: PgPool) {
525 let mut rng = ChaChaRng::seed_from_u64(42);
526 let clock = MockClock::default();
527
528 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
529
530 let registration = repo
531 .user_registration()
532 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
533 .await
534 .unwrap();
535
536 assert_eq!(registration.display_name, None);
537
538 let registration = repo
539 .user_registration()
540 .set_display_name(registration, "Alice".to_owned())
541 .await
542 .unwrap();
543
544 assert_eq!(registration.display_name, Some("Alice".to_owned()));
545
546 let lookup = repo
547 .user_registration()
548 .lookup(registration.id)
549 .await
550 .unwrap()
551 .unwrap();
552
553 assert_eq!(lookup.display_name, registration.display_name);
554
555 let registration = repo
557 .user_registration()
558 .set_display_name(registration, "Bob".to_owned())
559 .await
560 .unwrap();
561
562 assert_eq!(registration.display_name, Some("Bob".to_owned()));
563
564 let lookup = repo
565 .user_registration()
566 .lookup(registration.id)
567 .await
568 .unwrap()
569 .unwrap();
570
571 assert_eq!(lookup.display_name, registration.display_name);
572
573 let registration = repo
575 .user_registration()
576 .complete(&clock, registration)
577 .await
578 .unwrap();
579
580 let res = repo
581 .user_registration()
582 .set_display_name(registration, "Charlie".to_owned())
583 .await;
584 assert!(res.is_err());
585 }
586
587 #[sqlx::test(migrator = "crate::MIGRATOR")]
588 async fn test_set_terms_url(pool: PgPool) {
589 let mut rng = ChaChaRng::seed_from_u64(42);
590 let clock = MockClock::default();
591
592 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
593
594 let registration = repo
595 .user_registration()
596 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
597 .await
598 .unwrap();
599
600 assert_eq!(registration.terms_url, None);
601
602 let registration = repo
603 .user_registration()
604 .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
605 .await
606 .unwrap();
607
608 assert_eq!(
609 registration.terms_url,
610 Some("https://example.com/terms".parse().unwrap())
611 );
612
613 let lookup = repo
614 .user_registration()
615 .lookup(registration.id)
616 .await
617 .unwrap()
618 .unwrap();
619
620 assert_eq!(lookup.terms_url, registration.terms_url);
621
622 let registration = repo
624 .user_registration()
625 .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
626 .await
627 .unwrap();
628
629 assert_eq!(
630 registration.terms_url,
631 Some("https://example.com/terms2".parse().unwrap())
632 );
633
634 let lookup = repo
635 .user_registration()
636 .lookup(registration.id)
637 .await
638 .unwrap()
639 .unwrap();
640
641 assert_eq!(lookup.terms_url, registration.terms_url);
642
643 let registration = repo
645 .user_registration()
646 .complete(&clock, registration)
647 .await
648 .unwrap();
649
650 let res = repo
651 .user_registration()
652 .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
653 .await;
654 assert!(res.is_err());
655 }
656
657 #[sqlx::test(migrator = "crate::MIGRATOR")]
658 async fn test_set_email_authentication(pool: PgPool) {
659 let mut rng = ChaChaRng::seed_from_u64(42);
660 let clock = MockClock::default();
661
662 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
663
664 let registration = repo
665 .user_registration()
666 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
667 .await
668 .unwrap();
669
670 assert_eq!(registration.email_authentication_id, None);
671
672 let authentication = repo
673 .user_email()
674 .add_authentication_for_registration(
675 &mut rng,
676 &clock,
677 "alice@example.com".to_owned(),
678 ®istration,
679 )
680 .await
681 .unwrap();
682
683 let registration = repo
684 .user_registration()
685 .set_email_authentication(registration, &authentication)
686 .await
687 .unwrap();
688
689 assert_eq!(
690 registration.email_authentication_id,
691 Some(authentication.id)
692 );
693
694 let lookup = repo
695 .user_registration()
696 .lookup(registration.id)
697 .await
698 .unwrap()
699 .unwrap();
700
701 assert_eq!(
702 lookup.email_authentication_id,
703 registration.email_authentication_id
704 );
705
706 let registration = repo
708 .user_registration()
709 .set_email_authentication(registration, &authentication)
710 .await
711 .unwrap();
712
713 assert_eq!(
714 registration.email_authentication_id,
715 Some(authentication.id)
716 );
717
718 let lookup = repo
719 .user_registration()
720 .lookup(registration.id)
721 .await
722 .unwrap()
723 .unwrap();
724
725 assert_eq!(
726 lookup.email_authentication_id,
727 registration.email_authentication_id
728 );
729
730 let registration = repo
732 .user_registration()
733 .complete(&clock, registration)
734 .await
735 .unwrap();
736
737 let res = repo
738 .user_registration()
739 .set_email_authentication(registration, &authentication)
740 .await;
741 assert!(res.is_err());
742 }
743
744 #[sqlx::test(migrator = "crate::MIGRATOR")]
745 async fn test_set_password(pool: PgPool) {
746 let mut rng = ChaChaRng::seed_from_u64(42);
747 let clock = MockClock::default();
748
749 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
750
751 let registration = repo
752 .user_registration()
753 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
754 .await
755 .unwrap();
756
757 assert_eq!(registration.password, None);
758
759 let registration = repo
760 .user_registration()
761 .set_password(registration, "fakehashedpassword".to_owned(), 1)
762 .await
763 .unwrap();
764
765 assert_eq!(
766 registration.password,
767 Some(UserRegistrationPassword {
768 hashed_password: "fakehashedpassword".to_owned(),
769 version: 1,
770 })
771 );
772
773 let lookup = repo
774 .user_registration()
775 .lookup(registration.id)
776 .await
777 .unwrap()
778 .unwrap();
779
780 assert_eq!(lookup.password, registration.password);
781
782 let registration = repo
784 .user_registration()
785 .set_password(registration, "fakehashedpassword2".to_owned(), 2)
786 .await
787 .unwrap();
788
789 assert_eq!(
790 registration.password,
791 Some(UserRegistrationPassword {
792 hashed_password: "fakehashedpassword2".to_owned(),
793 version: 2,
794 })
795 );
796
797 let lookup = repo
798 .user_registration()
799 .lookup(registration.id)
800 .await
801 .unwrap()
802 .unwrap();
803
804 assert_eq!(lookup.password, registration.password);
805
806 let registration = repo
808 .user_registration()
809 .complete(&clock, registration)
810 .await
811 .unwrap();
812
813 let res = repo
814 .user_registration()
815 .set_password(registration, "fakehashedpassword3".to_owned(), 3)
816 .await;
817 assert!(res.is_err());
818 }
819}