mas_handlers/views/register/
mod.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use axum::{
7    extract::{Query, State},
8    response::{Html, IntoResponse, Response},
9};
10use mas_axum_utils::{FancyError, SessionInfoExt, cookies::CookieJar, csrf::CsrfExt as _};
11use mas_data_model::SiteConfig;
12use mas_router::{PasswordRegister, UpstreamOAuth2Authorize, UrlBuilder};
13use mas_storage::{BoxClock, BoxRepository, BoxRng};
14use mas_templates::{RegisterContext, TemplateContext, Templates};
15
16use super::shared::OptionalPostAuthAction;
17use crate::{BoundActivityTracker, PreferredLanguage};
18
19mod cookie;
20pub(crate) mod password;
21pub(crate) mod steps;
22
23#[tracing::instrument(name = "handlers.views.register.get", skip_all, err)]
24pub(crate) async fn get(
25    mut rng: BoxRng,
26    clock: BoxClock,
27    PreferredLanguage(locale): PreferredLanguage,
28    State(templates): State<Templates>,
29    State(url_builder): State<UrlBuilder>,
30    State(site_config): State<SiteConfig>,
31    mut repo: BoxRepository,
32    activity_tracker: BoundActivityTracker,
33    Query(query): Query<OptionalPostAuthAction>,
34    cookie_jar: CookieJar,
35) -> Result<Response, FancyError> {
36    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
37    let (session_info, cookie_jar) = cookie_jar.session_info();
38
39    let maybe_session = session_info.load_active_session(&mut repo).await?;
40
41    if let Some(session) = maybe_session {
42        activity_tracker
43            .record_browser_session(&clock, &session)
44            .await;
45
46        let reply = query.go_next(&url_builder);
47        return Ok((cookie_jar, reply).into_response());
48    }
49
50    let providers = repo.upstream_oauth_provider().all_enabled().await?;
51
52    // If password-based login is disabled, and there is only one upstream provider,
53    // we can directly start an authorization flow
54    if !site_config.password_registration_enabled && providers.len() == 1 {
55        let provider = providers.into_iter().next().unwrap();
56
57        let mut destination = UpstreamOAuth2Authorize::new(provider.id);
58
59        if let Some(action) = query.post_auth_action {
60            destination = destination.and_then(action);
61        }
62
63        return Ok((cookie_jar, url_builder.redirect(&destination)).into_response());
64    }
65
66    // If password-based registration is enabled and there are no upstream
67    // providers, we redirect to the password registration page
68    if site_config.password_registration_enabled && providers.is_empty() {
69        let mut destination = PasswordRegister::default();
70
71        if let Some(action) = query.post_auth_action {
72            destination = destination.and_then(action);
73        }
74
75        return Ok((cookie_jar, url_builder.redirect(&destination)).into_response());
76    }
77
78    let mut ctx = RegisterContext::new(providers);
79    let post_action = query.load_context(&mut repo).await?;
80    if let Some(action) = post_action {
81        ctx = ctx.with_post_action(action);
82    }
83
84    let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
85
86    let content = templates.render_register(&ctx)?;
87
88    Ok((cookie_jar, Html(content)).into_response())
89}