mas_storage_pg/user/
registration.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2025 New Vector Ltd.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
13    UserRegistrationPassword, UserRegistrationToken,
14};
15use mas_storage::user::UserRegistrationRepository;
16use rand::RngCore;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use url::Url;
20use uuid::Uuid;
21
22use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
23
24/// An implementation of [`UserRegistrationRepository`] for a PostgreSQL
25/// connection
26pub struct PgUserRegistrationRepository<'c> {
27    conn: &'c mut PgConnection,
28}
29
30impl<'c> PgUserRegistrationRepository<'c> {
31    /// Create a new [`PgUserRegistrationRepository`] from an active PostgreSQL
32    /// connection
33    pub fn new(conn: &'c mut PgConnection) -> Self {
34        Self { conn }
35    }
36}
37
38struct UserRegistrationLookup {
39    user_registration_id: Uuid,
40    ip_address: Option<IpAddr>,
41    user_agent: Option<String>,
42    post_auth_action: Option<serde_json::Value>,
43    username: String,
44    display_name: Option<String>,
45    terms_url: Option<String>,
46    email_authentication_id: Option<Uuid>,
47    user_registration_token_id: Option<Uuid>,
48    hashed_password: Option<String>,
49    hashed_password_version: Option<i32>,
50    upstream_oauth_authorization_session_id: Option<Uuid>,
51    created_at: DateTime<Utc>,
52    completed_at: Option<DateTime<Utc>>,
53}
54
55impl TryFrom<UserRegistrationLookup> for UserRegistration {
56    type Error = DatabaseInconsistencyError;
57
58    fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
59        let id = Ulid::from(value.user_registration_id);
60
61        let password = match (value.hashed_password, value.hashed_password_version) {
62            (Some(hashed_password), Some(version)) => {
63                let version = version.try_into().map_err(|e| {
64                    DatabaseInconsistencyError::on("user_registrations")
65                        .column("hashed_password_version")
66                        .row(id)
67                        .source(e)
68                })?;
69
70                Some(UserRegistrationPassword {
71                    hashed_password,
72                    version,
73                })
74            }
75            (None, None) => None,
76            _ => {
77                return Err(DatabaseInconsistencyError::on("user_registrations")
78                    .column("hashed_password")
79                    .row(id));
80            }
81        };
82
83        let terms_url = value
84            .terms_url
85            .map(|u| u.parse())
86            .transpose()
87            .map_err(|e| {
88                DatabaseInconsistencyError::on("user_registrations")
89                    .column("terms_url")
90                    .row(id)
91                    .source(e)
92            })?;
93
94        Ok(UserRegistration {
95            id,
96            ip_address: value.ip_address,
97            user_agent: value.user_agent,
98            post_auth_action: value.post_auth_action,
99            username: value.username,
100            display_name: value.display_name,
101            terms_url,
102            email_authentication_id: value.email_authentication_id.map(Ulid::from),
103            user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
104            password,
105            upstream_oauth_authorization_session_id: value
106                .upstream_oauth_authorization_session_id
107                .map(Ulid::from),
108            created_at: value.created_at,
109            completed_at: value.completed_at,
110        })
111    }
112}
113
114#[async_trait]
115impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
116    type Error = DatabaseError;
117
118    #[tracing::instrument(
119        name = "db.user_registration.lookup",
120        skip_all,
121        fields(
122            db.query.text,
123            user_registration.id = %id,
124        ),
125        err,
126    )]
127    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
128        let res = sqlx::query_as!(
129            UserRegistrationLookup,
130            r#"
131                SELECT user_registration_id
132                     , ip_address as "ip_address: IpAddr"
133                     , user_agent
134                     , post_auth_action
135                     , username
136                     , display_name
137                     , terms_url
138                     , email_authentication_id
139                     , user_registration_token_id
140                     , hashed_password
141                     , hashed_password_version
142                     , upstream_oauth_authorization_session_id
143                     , created_at
144                     , completed_at
145                FROM user_registrations
146                WHERE user_registration_id = $1
147            "#,
148            Uuid::from(id),
149        )
150        .traced()
151        .fetch_optional(&mut *self.conn)
152        .await?;
153
154        let Some(res) = res else { return Ok(None) };
155
156        Ok(Some(res.try_into()?))
157    }
158
159    #[tracing::instrument(
160        name = "db.user_registration.add",
161        skip_all,
162        fields(
163            db.query.text,
164            user_registration.id,
165        ),
166        err,
167    )]
168    async fn add(
169        &mut self,
170        rng: &mut (dyn RngCore + Send),
171        clock: &dyn Clock,
172        username: String,
173        ip_address: Option<IpAddr>,
174        user_agent: Option<String>,
175        post_auth_action: Option<serde_json::Value>,
176    ) -> Result<UserRegistration, Self::Error> {
177        let created_at = clock.now();
178        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
179        tracing::Span::current().record("user_registration.id", tracing::field::display(id));
180
181        sqlx::query!(
182            r#"
183                INSERT INTO user_registrations
184                  ( user_registration_id
185                  , ip_address
186                  , user_agent
187                  , post_auth_action
188                  , username
189                  , created_at
190                  )
191                VALUES ($1, $2, $3, $4, $5, $6)
192            "#,
193            Uuid::from(id),
194            ip_address as Option<IpAddr>,
195            user_agent.as_deref(),
196            post_auth_action,
197            username,
198            created_at,
199        )
200        .traced()
201        .execute(&mut *self.conn)
202        .await?;
203
204        Ok(UserRegistration {
205            id,
206            ip_address,
207            user_agent,
208            post_auth_action,
209            created_at,
210            completed_at: None,
211            username,
212            display_name: None,
213            terms_url: None,
214            email_authentication_id: None,
215            user_registration_token_id: None,
216            password: None,
217            upstream_oauth_authorization_session_id: None,
218        })
219    }
220
221    #[tracing::instrument(
222        name = "db.user_registration.set_display_name",
223        skip_all,
224        fields(
225            db.query.text,
226            user_registration.id = %user_registration.id,
227            user_registration.display_name = display_name,
228        ),
229        err,
230    )]
231    async fn set_display_name(
232        &mut self,
233        mut user_registration: UserRegistration,
234        display_name: String,
235    ) -> Result<UserRegistration, Self::Error> {
236        let res = sqlx::query!(
237            r#"
238                UPDATE user_registrations
239                SET display_name = $2
240                WHERE user_registration_id = $1 AND completed_at IS NULL
241            "#,
242            Uuid::from(user_registration.id),
243            display_name,
244        )
245        .traced()
246        .execute(&mut *self.conn)
247        .await?;
248
249        DatabaseError::ensure_affected_rows(&res, 1)?;
250
251        user_registration.display_name = Some(display_name);
252
253        Ok(user_registration)
254    }
255
256    #[tracing::instrument(
257        name = "db.user_registration.set_terms_url",
258        skip_all,
259        fields(
260            db.query.text,
261            user_registration.id = %user_registration.id,
262            user_registration.terms_url = %terms_url,
263        ),
264        err,
265    )]
266    async fn set_terms_url(
267        &mut self,
268        mut user_registration: UserRegistration,
269        terms_url: Url,
270    ) -> Result<UserRegistration, Self::Error> {
271        let res = sqlx::query!(
272            r#"
273                UPDATE user_registrations
274                SET terms_url = $2
275                WHERE user_registration_id = $1 AND completed_at IS NULL
276            "#,
277            Uuid::from(user_registration.id),
278            terms_url.as_str(),
279        )
280        .traced()
281        .execute(&mut *self.conn)
282        .await?;
283
284        DatabaseError::ensure_affected_rows(&res, 1)?;
285
286        user_registration.terms_url = Some(terms_url);
287
288        Ok(user_registration)
289    }
290
291    #[tracing::instrument(
292        name = "db.user_registration.set_email_authentication",
293        skip_all,
294        fields(
295            db.query.text,
296            %user_registration.id,
297            %user_email_authentication.id,
298            %user_email_authentication.email,
299        ),
300        err,
301    )]
302    async fn set_email_authentication(
303        &mut self,
304        mut user_registration: UserRegistration,
305        user_email_authentication: &UserEmailAuthentication,
306    ) -> Result<UserRegistration, Self::Error> {
307        let res = sqlx::query!(
308            r#"
309                UPDATE user_registrations
310                SET email_authentication_id = $2
311                WHERE user_registration_id = $1 AND completed_at IS NULL
312            "#,
313            Uuid::from(user_registration.id),
314            Uuid::from(user_email_authentication.id),
315        )
316        .traced()
317        .execute(&mut *self.conn)
318        .await?;
319
320        DatabaseError::ensure_affected_rows(&res, 1)?;
321
322        user_registration.email_authentication_id = Some(user_email_authentication.id);
323
324        Ok(user_registration)
325    }
326
327    #[tracing::instrument(
328        name = "db.user_registration.set_password",
329        skip_all,
330        fields(
331            db.query.text,
332            user_registration.id = %user_registration.id,
333            user_registration.hashed_password = hashed_password,
334            user_registration.hashed_password_version = version,
335        ),
336        err,
337    )]
338    async fn set_password(
339        &mut self,
340        mut user_registration: UserRegistration,
341        hashed_password: String,
342        version: u16,
343    ) -> Result<UserRegistration, Self::Error> {
344        let res = sqlx::query!(
345            r#"
346                UPDATE user_registrations
347                SET hashed_password = $2, hashed_password_version = $3
348                WHERE user_registration_id = $1 AND completed_at IS NULL
349            "#,
350            Uuid::from(user_registration.id),
351            hashed_password,
352            i32::from(version),
353        )
354        .traced()
355        .execute(&mut *self.conn)
356        .await?;
357
358        DatabaseError::ensure_affected_rows(&res, 1)?;
359
360        user_registration.password = Some(UserRegistrationPassword {
361            hashed_password,
362            version,
363        });
364
365        Ok(user_registration)
366    }
367
368    #[tracing::instrument(
369        name = "db.user_registration.set_registration_token",
370        skip_all,
371        fields(
372            db.query.text,
373            %user_registration.id,
374            %user_registration_token.id,
375        ),
376        err,
377    )]
378    async fn set_registration_token(
379        &mut self,
380        mut user_registration: UserRegistration,
381        user_registration_token: &UserRegistrationToken,
382    ) -> Result<UserRegistration, Self::Error> {
383        let res = sqlx::query!(
384            r#"
385                UPDATE user_registrations
386                SET user_registration_token_id = $2
387                WHERE user_registration_id = $1 AND completed_at IS NULL
388            "#,
389            Uuid::from(user_registration.id),
390            Uuid::from(user_registration_token.id),
391        )
392        .traced()
393        .execute(&mut *self.conn)
394        .await?;
395
396        DatabaseError::ensure_affected_rows(&res, 1)?;
397
398        user_registration.user_registration_token_id = Some(user_registration_token.id);
399
400        Ok(user_registration)
401    }
402
403    #[tracing::instrument(
404        name = "db.user_registration.set_upstream_oauth_authorization_session",
405        skip_all,
406        fields(
407            db.query.text,
408            %user_registration.id,
409            %upstream_oauth_authorization_session.id,
410        ),
411        err,
412    )]
413    async fn set_upstream_oauth_authorization_session(
414        &mut self,
415        mut user_registration: UserRegistration,
416        upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
417    ) -> Result<UserRegistration, Self::Error> {
418        let res = sqlx::query!(
419            r#"
420                UPDATE user_registrations
421                SET upstream_oauth_authorization_session_id = $2
422                WHERE user_registration_id = $1 AND completed_at IS NULL
423            "#,
424            Uuid::from(user_registration.id),
425            Uuid::from(upstream_oauth_authorization_session.id),
426        )
427        .traced()
428        .execute(&mut *self.conn)
429        .await?;
430
431        DatabaseError::ensure_affected_rows(&res, 1)?;
432
433        user_registration.upstream_oauth_authorization_session_id =
434            Some(upstream_oauth_authorization_session.id);
435
436        Ok(user_registration)
437    }
438
439    #[tracing::instrument(
440        name = "db.user_registration.complete",
441        skip_all,
442        fields(
443            db.query.text,
444            user_registration.id = %user_registration.id,
445        ),
446        err,
447    )]
448    async fn complete(
449        &mut self,
450        clock: &dyn Clock,
451        mut user_registration: UserRegistration,
452    ) -> Result<UserRegistration, Self::Error> {
453        let completed_at = clock.now();
454        let res = sqlx::query!(
455            r#"
456                UPDATE user_registrations
457                SET completed_at = $2
458                WHERE user_registration_id = $1 AND completed_at IS NULL
459            "#,
460            Uuid::from(user_registration.id),
461            completed_at,
462        )
463        .traced()
464        .execute(&mut *self.conn)
465        .await?;
466
467        DatabaseError::ensure_affected_rows(&res, 1)?;
468
469        user_registration.completed_at = Some(completed_at);
470
471        Ok(user_registration)
472    }
473
474    #[tracing::instrument(
475        name = "db.user_registration.cleanup",
476        skip_all,
477        fields(
478            db.query.text,
479        ),
480        err,
481    )]
482    async fn cleanup(
483        &mut self,
484        since: Option<Ulid>,
485        until: Ulid,
486        limit: usize,
487    ) -> Result<(usize, Option<Ulid>), Self::Error> {
488        // `MAX(uuid)` isn't a thing in Postgres, so we can't just re-select the
489        // deleted rows and do a MAX on the `user_registration_id`.
490        // Instead, we do the aggregation on the client side, which is a little
491        // less efficient, but good enough.
492        let res = sqlx::query_scalar!(
493            r#"
494                WITH to_delete AS (
495                    SELECT user_registration_id
496                    FROM user_registrations
497                    WHERE ($1::uuid IS NULL OR user_registration_id > $1)
498                    AND user_registration_id <= $2
499                    ORDER BY user_registration_id
500                    LIMIT $3
501                )
502                DELETE FROM user_registrations
503                USING to_delete
504                WHERE user_registrations.user_registration_id = to_delete.user_registration_id
505                RETURNING user_registrations.user_registration_id
506            "#,
507            since.map(Uuid::from),
508            Uuid::from(until),
509            i64::try_from(limit).unwrap_or(i64::MAX)
510        )
511        .traced()
512        .fetch_all(&mut *self.conn)
513        .await?;
514
515        let count = res.len();
516        let max_id = res.into_iter().max();
517
518        Ok((count, max_id.map(Ulid::from)))
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use std::net::{IpAddr, Ipv4Addr};
525
526    use mas_data_model::{
527        Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
528        UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
529        UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock,
530    };
531    use mas_iana::jose::JsonWebSignatureAlg;
532    use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
533    use oauth2_types::scope::Scope;
534    use rand::SeedableRng;
535    use rand_chacha::ChaChaRng;
536    use sqlx::PgPool;
537
538    use crate::PgRepository;
539
540    #[sqlx::test(migrator = "crate::MIGRATOR")]
541    async fn test_create_lookup_complete(pool: PgPool) {
542        let mut rng = ChaChaRng::seed_from_u64(42);
543        let clock = MockClock::default();
544
545        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
546
547        let registration = repo
548            .user_registration()
549            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
550            .await
551            .unwrap();
552
553        assert_eq!(registration.created_at, clock.now());
554        assert_eq!(registration.completed_at, None);
555        assert_eq!(registration.username, "alice");
556        assert_eq!(registration.display_name, None);
557        assert_eq!(registration.terms_url, None);
558        assert_eq!(registration.email_authentication_id, None);
559        assert_eq!(registration.password, None);
560        assert_eq!(registration.user_agent, None);
561        assert_eq!(registration.ip_address, None);
562        assert_eq!(registration.post_auth_action, None);
563
564        let lookup = repo
565            .user_registration()
566            .lookup(registration.id)
567            .await
568            .unwrap()
569            .unwrap();
570
571        assert_eq!(lookup.id, registration.id);
572        assert_eq!(lookup.created_at, registration.created_at);
573        assert_eq!(lookup.completed_at, registration.completed_at);
574        assert_eq!(lookup.username, registration.username);
575        assert_eq!(lookup.display_name, registration.display_name);
576        assert_eq!(lookup.terms_url, registration.terms_url);
577        assert_eq!(
578            lookup.email_authentication_id,
579            registration.email_authentication_id
580        );
581        assert_eq!(lookup.password, registration.password);
582        assert_eq!(lookup.user_agent, registration.user_agent);
583        assert_eq!(lookup.ip_address, registration.ip_address);
584        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
585
586        // Mark the registration as completed
587        let registration = repo
588            .user_registration()
589            .complete(&clock, registration)
590            .await
591            .unwrap();
592        assert_eq!(registration.completed_at, Some(clock.now()));
593
594        // Lookup the registration again
595        let lookup = repo
596            .user_registration()
597            .lookup(registration.id)
598            .await
599            .unwrap()
600            .unwrap();
601        assert_eq!(lookup.completed_at, registration.completed_at);
602
603        // Do it again, it should fail
604        let res = repo
605            .user_registration()
606            .complete(&clock, registration)
607            .await;
608        assert!(res.is_err());
609    }
610
611    #[sqlx::test(migrator = "crate::MIGRATOR")]
612    async fn test_create_useragent_ipaddress(pool: PgPool) {
613        let mut rng = ChaChaRng::seed_from_u64(42);
614        let clock = MockClock::default();
615
616        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
617
618        let registration = repo
619            .user_registration()
620            .add(
621                &mut rng,
622                &clock,
623                "alice".to_owned(),
624                Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
625                Some("Mozilla/5.0".to_owned()),
626                Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
627            )
628            .await
629            .unwrap();
630
631        assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
632        assert_eq!(
633            registration.ip_address,
634            Some(IpAddr::V4(Ipv4Addr::LOCALHOST))
635        );
636        assert_eq!(
637            registration.post_auth_action,
638            Some(
639                serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
640            )
641        );
642
643        let lookup = repo
644            .user_registration()
645            .lookup(registration.id)
646            .await
647            .unwrap()
648            .unwrap();
649
650        assert_eq!(lookup.user_agent, registration.user_agent);
651        assert_eq!(lookup.ip_address, registration.ip_address);
652        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
653    }
654
655    #[sqlx::test(migrator = "crate::MIGRATOR")]
656    async fn test_set_display_name(pool: PgPool) {
657        let mut rng = ChaChaRng::seed_from_u64(42);
658        let clock = MockClock::default();
659
660        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
661
662        let registration = repo
663            .user_registration()
664            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
665            .await
666            .unwrap();
667
668        assert_eq!(registration.display_name, None);
669
670        let registration = repo
671            .user_registration()
672            .set_display_name(registration, "Alice".to_owned())
673            .await
674            .unwrap();
675
676        assert_eq!(registration.display_name, Some("Alice".to_owned()));
677
678        let lookup = repo
679            .user_registration()
680            .lookup(registration.id)
681            .await
682            .unwrap()
683            .unwrap();
684
685        assert_eq!(lookup.display_name, registration.display_name);
686
687        // Setting it again should work
688        let registration = repo
689            .user_registration()
690            .set_display_name(registration, "Bob".to_owned())
691            .await
692            .unwrap();
693
694        assert_eq!(registration.display_name, Some("Bob".to_owned()));
695
696        let lookup = repo
697            .user_registration()
698            .lookup(registration.id)
699            .await
700            .unwrap()
701            .unwrap();
702
703        assert_eq!(lookup.display_name, registration.display_name);
704
705        // Can't set it once completed
706        let registration = repo
707            .user_registration()
708            .complete(&clock, registration)
709            .await
710            .unwrap();
711
712        let res = repo
713            .user_registration()
714            .set_display_name(registration, "Charlie".to_owned())
715            .await;
716        assert!(res.is_err());
717    }
718
719    #[sqlx::test(migrator = "crate::MIGRATOR")]
720    async fn test_set_terms_url(pool: PgPool) {
721        let mut rng = ChaChaRng::seed_from_u64(42);
722        let clock = MockClock::default();
723
724        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
725
726        let registration = repo
727            .user_registration()
728            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
729            .await
730            .unwrap();
731
732        assert_eq!(registration.terms_url, None);
733
734        let registration = repo
735            .user_registration()
736            .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
737            .await
738            .unwrap();
739
740        assert_eq!(
741            registration.terms_url,
742            Some("https://example.com/terms".parse().unwrap())
743        );
744
745        let lookup = repo
746            .user_registration()
747            .lookup(registration.id)
748            .await
749            .unwrap()
750            .unwrap();
751
752        assert_eq!(lookup.terms_url, registration.terms_url);
753
754        // Setting it again should work
755        let registration = repo
756            .user_registration()
757            .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
758            .await
759            .unwrap();
760
761        assert_eq!(
762            registration.terms_url,
763            Some("https://example.com/terms2".parse().unwrap())
764        );
765
766        let lookup = repo
767            .user_registration()
768            .lookup(registration.id)
769            .await
770            .unwrap()
771            .unwrap();
772
773        assert_eq!(lookup.terms_url, registration.terms_url);
774
775        // Can't set it once completed
776        let registration = repo
777            .user_registration()
778            .complete(&clock, registration)
779            .await
780            .unwrap();
781
782        let res = repo
783            .user_registration()
784            .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
785            .await;
786        assert!(res.is_err());
787    }
788
789    #[sqlx::test(migrator = "crate::MIGRATOR")]
790    async fn test_set_email_authentication(pool: PgPool) {
791        let mut rng = ChaChaRng::seed_from_u64(42);
792        let clock = MockClock::default();
793
794        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
795
796        let registration = repo
797            .user_registration()
798            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
799            .await
800            .unwrap();
801
802        assert_eq!(registration.email_authentication_id, None);
803
804        let authentication = repo
805            .user_email()
806            .add_authentication_for_registration(
807                &mut rng,
808                &clock,
809                "alice@example.com".to_owned(),
810                &registration,
811            )
812            .await
813            .unwrap();
814
815        let registration = repo
816            .user_registration()
817            .set_email_authentication(registration, &authentication)
818            .await
819            .unwrap();
820
821        assert_eq!(
822            registration.email_authentication_id,
823            Some(authentication.id)
824        );
825
826        let lookup = repo
827            .user_registration()
828            .lookup(registration.id)
829            .await
830            .unwrap()
831            .unwrap();
832
833        assert_eq!(
834            lookup.email_authentication_id,
835            registration.email_authentication_id
836        );
837
838        // Setting it again should work
839        let registration = repo
840            .user_registration()
841            .set_email_authentication(registration, &authentication)
842            .await
843            .unwrap();
844
845        assert_eq!(
846            registration.email_authentication_id,
847            Some(authentication.id)
848        );
849
850        let lookup = repo
851            .user_registration()
852            .lookup(registration.id)
853            .await
854            .unwrap()
855            .unwrap();
856
857        assert_eq!(
858            lookup.email_authentication_id,
859            registration.email_authentication_id
860        );
861
862        // Can't set it once completed
863        let registration = repo
864            .user_registration()
865            .complete(&clock, registration)
866            .await
867            .unwrap();
868
869        let res = repo
870            .user_registration()
871            .set_email_authentication(registration, &authentication)
872            .await;
873        assert!(res.is_err());
874    }
875
876    #[sqlx::test(migrator = "crate::MIGRATOR")]
877    async fn test_set_password(pool: PgPool) {
878        let mut rng = ChaChaRng::seed_from_u64(42);
879        let clock = MockClock::default();
880
881        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
882
883        let registration = repo
884            .user_registration()
885            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
886            .await
887            .unwrap();
888
889        assert_eq!(registration.password, None);
890
891        let registration = repo
892            .user_registration()
893            .set_password(registration, "fakehashedpassword".to_owned(), 1)
894            .await
895            .unwrap();
896
897        assert_eq!(
898            registration.password,
899            Some(UserRegistrationPassword {
900                hashed_password: "fakehashedpassword".to_owned(),
901                version: 1,
902            })
903        );
904
905        let lookup = repo
906            .user_registration()
907            .lookup(registration.id)
908            .await
909            .unwrap()
910            .unwrap();
911
912        assert_eq!(lookup.password, registration.password);
913
914        // Setting it again should work
915        let registration = repo
916            .user_registration()
917            .set_password(registration, "fakehashedpassword2".to_owned(), 2)
918            .await
919            .unwrap();
920
921        assert_eq!(
922            registration.password,
923            Some(UserRegistrationPassword {
924                hashed_password: "fakehashedpassword2".to_owned(),
925                version: 2,
926            })
927        );
928
929        let lookup = repo
930            .user_registration()
931            .lookup(registration.id)
932            .await
933            .unwrap()
934            .unwrap();
935
936        assert_eq!(lookup.password, registration.password);
937
938        // Can't set it once completed
939        let registration = repo
940            .user_registration()
941            .complete(&clock, registration)
942            .await
943            .unwrap();
944
945        let res = repo
946            .user_registration()
947            .set_password(registration, "fakehashedpassword3".to_owned(), 3)
948            .await;
949        assert!(res.is_err());
950    }
951
952    #[sqlx::test(migrator = "crate::MIGRATOR")]
953    async fn test_set_upstream_oauth_session(pool: PgPool) {
954        let mut rng = ChaChaRng::seed_from_u64(42);
955        let clock = MockClock::default();
956
957        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
958
959        let registration = repo
960            .user_registration()
961            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
962            .await
963            .unwrap();
964
965        assert_eq!(registration.upstream_oauth_authorization_session_id, None);
966
967        let provider = repo
968            .upstream_oauth_provider()
969            .add(
970                &mut rng,
971                &clock,
972                UpstreamOAuthProviderParams {
973                    issuer: Some("https://example.com/".to_owned()),
974                    human_name: Some("Example Ltd.".to_owned()),
975                    brand_name: None,
976                    scope: Scope::from_iter([oauth2_types::scope::OPENID]),
977                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
978                    token_endpoint_signing_alg: None,
979                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
980                    client_id: "client".to_owned(),
981                    encrypted_client_secret: None,
982                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
983                    authorization_endpoint_override: None,
984                    token_endpoint_override: None,
985                    userinfo_endpoint_override: None,
986                    fetch_userinfo: false,
987                    userinfo_signed_response_alg: None,
988                    jwks_uri_override: None,
989                    discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
990                    pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
991                    response_mode: None,
992                    additional_authorization_parameters: Vec::new(),
993                    forward_login_hint: false,
994                    ui_order: 0,
995                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
996                },
997            )
998            .await
999            .unwrap();
1000
1001        let session = repo
1002            .upstream_oauth_session()
1003            .add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
1004            .await
1005            .unwrap();
1006
1007        let registration = repo
1008            .user_registration()
1009            .set_upstream_oauth_authorization_session(registration, &session)
1010            .await
1011            .unwrap();
1012
1013        assert_eq!(
1014            registration.upstream_oauth_authorization_session_id,
1015            Some(session.id)
1016        );
1017
1018        let lookup = repo
1019            .user_registration()
1020            .lookup(registration.id)
1021            .await
1022            .unwrap()
1023            .unwrap();
1024
1025        assert_eq!(
1026            lookup.upstream_oauth_authorization_session_id,
1027            registration.upstream_oauth_authorization_session_id
1028        );
1029
1030        // Setting it again should work
1031        let registration = repo
1032            .user_registration()
1033            .set_upstream_oauth_authorization_session(registration, &session)
1034            .await
1035            .unwrap();
1036
1037        assert_eq!(
1038            registration.upstream_oauth_authorization_session_id,
1039            Some(session.id)
1040        );
1041
1042        let lookup = repo
1043            .user_registration()
1044            .lookup(registration.id)
1045            .await
1046            .unwrap()
1047            .unwrap();
1048
1049        assert_eq!(
1050            lookup.upstream_oauth_authorization_session_id,
1051            registration.upstream_oauth_authorization_session_id
1052        );
1053
1054        // Can't set it once completed
1055        let registration = repo
1056            .user_registration()
1057            .complete(&clock, registration)
1058            .await
1059            .unwrap();
1060
1061        let res = repo
1062            .user_registration()
1063            .set_upstream_oauth_authorization_session(registration, &session)
1064            .await;
1065        assert!(res.is_err());
1066    }
1067}