mas_handlers/views/register/steps/
display_name.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use anyhow::Context as _;
7use axum::{
8    Form,
9    extract::{Path, State},
10    response::{Html, IntoResponse, Response},
11};
12use mas_axum_utils::{
13    FancyError,
14    cookies::CookieJar,
15    csrf::{CsrfExt as _, ProtectedForm},
16};
17use mas_router::{PostAuthAction, UrlBuilder};
18use mas_storage::{BoxClock, BoxRepository, BoxRng};
19use mas_templates::{
20    FieldError, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
21    TemplateContext as _, Templates, ToFormState,
22};
23use serde::{Deserialize, Serialize};
24use ulid::Ulid;
25
26use crate::{PreferredLanguage, views::shared::OptionalPostAuthAction};
27
28#[derive(Deserialize, Default)]
29#[serde(rename_all = "snake_case")]
30enum FormAction {
31    #[default]
32    Set,
33    Skip,
34}
35
36#[derive(Deserialize, Serialize)]
37pub(crate) struct DisplayNameForm {
38    #[serde(skip_serializing, default)]
39    action: FormAction,
40    #[serde(default)]
41    display_name: String,
42}
43
44impl ToFormState for DisplayNameForm {
45    type Field = mas_templates::RegisterStepsDisplayNameFormField;
46}
47
48#[tracing::instrument(
49    name = "handlers.views.register.steps.display_name.get",
50    fields(user_registration.id = %id),
51    skip_all,
52    err,
53)]
54pub(crate) async fn get(
55    mut rng: BoxRng,
56    clock: BoxClock,
57    PreferredLanguage(locale): PreferredLanguage,
58    State(templates): State<Templates>,
59    State(url_builder): State<UrlBuilder>,
60    mut repo: BoxRepository,
61    Path(id): Path<Ulid>,
62    cookie_jar: CookieJar,
63) -> Result<Response, FancyError> {
64    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
65
66    let registration = repo
67        .user_registration()
68        .lookup(id)
69        .await?
70        .context("Could not find user registration")?;
71
72    // If the registration is completed, we can go to the registration destination
73    // XXX: this might not be the right thing to do? Maybe an error page would be
74    // better?
75    if registration.completed_at.is_some() {
76        let post_auth_action: Option<PostAuthAction> = registration
77            .post_auth_action
78            .map(serde_json::from_value)
79            .transpose()?;
80
81        return Ok((
82            cookie_jar,
83            OptionalPostAuthAction::from(post_auth_action)
84                .go_next(&url_builder)
85                .into_response(),
86        )
87            .into_response());
88    }
89
90    let ctx = RegisterStepsDisplayNameContext::new()
91        .with_csrf(csrf_token.form_value())
92        .with_language(locale);
93
94    let content = templates.render_register_steps_display_name(&ctx)?;
95
96    Ok((cookie_jar, Html(content)).into_response())
97}
98
99#[tracing::instrument(
100    name = "handlers.views.register.steps.display_name.post",
101    fields(user_registration.id = %id),
102    skip_all,
103    err,
104)]
105pub(crate) async fn post(
106    mut rng: BoxRng,
107    clock: BoxClock,
108    PreferredLanguage(locale): PreferredLanguage,
109    State(templates): State<Templates>,
110    State(url_builder): State<UrlBuilder>,
111    mut repo: BoxRepository,
112    Path(id): Path<Ulid>,
113    cookie_jar: CookieJar,
114    Form(form): Form<ProtectedForm<DisplayNameForm>>,
115) -> Result<Response, FancyError> {
116    let registration = repo
117        .user_registration()
118        .lookup(id)
119        .await?
120        .context("Could not find user registration")?;
121
122    // If the registration is completed, we can go to the registration destination
123    // XXX: this might not be the right thing to do? Maybe an error page would be
124    // better?
125    if registration.completed_at.is_some() {
126        let post_auth_action: Option<PostAuthAction> = registration
127            .post_auth_action
128            .map(serde_json::from_value)
129            .transpose()?;
130
131        return Ok((
132            cookie_jar,
133            OptionalPostAuthAction::from(post_auth_action)
134                .go_next(&url_builder)
135                .into_response(),
136        )
137            .into_response());
138    }
139
140    let form = cookie_jar.verify_form(&clock, form)?;
141
142    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
143
144    let display_name = match form.action {
145        FormAction::Set => {
146            let display_name = form.display_name.trim();
147
148            if display_name.is_empty() || display_name.len() > 255 {
149                let ctx = RegisterStepsDisplayNameContext::new()
150                    .with_form_state(form.to_form_state().with_error_on_field(
151                        RegisterStepsDisplayNameFormField::DisplayName,
152                        FieldError::Invalid,
153                    ))
154                    .with_csrf(csrf_token.form_value())
155                    .with_language(locale);
156
157                return Ok((
158                    cookie_jar,
159                    Html(templates.render_register_steps_display_name(&ctx)?),
160                )
161                    .into_response());
162            }
163
164            display_name.to_owned()
165        }
166        FormAction::Skip => {
167            // If the user chose to skip, we do the same as Synapse and use the localpart as
168            // default display name
169            registration.username.clone()
170        }
171    };
172
173    let registration = repo
174        .user_registration()
175        .set_display_name(registration, display_name)
176        .await?;
177
178    repo.save().await?;
179
180    let destination = mas_router::RegisterFinish::new(registration.id);
181    return Ok((cookie_jar, url_builder.redirect(&destination)).into_response());
182}