mas_handlers/upstream_oauth2/
link.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::Arc;
8
9use axum::{
10    Form,
11    extract::{Path, State},
12    response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17    FancyError, SessionInfoExt,
18    cookies::CookieJar,
19    csrf::{CsrfExt, ProtectedForm},
20    sentry::SentryEventID,
21};
22use mas_data_model::UserAgent;
23use mas_jose::jwt::Jwt;
24use mas_matrix::HomeserverConnection;
25use mas_policy::Policy;
26use mas_router::UrlBuilder;
27use mas_storage::{
28    BoxClock, BoxRepository, BoxRng, RepositoryAccess,
29    queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
30    upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
31    user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
32};
33use mas_templates::{
34    AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
35    ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
36};
37use minijinja::Environment;
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tracing::warn;
41use ulid::Ulid;
42
43use super::{
44    UpstreamSessionsCookie,
45    template::{AttributeMappingContext, environment},
46};
47use crate::{
48    BoundActivityTracker, PreferredLanguage, SiteConfig, impl_from_error_for_route,
49    views::shared::OptionalPostAuthAction,
50};
51
52const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
53const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
54const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
55
56#[derive(Debug, Error)]
57pub(crate) enum RouteError {
58    /// Couldn't find the link specified in the URL
59    #[error("Link not found")]
60    LinkNotFound,
61
62    /// Couldn't find the session on the link
63    #[error("Session not found")]
64    SessionNotFound,
65
66    /// Couldn't find the user
67    #[error("User not found")]
68    UserNotFound,
69
70    /// Couldn't find upstream provider
71    #[error("Upstream provider not found")]
72    ProviderNotFound,
73
74    /// Required attribute rendered to an empty string
75    #[error("Template {template:?} rendered to an empty string")]
76    RequiredAttributeEmpty { template: String },
77
78    /// Required claim was missing in `id_token`
79    #[error(
80        "Template {template:?} could not be rendered from the upstream provider's response for required claim"
81    )]
82    RequiredAttributeRender {
83        template: String,
84
85        #[source]
86        source: minijinja::Error,
87    },
88
89    /// Session was already consumed
90    #[error("Session already consumed")]
91    SessionConsumed,
92
93    #[error("Missing session cookie")]
94    MissingCookie,
95
96    #[error("Invalid form action")]
97    InvalidFormAction,
98
99    #[error("Homeserver connection error")]
100    HomeserverConnection(#[source] anyhow::Error),
101
102    #[error(transparent)]
103    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
104}
105
106impl_from_error_for_route!(mas_templates::TemplateError);
107impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
108impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
109impl_from_error_for_route!(mas_storage::RepositoryError);
110impl_from_error_for_route!(mas_policy::EvaluationError);
111impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
112
113impl IntoResponse for RouteError {
114    fn into_response(self) -> axum::response::Response {
115        let event_id = sentry::capture_error(&self);
116        let response = match self {
117            Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
118            Self::Internal(e) => FancyError::from(e).into_response(),
119            e => FancyError::from(e).into_response(),
120        };
121
122        (SentryEventID::from(event_id), response).into_response()
123    }
124}
125
126/// Utility function to render an attribute template.
127///
128/// # Parameters
129///
130/// * `environment` - The minijinja environment to use to render the template
131/// * `template` - The template to use to render the claim
132/// * `required` - Whether the attribute is required or not
133///
134/// # Errors
135///
136/// Returns an error if the attribute is required but fails to render or is
137/// empty
138fn render_attribute_template(
139    environment: &Environment,
140    template: &str,
141    context: &minijinja::Value,
142    required: bool,
143) -> Result<Option<String>, RouteError> {
144    match environment.render_str(template, context) {
145        Ok(value) if value.is_empty() => {
146            if required {
147                return Err(RouteError::RequiredAttributeEmpty {
148                    template: template.to_owned(),
149                });
150            }
151
152            Ok(None)
153        }
154
155        Ok(value) => Ok(Some(value)),
156
157        Err(source) => {
158            if required {
159                return Err(RouteError::RequiredAttributeRender {
160                    template: template.to_owned(),
161                    source,
162                });
163            }
164
165            tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
166            Ok(None)
167        }
168    }
169}
170
171#[derive(Deserialize, Serialize)]
172#[serde(rename_all = "lowercase", tag = "action")]
173pub(crate) enum FormData {
174    Register {
175        #[serde(default)]
176        username: Option<String>,
177        #[serde(default)]
178        import_email: Option<String>,
179        #[serde(default)]
180        import_display_name: Option<String>,
181        #[serde(default)]
182        accept_terms: Option<String>,
183    },
184    Link,
185}
186
187impl ToFormState for FormData {
188    type Field = mas_templates::UpstreamRegisterFormField;
189}
190
191#[tracing::instrument(
192    name = "handlers.upstream_oauth2.link.get",
193    fields(upstream_oauth_link.id = %link_id),
194    skip_all,
195    err,
196)]
197pub(crate) async fn get(
198    mut rng: BoxRng,
199    clock: BoxClock,
200    mut repo: BoxRepository,
201    mut policy: Policy,
202    PreferredLanguage(locale): PreferredLanguage,
203    State(templates): State<Templates>,
204    State(url_builder): State<UrlBuilder>,
205    State(homeserver): State<Arc<dyn HomeserverConnection>>,
206    cookie_jar: CookieJar,
207    activity_tracker: BoundActivityTracker,
208    user_agent: Option<TypedHeader<headers::UserAgent>>,
209    Path(link_id): Path<Ulid>,
210) -> Result<impl IntoResponse, RouteError> {
211    let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
212    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
213    let (session_id, post_auth_action) = sessions_cookie
214        .lookup_link(link_id)
215        .map_err(|_| RouteError::MissingCookie)?;
216
217    let post_auth_action = OptionalPostAuthAction {
218        post_auth_action: post_auth_action.cloned(),
219    };
220
221    let link = repo
222        .upstream_oauth_link()
223        .lookup(link_id)
224        .await?
225        .ok_or(RouteError::LinkNotFound)?;
226
227    let upstream_session = repo
228        .upstream_oauth_session()
229        .lookup(session_id)
230        .await?
231        .ok_or(RouteError::SessionNotFound)?;
232
233    // This checks that we're in a browser session which is allowed to consume this
234    // link: the upstream auth session should have been started in this browser.
235    if upstream_session.link_id() != Some(link.id) {
236        return Err(RouteError::SessionNotFound);
237    }
238
239    if upstream_session.is_consumed() {
240        return Err(RouteError::SessionConsumed);
241    }
242
243    let (user_session_info, cookie_jar) = cookie_jar.session_info();
244    let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
245    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
246
247    let response = match (maybe_user_session, link.user_id) {
248        (Some(session), Some(user_id)) if session.user.id == user_id => {
249            // Session already linked, and link matches the currently logged
250            // user. Mark the session as consumed and renew the authentication.
251            let upstream_session = repo
252                .upstream_oauth_session()
253                .consume(&clock, upstream_session)
254                .await?;
255
256            repo.browser_session()
257                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
258                .await?;
259
260            cookie_jar = cookie_jar.set_session(&session);
261
262            repo.save().await?;
263
264            post_auth_action.go_next(&url_builder).into_response()
265        }
266
267        (Some(user_session), Some(user_id)) => {
268            // Session already linked, but link doesn't match the currently
269            // logged user. Suggest logging out of the current user
270            // and logging in with the new one
271            let user = repo
272                .user()
273                .lookup(user_id)
274                .await?
275                .ok_or(RouteError::UserNotFound)?;
276
277            let ctx = UpstreamExistingLinkContext::new(user)
278                .with_session(user_session)
279                .with_csrf(csrf_token.form_value())
280                .with_language(locale);
281
282            Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
283        }
284
285        (Some(user_session), None) => {
286            // Session not linked, but user logged in: suggest linking account
287            let ctx = UpstreamSuggestLink::new(&link)
288                .with_session(user_session)
289                .with_csrf(csrf_token.form_value())
290                .with_language(locale);
291
292            Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
293        }
294
295        (None, Some(user_id)) => {
296            // Session linked, but user not logged in: do the login
297            let user = repo
298                .user()
299                .lookup(user_id)
300                .await?
301                .ok_or(RouteError::UserNotFound)?;
302
303            // Check that the user is not locked or deactivated
304            if user.deactivated_at.is_some() {
305                // The account is deactivated, show the 'account deactivated' fallback
306                let ctx = AccountInactiveContext::new(user)
307                    .with_csrf(csrf_token.form_value())
308                    .with_language(locale);
309                let fallback = templates.render_account_deactivated(&ctx)?;
310                return Ok((cookie_jar, Html(fallback).into_response()));
311            }
312
313            if user.locked_at.is_some() {
314                // The account is locked, show the 'account locked' fallback
315                let ctx = AccountInactiveContext::new(user)
316                    .with_csrf(csrf_token.form_value())
317                    .with_language(locale);
318                let fallback = templates.render_account_locked(&ctx)?;
319                return Ok((cookie_jar, Html(fallback).into_response()));
320            }
321
322            let session = repo
323                .browser_session()
324                .add(&mut rng, &clock, &user, user_agent)
325                .await?;
326
327            let upstream_session = repo
328                .upstream_oauth_session()
329                .consume(&clock, upstream_session)
330                .await?;
331
332            repo.browser_session()
333                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
334                .await?;
335
336            cookie_jar = sessions_cookie
337                .consume_link(link_id)?
338                .save(cookie_jar, &clock);
339            cookie_jar = cookie_jar.set_session(&session);
340
341            repo.save().await?;
342
343            post_auth_action.go_next(&url_builder).into_response()
344        }
345
346        (None, None) => {
347            // Session not linked and used not logged in: suggest creating an
348            // account or logging in an existing user
349            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
350
351            let provider = repo
352                .upstream_oauth_provider()
353                .lookup(link.provider_id)
354                .await?
355                .ok_or(RouteError::ProviderNotFound)?;
356
357            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
358
359            let env = environment();
360
361            let mut context = AttributeMappingContext::new();
362            if let Some(id_token) = id_token {
363                let (_, payload) = id_token.into_parts();
364                context = context.with_id_token_claims(payload);
365            }
366            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
367                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
368            }
369            if let Some(userinfo) = upstream_session.userinfo() {
370                context = context.with_userinfo_claims(userinfo.clone());
371            }
372            let context = context.build();
373
374            let ctx = if provider.claims_imports.displayname.ignore() {
375                ctx
376            } else {
377                let template = provider
378                    .claims_imports
379                    .displayname
380                    .template
381                    .as_deref()
382                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
383
384                match render_attribute_template(
385                    &env,
386                    template,
387                    &context,
388                    provider.claims_imports.displayname.is_required(),
389                )? {
390                    Some(value) => ctx
391                        .with_display_name(value, provider.claims_imports.displayname.is_forced()),
392                    None => ctx,
393                }
394            };
395
396            let ctx = if provider.claims_imports.email.ignore() {
397                ctx
398            } else {
399                let template = provider
400                    .claims_imports
401                    .email
402                    .template
403                    .as_deref()
404                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
405
406                match render_attribute_template(
407                    &env,
408                    template,
409                    &context,
410                    provider.claims_imports.email.is_required(),
411                )? {
412                    Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
413                    None => ctx,
414                }
415            };
416
417            let ctx = if provider.claims_imports.localpart.ignore() {
418                ctx
419            } else {
420                let template = provider
421                    .claims_imports
422                    .localpart
423                    .template
424                    .as_deref()
425                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
426
427                match render_attribute_template(
428                    &env,
429                    template,
430                    &context,
431                    provider.claims_imports.localpart.is_required(),
432                )? {
433                    Some(localpart) => {
434                        // We could run policy & existing user checks when the user submits the
435                        // form, but this lead to poor UX. This is why we do
436                        // it ahead of time here.
437                        let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
438                        let is_available = homeserver
439                            .is_localpart_available(&localpart)
440                            .await
441                            .map_err(RouteError::HomeserverConnection)?;
442
443                        if maybe_existing_user.is_some() || !is_available {
444                            if let Some(existing_user) = maybe_existing_user {
445                                // The mapper returned a username which already exists, but isn't
446                                // linked to this upstream user.
447                                warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
448                            }
449
450                            // TODO: translate
451                            let ctx = ErrorContext::new()
452                                .with_code("User exists")
453                                .with_description(format!(
454                                    r"Upstream account provider returned {localpart:?} as username,
455                                    which is not linked to that upstream account"
456                                ))
457                                .with_language(&locale);
458
459                            return Ok((
460                                cookie_jar,
461                                Html(templates.render_error(&ctx)?).into_response(),
462                            ));
463                        }
464
465                        let res = policy
466                            .evaluate_register(mas_policy::RegisterInput {
467                                registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
468                                username: &localpart,
469                                email: None,
470                                requester: mas_policy::Requester {
471                                    ip_address: activity_tracker.ip(),
472                                    user_agent: user_agent.clone().map(|ua| ua.raw),
473                                },
474                            })
475                            .await?;
476
477                        if res.valid() {
478                            // The username passes the policy check, add it to the context
479                            ctx.with_localpart(
480                                localpart,
481                                provider.claims_imports.localpart.is_forced(),
482                            )
483                        } else if provider.claims_imports.localpart.is_forced() {
484                            // If the username claim is 'forced' but doesn't pass the policy check,
485                            // we display an error message.
486                            // TODO: translate
487                            let ctx = ErrorContext::new()
488                                .with_code("Policy error")
489                                .with_description(format!(
490                                    r"Upstream account provider returned {localpart:?} as username,
491                                    which does not pass the policy check: {res}"
492                                ))
493                                .with_language(&locale);
494
495                            return Ok((
496                                cookie_jar,
497                                Html(templates.render_error(&ctx)?).into_response(),
498                            ));
499                        } else {
500                            // Else, we just ignore it when it doesn't pass the policy check.
501                            ctx
502                        }
503                    }
504                    None => ctx,
505                }
506            };
507
508            let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
509
510            Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
511        }
512    };
513
514    Ok((cookie_jar, response))
515}
516
517#[tracing::instrument(
518    name = "handlers.upstream_oauth2.link.post",
519    fields(upstream_oauth_link.id = %link_id),
520    skip_all,
521    err,
522)]
523pub(crate) async fn post(
524    mut rng: BoxRng,
525    clock: BoxClock,
526    mut repo: BoxRepository,
527    cookie_jar: CookieJar,
528    user_agent: Option<TypedHeader<headers::UserAgent>>,
529    mut policy: Policy,
530    PreferredLanguage(locale): PreferredLanguage,
531    activity_tracker: BoundActivityTracker,
532    State(templates): State<Templates>,
533    State(homeserver): State<Arc<dyn HomeserverConnection>>,
534    State(url_builder): State<UrlBuilder>,
535    State(site_config): State<SiteConfig>,
536    Path(link_id): Path<Ulid>,
537    Form(form): Form<ProtectedForm<FormData>>,
538) -> Result<Response, RouteError> {
539    let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
540    let form = cookie_jar.verify_form(&clock, form)?;
541
542    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
543    let (session_id, post_auth_action) = sessions_cookie
544        .lookup_link(link_id)
545        .map_err(|_| RouteError::MissingCookie)?;
546
547    let post_auth_action = OptionalPostAuthAction {
548        post_auth_action: post_auth_action.cloned(),
549    };
550
551    let link = repo
552        .upstream_oauth_link()
553        .lookup(link_id)
554        .await?
555        .ok_or(RouteError::LinkNotFound)?;
556
557    let upstream_session = repo
558        .upstream_oauth_session()
559        .lookup(session_id)
560        .await?
561        .ok_or(RouteError::SessionNotFound)?;
562
563    // This checks that we're in a browser session which is allowed to consume this
564    // link: the upstream auth session should have been started in this browser.
565    if upstream_session.link_id() != Some(link.id) {
566        return Err(RouteError::SessionNotFound);
567    }
568
569    if upstream_session.is_consumed() {
570        return Err(RouteError::SessionConsumed);
571    }
572
573    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
574    let (user_session_info, cookie_jar) = cookie_jar.session_info();
575    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
576    let form_state = form.to_form_state();
577
578    let session = match (maybe_user_session, link.user_id, form) {
579        (Some(session), None, FormData::Link) => {
580            // The user is already logged in, the link is not linked to any user, and the
581            // user asked to link their account.
582            repo.upstream_oauth_link()
583                .associate_to_user(&link, &session.user)
584                .await?;
585
586            session
587        }
588
589        (
590            None,
591            None,
592            FormData::Register {
593                username,
594                import_email,
595                import_display_name,
596                accept_terms,
597            },
598        ) => {
599            // The user got the form to register a new account, and is not logged in.
600            // Depending on the claims_imports, we've let the user choose their username,
601            // choose whether they want to import the email and display name, or
602            // not.
603
604            // Those fields are Some("on") if the checkbox is checked
605            let import_email = import_email.is_some();
606            let import_display_name = import_display_name.is_some();
607            let accept_terms = accept_terms.is_some();
608
609            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
610
611            let provider = repo
612                .upstream_oauth_provider()
613                .lookup(link.provider_id)
614                .await?
615                .ok_or(RouteError::ProviderNotFound)?;
616
617            // Let's try to import the claims from the ID token
618            let env = environment();
619
620            let mut context = AttributeMappingContext::new();
621            if let Some(id_token) = id_token {
622                let (_, payload) = id_token.into_parts();
623                context = context.with_id_token_claims(payload);
624            }
625            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
626                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
627            }
628            if let Some(userinfo) = upstream_session.userinfo() {
629                context = context.with_userinfo_claims(userinfo.clone());
630            }
631            let context = context.build();
632
633            // Create a template context in case we need to re-render because of an error
634            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
635
636            let display_name = if provider
637                .claims_imports
638                .displayname
639                .should_import(import_display_name)
640            {
641                let template = provider
642                    .claims_imports
643                    .displayname
644                    .template
645                    .as_deref()
646                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
647
648                render_attribute_template(
649                    &env,
650                    template,
651                    &context,
652                    provider.claims_imports.displayname.is_required(),
653                )?
654            } else {
655                None
656            };
657
658            let ctx = if let Some(ref display_name) = display_name {
659                ctx.with_display_name(
660                    display_name.clone(),
661                    provider.claims_imports.email.is_forced(),
662                )
663            } else {
664                ctx
665            };
666
667            let email = if provider.claims_imports.email.should_import(import_email) {
668                let template = provider
669                    .claims_imports
670                    .email
671                    .template
672                    .as_deref()
673                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
674
675                render_attribute_template(
676                    &env,
677                    template,
678                    &context,
679                    provider.claims_imports.email.is_required(),
680                )?
681            } else {
682                None
683            };
684
685            let ctx = if let Some(ref email) = email {
686                ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
687            } else {
688                ctx
689            };
690
691            let username = if provider.claims_imports.localpart.is_forced() {
692                let template = provider
693                    .claims_imports
694                    .localpart
695                    .template
696                    .as_deref()
697                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
698
699                render_attribute_template(&env, template, &context, true)?
700            } else {
701                // If there is no forced username, we can use the one the user entered
702                username
703            }
704            .unwrap_or_default();
705
706            let ctx = ctx.with_localpart(
707                username.clone(),
708                provider.claims_imports.localpart.is_forced(),
709            );
710
711            // Validate the form
712            let form_state = {
713                let mut form_state = form_state;
714                let mut homeserver_denied_username = false;
715                if username.is_empty() {
716                    form_state.add_error_on_field(
717                        mas_templates::UpstreamRegisterFormField::Username,
718                        FieldError::Required,
719                    );
720                } else if repo.user().exists(&username).await? {
721                    form_state.add_error_on_field(
722                        mas_templates::UpstreamRegisterFormField::Username,
723                        FieldError::Exists,
724                    );
725                } else if !homeserver
726                    .is_localpart_available(&username)
727                    .await
728                    .map_err(RouteError::HomeserverConnection)?
729                {
730                    // The user already exists on the homeserver
731                    tracing::warn!(
732                        %username,
733                        "Homeserver denied username provided by user"
734                    );
735
736                    // We defer adding the error on the field, until we know whether we had another
737                    // error from the policy, to avoid showing both
738                    homeserver_denied_username = true;
739                }
740
741                // If we have a TOS in the config, make sure the user has accepted it
742                if site_config.tos_uri.is_some() && !accept_terms {
743                    form_state.add_error_on_field(
744                        mas_templates::UpstreamRegisterFormField::AcceptTerms,
745                        FieldError::Required,
746                    );
747                }
748
749                // Policy check
750                let res = policy
751                    .evaluate_register(mas_policy::RegisterInput {
752                        registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
753                        username: &username,
754                        email: email.as_deref(),
755                        requester: mas_policy::Requester {
756                            ip_address: activity_tracker.ip(),
757                            user_agent: user_agent.clone().map(|ua| ua.raw),
758                        },
759                    })
760                    .await?;
761
762                for violation in res.violations {
763                    match violation.field.as_deref() {
764                        Some("username") => {
765                            // If the homeserver denied the username, but we also had an error on
766                            // the policy side, we don't want to show
767                            // both, so we reset the state here
768                            homeserver_denied_username = false;
769                            form_state.add_error_on_field(
770                                mas_templates::UpstreamRegisterFormField::Username,
771                                FieldError::Policy {
772                                    code: violation.code.map(|c| c.as_str()),
773                                    message: violation.msg,
774                                },
775                            );
776                        }
777                        _ => form_state.add_error_on_form(FormError::Policy {
778                            code: violation.code.map(|c| c.as_str()),
779                            message: violation.msg,
780                        }),
781                    }
782                }
783
784                if homeserver_denied_username {
785                    // XXX: we may want to return different errors like "this username is reserved"
786                    form_state.add_error_on_field(
787                        mas_templates::UpstreamRegisterFormField::Username,
788                        FieldError::Exists,
789                    );
790                }
791
792                form_state
793            };
794
795            if !form_state.is_valid() {
796                let ctx = ctx
797                    .with_form_state(form_state)
798                    .with_csrf(csrf_token.form_value())
799                    .with_language(locale);
800
801                return Ok((
802                    cookie_jar,
803                    Html(templates.render_upstream_oauth2_do_register(&ctx)?),
804                )
805                    .into_response());
806            }
807
808            // Now we can create the user
809            let user = repo.user().add(&mut rng, &clock, username).await?;
810
811            if let Some(terms_url) = &site_config.tos_uri {
812                repo.user_terms()
813                    .accept_terms(&mut rng, &clock, &user, terms_url.clone())
814                    .await?;
815            }
816
817            // And schedule the job to provision it
818            let mut job = ProvisionUserJob::new(&user);
819
820            // If we have a display name, set it during provisioning
821            if let Some(name) = display_name {
822                job = job.set_display_name(name);
823            }
824
825            repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
826
827            // If we have an email, add it to the user
828            if let Some(email) = email {
829                repo.user_email()
830                    .add(&mut rng, &clock, &user, email)
831                    .await?;
832            }
833
834            repo.upstream_oauth_link()
835                .associate_to_user(&link, &user)
836                .await?;
837
838            repo.browser_session()
839                .add(&mut rng, &clock, &user, user_agent)
840                .await?
841        }
842
843        _ => return Err(RouteError::InvalidFormAction),
844    };
845
846    let upstream_session = repo
847        .upstream_oauth_session()
848        .consume(&clock, upstream_session)
849        .await?;
850
851    repo.browser_session()
852        .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
853        .await?;
854
855    let cookie_jar = sessions_cookie
856        .consume_link(link_id)?
857        .save(cookie_jar, &clock);
858    let cookie_jar = cookie_jar.set_session(&session);
859
860    repo.save().await?;
861
862    Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
863}
864
865#[cfg(test)]
866mod tests {
867    use hyper::{Request, StatusCode, header::CONTENT_TYPE};
868    use mas_data_model::{
869        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
870        UpstreamOAuthProviderTokenAuthMethod,
871    };
872    use mas_iana::jose::JsonWebSignatureAlg;
873    use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
874    use mas_router::Route;
875    use mas_storage::{
876        Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
877    };
878    use oauth2_types::scope::{OPENID, Scope};
879    use sqlx::PgPool;
880
881    use super::UpstreamSessionsCookie;
882    use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
883
884    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
885    async fn test_register(pool: PgPool) {
886        setup();
887        let state = TestState::from_pool(pool).await.unwrap();
888        let mut rng = state.rng();
889        let cookies = CookieHelper::new();
890
891        let claims_imports = UpstreamOAuthProviderClaimsImports {
892            localpart: UpstreamOAuthProviderImportPreference {
893                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
894                template: None,
895            },
896            email: UpstreamOAuthProviderImportPreference {
897                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
898                template: None,
899            },
900            ..UpstreamOAuthProviderClaimsImports::default()
901        };
902
903        let id_token = serde_json::json!({
904            "preferred_username": "john",
905            "email": "john@example.com",
906            "email_verified": true,
907        });
908
909        // Grab a key to sign the id_token
910        // We could generate a key on the fly, but because we have one available here,
911        // why not use it?
912        let key = state
913            .key_store
914            .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
915            .unwrap();
916
917        let signer = key
918            .params()
919            .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
920            .unwrap();
921        let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
922        let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
923
924        // Provision a provider and a link
925        let mut repo = state.repository().await.unwrap();
926        let provider = repo
927            .upstream_oauth_provider()
928            .add(
929                &mut rng,
930                &state.clock,
931                UpstreamOAuthProviderParams {
932                    issuer: Some("https://example.com/".to_owned()),
933                    human_name: Some("Example Ltd.".to_owned()),
934                    brand_name: None,
935                    scope: Scope::from_iter([OPENID]),
936                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
937                    token_endpoint_signing_alg: None,
938                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
939                    client_id: "client".to_owned(),
940                    encrypted_client_secret: None,
941                    claims_imports,
942                    authorization_endpoint_override: None,
943                    token_endpoint_override: None,
944                    userinfo_endpoint_override: None,
945                    fetch_userinfo: false,
946                    userinfo_signed_response_alg: None,
947                    jwks_uri_override: None,
948                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
949                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
950                    response_mode: None,
951                    additional_authorization_parameters: Vec::new(),
952                    ui_order: 0,
953                },
954            )
955            .await
956            .unwrap();
957
958        let session = repo
959            .upstream_oauth_session()
960            .add(
961                &mut rng,
962                &state.clock,
963                &provider,
964                "state".to_owned(),
965                None,
966                "nonce".to_owned(),
967            )
968            .await
969            .unwrap();
970
971        let link = repo
972            .upstream_oauth_link()
973            .add(
974                &mut rng,
975                &state.clock,
976                &provider,
977                "subject".to_owned(),
978                None,
979            )
980            .await
981            .unwrap();
982
983        let session = repo
984            .upstream_oauth_session()
985            .complete_with_link(
986                &state.clock,
987                session,
988                &link,
989                Some(id_token.into_string()),
990                None,
991                None,
992            )
993            .await
994            .unwrap();
995
996        repo.save().await.unwrap();
997
998        let cookie_jar = state.cookie_jar();
999        let upstream_sessions = UpstreamSessionsCookie::default()
1000            .add(session.id, provider.id, "state".to_owned(), None)
1001            .add_link_to_session(session.id, link.id)
1002            .unwrap();
1003        let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1004        cookies.import(cookie_jar);
1005
1006        let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1007        let request = cookies.with_cookies(request);
1008        let response = state.request(request).await;
1009        cookies.save_cookies(&response);
1010        response.assert_status(StatusCode::OK);
1011        response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1012
1013        // Extract the CSRF token from the response body
1014        let csrf_token = response
1015            .body()
1016            .split("name=\"csrf\" value=\"")
1017            .nth(1)
1018            .unwrap()
1019            .split('\"')
1020            .next()
1021            .unwrap();
1022
1023        let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1024            serde_json::json!({
1025                "csrf": csrf_token,
1026                "action": "register",
1027                "import_email": "on",
1028                "accept_terms": "on",
1029            }),
1030        );
1031        let request = cookies.with_cookies(request);
1032        let response = state.request(request).await;
1033        cookies.save_cookies(&response);
1034        response.assert_status(StatusCode::SEE_OTHER);
1035
1036        // Check that we have a registered user, with the email imported
1037        let mut repo = state.repository().await.unwrap();
1038        let user = repo
1039            .user()
1040            .find_by_username("john")
1041            .await
1042            .unwrap()
1043            .expect("user exists");
1044
1045        let link = repo
1046            .upstream_oauth_link()
1047            .find_by_subject(&provider, "subject")
1048            .await
1049            .unwrap()
1050            .expect("link exists");
1051
1052        assert_eq!(link.user_id, Some(user.id));
1053
1054        let page = repo
1055            .user_email()
1056            .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1057            .await
1058            .unwrap();
1059        let email = page.edges.first().expect("email exists");
1060
1061        assert_eq!(email.email, "john@example.com");
1062    }
1063}