mas_storage_pg/user/
registration.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    UserAgent, UserEmailAuthentication, UserRegistration, UserRegistrationPassword,
12};
13use mas_storage::{Clock, user::UserRegistrationRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use url::Url;
18use uuid::Uuid;
19
20use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
21
22/// An implementation of [`UserRegistrationRepository`] for a PostgreSQL
23/// connection
24pub struct PgUserRegistrationRepository<'c> {
25    conn: &'c mut PgConnection,
26}
27
28impl<'c> PgUserRegistrationRepository<'c> {
29    /// Create a new [`PgUserRegistrationRepository`] from an active PostgreSQL
30    /// connection
31    pub fn new(conn: &'c mut PgConnection) -> Self {
32        Self { conn }
33    }
34}
35
36struct UserRegistrationLookup {
37    user_registration_id: Uuid,
38    ip_address: Option<IpAddr>,
39    user_agent: Option<String>,
40    post_auth_action: Option<serde_json::Value>,
41    username: String,
42    display_name: Option<String>,
43    terms_url: Option<String>,
44    email_authentication_id: Option<Uuid>,
45    hashed_password: Option<String>,
46    hashed_password_version: Option<i32>,
47    created_at: DateTime<Utc>,
48    completed_at: Option<DateTime<Utc>>,
49}
50
51impl TryFrom<UserRegistrationLookup> for UserRegistration {
52    type Error = DatabaseInconsistencyError;
53
54    fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
55        let id = Ulid::from(value.user_registration_id);
56        let user_agent = value.user_agent.map(UserAgent::parse);
57
58        let password = match (value.hashed_password, value.hashed_password_version) {
59            (Some(hashed_password), Some(version)) => {
60                let version = version.try_into().map_err(|e| {
61                    DatabaseInconsistencyError::on("user_registrations")
62                        .column("hashed_password_version")
63                        .row(id)
64                        .source(e)
65                })?;
66
67                Some(UserRegistrationPassword {
68                    hashed_password,
69                    version,
70                })
71            }
72            (None, None) => None,
73            _ => {
74                return Err(DatabaseInconsistencyError::on("user_registrations")
75                    .column("hashed_password")
76                    .row(id));
77            }
78        };
79
80        let terms_url = value
81            .terms_url
82            .map(|u| u.parse())
83            .transpose()
84            .map_err(|e| {
85                DatabaseInconsistencyError::on("user_registrations")
86                    .column("terms_url")
87                    .row(id)
88                    .source(e)
89            })?;
90
91        Ok(UserRegistration {
92            id,
93            ip_address: value.ip_address,
94            user_agent,
95            post_auth_action: value.post_auth_action,
96            username: value.username,
97            display_name: value.display_name,
98            terms_url,
99            email_authentication_id: value.email_authentication_id.map(Ulid::from),
100            password,
101            created_at: value.created_at,
102            completed_at: value.completed_at,
103        })
104    }
105}
106
107#[async_trait]
108impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
109    type Error = DatabaseError;
110
111    #[tracing::instrument(
112        name = "db.user_registration.lookup",
113        skip_all,
114        fields(
115            db.query.text,
116            user_registration.id = %id,
117        ),
118        err,
119    )]
120    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
121        let res = sqlx::query_as!(
122            UserRegistrationLookup,
123            r#"
124                SELECT user_registration_id
125                     , ip_address as "ip_address: IpAddr"
126                     , user_agent
127                     , post_auth_action
128                     , username
129                     , display_name
130                     , terms_url
131                     , email_authentication_id
132                     , hashed_password
133                     , hashed_password_version
134                     , created_at
135                     , completed_at
136                FROM user_registrations
137                WHERE user_registration_id = $1
138            "#,
139            Uuid::from(id),
140        )
141        .traced()
142        .fetch_optional(&mut *self.conn)
143        .await?;
144
145        let Some(res) = res else { return Ok(None) };
146
147        Ok(Some(res.try_into()?))
148    }
149
150    #[tracing::instrument(
151        name = "db.user_registration.add",
152        skip_all,
153        fields(
154            db.query.text,
155            user_registration.id,
156        ),
157        err,
158    )]
159    async fn add(
160        &mut self,
161        rng: &mut (dyn RngCore + Send),
162        clock: &dyn Clock,
163        username: String,
164        ip_address: Option<IpAddr>,
165        user_agent: Option<UserAgent>,
166        post_auth_action: Option<serde_json::Value>,
167    ) -> Result<UserRegistration, Self::Error> {
168        let created_at = clock.now();
169        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
170        tracing::Span::current().record("user_registration.id", tracing::field::display(id));
171
172        sqlx::query!(
173            r#"
174                INSERT INTO user_registrations
175                  ( user_registration_id
176                  , ip_address
177                  , user_agent
178                  , post_auth_action
179                  , username
180                  , created_at
181                  )
182                VALUES ($1, $2, $3, $4, $5, $6)
183            "#,
184            Uuid::from(id),
185            ip_address as Option<IpAddr>,
186            user_agent.as_deref(),
187            post_auth_action,
188            username,
189            created_at,
190        )
191        .traced()
192        .execute(&mut *self.conn)
193        .await?;
194
195        Ok(UserRegistration {
196            id,
197            ip_address,
198            user_agent,
199            post_auth_action,
200            created_at,
201            completed_at: None,
202            username,
203            display_name: None,
204            terms_url: None,
205            email_authentication_id: None,
206            password: None,
207        })
208    }
209
210    #[tracing::instrument(
211        name = "db.user_registration.set_display_name",
212        skip_all,
213        fields(
214            db.query.text,
215            user_registration.id = %user_registration.id,
216            user_registration.display_name = display_name,
217        ),
218        err,
219    )]
220    async fn set_display_name(
221        &mut self,
222        mut user_registration: UserRegistration,
223        display_name: String,
224    ) -> Result<UserRegistration, Self::Error> {
225        let res = sqlx::query!(
226            r#"
227                UPDATE user_registrations
228                SET display_name = $2
229                WHERE user_registration_id = $1 AND completed_at IS NULL
230            "#,
231            Uuid::from(user_registration.id),
232            display_name,
233        )
234        .traced()
235        .execute(&mut *self.conn)
236        .await?;
237
238        DatabaseError::ensure_affected_rows(&res, 1)?;
239
240        user_registration.display_name = Some(display_name);
241
242        Ok(user_registration)
243    }
244
245    #[tracing::instrument(
246        name = "db.user_registration.set_terms_url",
247        skip_all,
248        fields(
249            db.query.text,
250            user_registration.id = %user_registration.id,
251            user_registration.terms_url = %terms_url,
252        ),
253        err,
254    )]
255    async fn set_terms_url(
256        &mut self,
257        mut user_registration: UserRegistration,
258        terms_url: Url,
259    ) -> Result<UserRegistration, Self::Error> {
260        let res = sqlx::query!(
261            r#"
262                UPDATE user_registrations
263                SET terms_url = $2
264                WHERE user_registration_id = $1 AND completed_at IS NULL
265            "#,
266            Uuid::from(user_registration.id),
267            terms_url.as_str(),
268        )
269        .traced()
270        .execute(&mut *self.conn)
271        .await?;
272
273        DatabaseError::ensure_affected_rows(&res, 1)?;
274
275        user_registration.terms_url = Some(terms_url);
276
277        Ok(user_registration)
278    }
279
280    #[tracing::instrument(
281        name = "db.user_registration.set_email_authentication",
282        skip_all,
283        fields(
284            db.query.text,
285            %user_registration.id,
286            %user_email_authentication.id,
287            %user_email_authentication.email,
288        ),
289        err,
290    )]
291    async fn set_email_authentication(
292        &mut self,
293        mut user_registration: UserRegistration,
294        user_email_authentication: &UserEmailAuthentication,
295    ) -> Result<UserRegistration, Self::Error> {
296        let res = sqlx::query!(
297            r#"
298                UPDATE user_registrations
299                SET email_authentication_id = $2
300                WHERE user_registration_id = $1 AND completed_at IS NULL
301            "#,
302            Uuid::from(user_registration.id),
303            Uuid::from(user_email_authentication.id),
304        )
305        .traced()
306        .execute(&mut *self.conn)
307        .await?;
308
309        DatabaseError::ensure_affected_rows(&res, 1)?;
310
311        user_registration.email_authentication_id = Some(user_email_authentication.id);
312
313        Ok(user_registration)
314    }
315
316    #[tracing::instrument(
317        name = "db.user_registration.set_password",
318        skip_all,
319        fields(
320            db.query.text,
321            user_registration.id = %user_registration.id,
322            user_registration.hashed_password = hashed_password,
323            user_registration.hashed_password_version = version,
324        ),
325        err,
326    )]
327    async fn set_password(
328        &mut self,
329        mut user_registration: UserRegistration,
330        hashed_password: String,
331        version: u16,
332    ) -> Result<UserRegistration, Self::Error> {
333        let res = sqlx::query!(
334            r#"
335                UPDATE user_registrations
336                SET hashed_password = $2, hashed_password_version = $3
337                WHERE user_registration_id = $1 AND completed_at IS NULL
338            "#,
339            Uuid::from(user_registration.id),
340            hashed_password,
341            i32::from(version),
342        )
343        .traced()
344        .execute(&mut *self.conn)
345        .await?;
346
347        DatabaseError::ensure_affected_rows(&res, 1)?;
348
349        user_registration.password = Some(UserRegistrationPassword {
350            hashed_password,
351            version,
352        });
353
354        Ok(user_registration)
355    }
356
357    #[tracing::instrument(
358        name = "db.user_registration.complete",
359        skip_all,
360        fields(
361            db.query.text,
362            user_registration.id = %user_registration.id,
363        ),
364        err,
365    )]
366    async fn complete(
367        &mut self,
368        clock: &dyn Clock,
369        mut user_registration: UserRegistration,
370    ) -> Result<UserRegistration, Self::Error> {
371        let completed_at = clock.now();
372        let res = sqlx::query!(
373            r#"
374                UPDATE user_registrations
375                SET completed_at = $2
376                WHERE user_registration_id = $1 AND completed_at IS NULL
377            "#,
378            Uuid::from(user_registration.id),
379            completed_at,
380        )
381        .traced()
382        .execute(&mut *self.conn)
383        .await?;
384
385        DatabaseError::ensure_affected_rows(&res, 1)?;
386
387        user_registration.completed_at = Some(completed_at);
388
389        Ok(user_registration)
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use std::net::{IpAddr, Ipv4Addr};
396
397    use mas_data_model::{UserAgent, UserRegistrationPassword};
398    use mas_storage::{Clock, clock::MockClock};
399    use rand::SeedableRng;
400    use rand_chacha::ChaChaRng;
401    use sqlx::PgPool;
402
403    use crate::PgRepository;
404
405    #[sqlx::test(migrator = "crate::MIGRATOR")]
406    async fn test_create_lookup_complete(pool: PgPool) {
407        let mut rng = ChaChaRng::seed_from_u64(42);
408        let clock = MockClock::default();
409
410        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
411
412        let registration = repo
413            .user_registration()
414            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
415            .await
416            .unwrap();
417
418        assert_eq!(registration.created_at, clock.now());
419        assert_eq!(registration.completed_at, None);
420        assert_eq!(registration.username, "alice");
421        assert_eq!(registration.display_name, None);
422        assert_eq!(registration.terms_url, None);
423        assert_eq!(registration.email_authentication_id, None);
424        assert_eq!(registration.password, None);
425        assert_eq!(registration.user_agent, None);
426        assert_eq!(registration.ip_address, None);
427        assert_eq!(registration.post_auth_action, None);
428
429        let lookup = repo
430            .user_registration()
431            .lookup(registration.id)
432            .await
433            .unwrap()
434            .unwrap();
435
436        assert_eq!(lookup.id, registration.id);
437        assert_eq!(lookup.created_at, registration.created_at);
438        assert_eq!(lookup.completed_at, registration.completed_at);
439        assert_eq!(lookup.username, registration.username);
440        assert_eq!(lookup.display_name, registration.display_name);
441        assert_eq!(lookup.terms_url, registration.terms_url);
442        assert_eq!(
443            lookup.email_authentication_id,
444            registration.email_authentication_id
445        );
446        assert_eq!(lookup.password, registration.password);
447        assert_eq!(lookup.user_agent, registration.user_agent);
448        assert_eq!(lookup.ip_address, registration.ip_address);
449        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
450
451        // Mark the registration as completed
452        let registration = repo
453            .user_registration()
454            .complete(&clock, registration)
455            .await
456            .unwrap();
457        assert_eq!(registration.completed_at, Some(clock.now()));
458
459        // Lookup the registration again
460        let lookup = repo
461            .user_registration()
462            .lookup(registration.id)
463            .await
464            .unwrap()
465            .unwrap();
466        assert_eq!(lookup.completed_at, registration.completed_at);
467
468        // Do it again, it should fail
469        let res = repo
470            .user_registration()
471            .complete(&clock, registration)
472            .await;
473        assert!(res.is_err());
474    }
475
476    #[sqlx::test(migrator = "crate::MIGRATOR")]
477    async fn test_create_useragent_ipaddress(pool: PgPool) {
478        let mut rng = ChaChaRng::seed_from_u64(42);
479        let clock = MockClock::default();
480
481        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
482
483        let registration = repo
484            .user_registration()
485            .add(
486                &mut rng,
487                &clock,
488                "alice".to_owned(),
489                Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
490                Some(UserAgent::parse("Mozilla/5.0".to_owned())),
491                Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
492            )
493            .await
494            .unwrap();
495
496        assert_eq!(
497            registration.user_agent,
498            Some(UserAgent::parse("Mozilla/5.0".to_owned()))
499        );
500        assert_eq!(
501            registration.ip_address,
502            Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
503        );
504        assert_eq!(
505            registration.post_auth_action,
506            Some(
507                serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
508            )
509        );
510
511        let lookup = repo
512            .user_registration()
513            .lookup(registration.id)
514            .await
515            .unwrap()
516            .unwrap();
517
518        assert_eq!(lookup.user_agent, registration.user_agent);
519        assert_eq!(lookup.ip_address, registration.ip_address);
520        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
521    }
522
523    #[sqlx::test(migrator = "crate::MIGRATOR")]
524    async fn test_set_display_name(pool: PgPool) {
525        let mut rng = ChaChaRng::seed_from_u64(42);
526        let clock = MockClock::default();
527
528        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
529
530        let registration = repo
531            .user_registration()
532            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
533            .await
534            .unwrap();
535
536        assert_eq!(registration.display_name, None);
537
538        let registration = repo
539            .user_registration()
540            .set_display_name(registration, "Alice".to_owned())
541            .await
542            .unwrap();
543
544        assert_eq!(registration.display_name, Some("Alice".to_owned()));
545
546        let lookup = repo
547            .user_registration()
548            .lookup(registration.id)
549            .await
550            .unwrap()
551            .unwrap();
552
553        assert_eq!(lookup.display_name, registration.display_name);
554
555        // Setting it again should work
556        let registration = repo
557            .user_registration()
558            .set_display_name(registration, "Bob".to_owned())
559            .await
560            .unwrap();
561
562        assert_eq!(registration.display_name, Some("Bob".to_owned()));
563
564        let lookup = repo
565            .user_registration()
566            .lookup(registration.id)
567            .await
568            .unwrap()
569            .unwrap();
570
571        assert_eq!(lookup.display_name, registration.display_name);
572
573        // Can't set it once completed
574        let registration = repo
575            .user_registration()
576            .complete(&clock, registration)
577            .await
578            .unwrap();
579
580        let res = repo
581            .user_registration()
582            .set_display_name(registration, "Charlie".to_owned())
583            .await;
584        assert!(res.is_err());
585    }
586
587    #[sqlx::test(migrator = "crate::MIGRATOR")]
588    async fn test_set_terms_url(pool: PgPool) {
589        let mut rng = ChaChaRng::seed_from_u64(42);
590        let clock = MockClock::default();
591
592        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
593
594        let registration = repo
595            .user_registration()
596            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
597            .await
598            .unwrap();
599
600        assert_eq!(registration.terms_url, None);
601
602        let registration = repo
603            .user_registration()
604            .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
605            .await
606            .unwrap();
607
608        assert_eq!(
609            registration.terms_url,
610            Some("https://example.com/terms".parse().unwrap())
611        );
612
613        let lookup = repo
614            .user_registration()
615            .lookup(registration.id)
616            .await
617            .unwrap()
618            .unwrap();
619
620        assert_eq!(lookup.terms_url, registration.terms_url);
621
622        // Setting it again should work
623        let registration = repo
624            .user_registration()
625            .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
626            .await
627            .unwrap();
628
629        assert_eq!(
630            registration.terms_url,
631            Some("https://example.com/terms2".parse().unwrap())
632        );
633
634        let lookup = repo
635            .user_registration()
636            .lookup(registration.id)
637            .await
638            .unwrap()
639            .unwrap();
640
641        assert_eq!(lookup.terms_url, registration.terms_url);
642
643        // Can't set it once completed
644        let registration = repo
645            .user_registration()
646            .complete(&clock, registration)
647            .await
648            .unwrap();
649
650        let res = repo
651            .user_registration()
652            .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
653            .await;
654        assert!(res.is_err());
655    }
656
657    #[sqlx::test(migrator = "crate::MIGRATOR")]
658    async fn test_set_email_authentication(pool: PgPool) {
659        let mut rng = ChaChaRng::seed_from_u64(42);
660        let clock = MockClock::default();
661
662        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
663
664        let registration = repo
665            .user_registration()
666            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
667            .await
668            .unwrap();
669
670        assert_eq!(registration.email_authentication_id, None);
671
672        let authentication = repo
673            .user_email()
674            .add_authentication_for_registration(
675                &mut rng,
676                &clock,
677                "alice@example.com".to_owned(),
678                &registration,
679            )
680            .await
681            .unwrap();
682
683        let registration = repo
684            .user_registration()
685            .set_email_authentication(registration, &authentication)
686            .await
687            .unwrap();
688
689        assert_eq!(
690            registration.email_authentication_id,
691            Some(authentication.id)
692        );
693
694        let lookup = repo
695            .user_registration()
696            .lookup(registration.id)
697            .await
698            .unwrap()
699            .unwrap();
700
701        assert_eq!(
702            lookup.email_authentication_id,
703            registration.email_authentication_id
704        );
705
706        // Setting it again should work
707        let registration = repo
708            .user_registration()
709            .set_email_authentication(registration, &authentication)
710            .await
711            .unwrap();
712
713        assert_eq!(
714            registration.email_authentication_id,
715            Some(authentication.id)
716        );
717
718        let lookup = repo
719            .user_registration()
720            .lookup(registration.id)
721            .await
722            .unwrap()
723            .unwrap();
724
725        assert_eq!(
726            lookup.email_authentication_id,
727            registration.email_authentication_id
728        );
729
730        // Can't set it once completed
731        let registration = repo
732            .user_registration()
733            .complete(&clock, registration)
734            .await
735            .unwrap();
736
737        let res = repo
738            .user_registration()
739            .set_email_authentication(registration, &authentication)
740            .await;
741        assert!(res.is_err());
742    }
743
744    #[sqlx::test(migrator = "crate::MIGRATOR")]
745    async fn test_set_password(pool: PgPool) {
746        let mut rng = ChaChaRng::seed_from_u64(42);
747        let clock = MockClock::default();
748
749        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
750
751        let registration = repo
752            .user_registration()
753            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
754            .await
755            .unwrap();
756
757        assert_eq!(registration.password, None);
758
759        let registration = repo
760            .user_registration()
761            .set_password(registration, "fakehashedpassword".to_owned(), 1)
762            .await
763            .unwrap();
764
765        assert_eq!(
766            registration.password,
767            Some(UserRegistrationPassword {
768                hashed_password: "fakehashedpassword".to_owned(),
769                version: 1,
770            })
771        );
772
773        let lookup = repo
774            .user_registration()
775            .lookup(registration.id)
776            .await
777            .unwrap()
778            .unwrap();
779
780        assert_eq!(lookup.password, registration.password);
781
782        // Setting it again should work
783        let registration = repo
784            .user_registration()
785            .set_password(registration, "fakehashedpassword2".to_owned(), 2)
786            .await
787            .unwrap();
788
789        assert_eq!(
790            registration.password,
791            Some(UserRegistrationPassword {
792                hashed_password: "fakehashedpassword2".to_owned(),
793                version: 2,
794            })
795        );
796
797        let lookup = repo
798            .user_registration()
799            .lookup(registration.id)
800            .await
801            .unwrap()
802            .unwrap();
803
804        assert_eq!(lookup.password, registration.password);
805
806        // Can't set it once completed
807        let registration = repo
808            .user_registration()
809            .complete(&clock, registration)
810            .await
811            .unwrap();
812
813        let res = repo
814            .user_registration()
815            .set_password(registration, "fakehashedpassword3".to_owned(), 3)
816            .await;
817        assert!(res.is_err());
818    }
819}