1use std::{str::FromStr, sync::Arc};
8
9use axum::{
10 extract::{Form, Query, State},
11 response::{Html, IntoResponse, Response},
12};
13use axum_extra::typed_header::TypedHeader;
14use hyper::StatusCode;
15use lettre::Address;
16use mas_axum_utils::{
17 FancyError, SessionInfoExt,
18 cookies::CookieJar,
19 csrf::{CsrfExt, CsrfToken, ProtectedForm},
20};
21use mas_data_model::{CaptchaConfig, UserAgent};
22use mas_i18n::DataLocale;
23use mas_matrix::HomeserverConnection;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
28 queue::{QueueJobRepositoryExt as _, SendEmailAuthenticationCodeJob},
29 user::{UserEmailRepository, UserRepository},
30};
31use mas_templates::{
32 FieldError, FormError, FormState, PasswordRegisterContext, RegisterFormField, TemplateContext,
33 Templates, ToFormState,
34};
35use serde::{Deserialize, Serialize};
36use zeroize::Zeroizing;
37
38use super::cookie::UserRegistrationSessions;
39use crate::{
40 BoundActivityTracker, Limiter, PreferredLanguage, RequesterFingerprint, SiteConfig,
41 captcha::Form as CaptchaForm, passwords::PasswordManager,
42 views::shared::OptionalPostAuthAction,
43};
44
45#[derive(Debug, Deserialize, Serialize)]
46pub(crate) struct RegisterForm {
47 username: String,
48 email: String,
49 password: String,
50 password_confirm: String,
51 #[serde(default)]
52 accept_terms: String,
53
54 #[serde(flatten, skip_serializing)]
55 captcha: CaptchaForm,
56}
57
58impl ToFormState for RegisterForm {
59 type Field = RegisterFormField;
60}
61
62#[derive(Deserialize)]
63pub struct QueryParams {
64 username: Option<String>,
65 #[serde(flatten)]
66 action: OptionalPostAuthAction,
67}
68
69#[tracing::instrument(name = "handlers.views.password_register.get", skip_all, err)]
70pub(crate) async fn get(
71 mut rng: BoxRng,
72 clock: BoxClock,
73 PreferredLanguage(locale): PreferredLanguage,
74 State(templates): State<Templates>,
75 State(url_builder): State<UrlBuilder>,
76 State(site_config): State<SiteConfig>,
77 mut repo: BoxRepository,
78 Query(query): Query<QueryParams>,
79 cookie_jar: CookieJar,
80) -> Result<Response, FancyError> {
81 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
82 let (session_info, cookie_jar) = cookie_jar.session_info();
83
84 let maybe_session = session_info.load_active_session(&mut repo).await?;
85
86 if maybe_session.is_some() {
87 let reply = query.action.go_next(&url_builder);
88 return Ok((cookie_jar, reply).into_response());
89 }
90
91 if !site_config.password_registration_enabled {
92 return Ok(url_builder
94 .redirect(&mas_router::Login::from(query.action.post_auth_action))
95 .into_response());
96 }
97
98 let mut ctx = PasswordRegisterContext::default();
99
100 if let Some(username) = query.username {
102 let mut form_state = FormState::default();
103 form_state.set_value(RegisterFormField::Username, Some(username));
104 ctx = ctx.with_form_state(form_state);
105 }
106
107 let content = render(
108 locale,
109 ctx,
110 query.action,
111 csrf_token,
112 &mut repo,
113 &templates,
114 site_config.captcha.clone(),
115 )
116 .await?;
117
118 Ok((cookie_jar, Html(content)).into_response())
119}
120
121#[tracing::instrument(name = "handlers.views.password_register.post", skip_all, err)]
122#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
123pub(crate) async fn post(
124 mut rng: BoxRng,
125 clock: BoxClock,
126 PreferredLanguage(locale): PreferredLanguage,
127 State(password_manager): State<PasswordManager>,
128 State(templates): State<Templates>,
129 State(url_builder): State<UrlBuilder>,
130 State(site_config): State<SiteConfig>,
131 State(homeserver): State<Arc<dyn HomeserverConnection>>,
132 State(http_client): State<reqwest::Client>,
133 (State(limiter), requester): (State<Limiter>, RequesterFingerprint),
134 mut policy: Policy,
135 mut repo: BoxRepository,
136 (user_agent, activity_tracker): (
137 Option<TypedHeader<headers::UserAgent>>,
138 BoundActivityTracker,
139 ),
140 Query(query): Query<OptionalPostAuthAction>,
141 cookie_jar: CookieJar,
142 Form(form): Form<ProtectedForm<RegisterForm>>,
143) -> Result<Response, FancyError> {
144 let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
145
146 let ip_address = activity_tracker.ip();
147 if !site_config.password_registration_enabled {
148 return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response());
149 }
150
151 let form = cookie_jar.verify_form(&clock, form)?;
152
153 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
154
155 let passed_captcha = form
158 .captcha
159 .verify(
160 &activity_tracker,
161 &http_client,
162 url_builder.public_hostname(),
163 site_config.captcha.as_ref(),
164 )
165 .await
166 .is_ok();
167
168 let state = {
170 let mut state = form.to_form_state();
171
172 if !passed_captcha {
173 state.add_error_on_form(FormError::Captcha);
174 }
175
176 let mut homeserver_denied_username = false;
177 if form.username.is_empty() {
178 state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
179 } else if repo.user().exists(&form.username).await? {
180 state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
182 } else if !homeserver.is_localpart_available(&form.username).await? {
183 tracing::warn!(
185 username = &form.username,
186 "Homeserver denied username provided by user"
187 );
188
189 homeserver_denied_username = true;
192 }
193
194 if form.email.is_empty() {
198 state.add_error_on_field(RegisterFormField::Email, FieldError::Required);
199 } else if Address::from_str(&form.email).is_err() {
200 state.add_error_on_field(RegisterFormField::Email, FieldError::Invalid);
201 }
202
203 if form.password.is_empty() {
204 state.add_error_on_field(RegisterFormField::Password, FieldError::Required);
205 }
206
207 if form.password_confirm.is_empty() {
208 state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Required);
209 }
210
211 if form.password != form.password_confirm {
212 state.add_error_on_field(RegisterFormField::Password, FieldError::Unspecified);
213 state.add_error_on_field(
214 RegisterFormField::PasswordConfirm,
215 FieldError::PasswordMismatch,
216 );
217 }
218
219 if !password_manager.is_password_complex_enough(&form.password)? {
220 state.add_error_on_field(
222 RegisterFormField::Password,
223 FieldError::Policy {
224 code: None,
225 message: "Password is too weak".to_owned(),
226 },
227 );
228 }
229
230 if site_config.tos_uri.is_some() && form.accept_terms != "on" {
232 state.add_error_on_field(RegisterFormField::AcceptTerms, FieldError::Required);
233 }
234
235 let res = policy
236 .evaluate_register(mas_policy::RegisterInput {
237 registration_method: mas_policy::RegistrationMethod::Password,
238 username: &form.username,
239 email: Some(&form.email),
240 requester: mas_policy::Requester {
241 ip_address: activity_tracker.ip(),
242 user_agent: user_agent.clone().map(|ua| ua.raw),
243 },
244 })
245 .await?;
246
247 for violation in res.violations {
248 match violation.field.as_deref() {
249 Some("email") => state.add_error_on_field(
250 RegisterFormField::Email,
251 FieldError::Policy {
252 code: violation.code.map(|c| c.as_str()),
253 message: violation.msg,
254 },
255 ),
256 Some("username") => {
257 homeserver_denied_username = false;
260 state.add_error_on_field(
261 RegisterFormField::Username,
262 FieldError::Policy {
263 code: violation.code.map(|c| c.as_str()),
264 message: violation.msg,
265 },
266 );
267 }
268 Some("password") => state.add_error_on_field(
269 RegisterFormField::Password,
270 FieldError::Policy {
271 code: violation.code.map(|c| c.as_str()),
272 message: violation.msg,
273 },
274 ),
275 _ => state.add_error_on_form(FormError::Policy {
276 code: violation.code.map(|c| c.as_str()),
277 message: violation.msg,
278 }),
279 }
280 }
281
282 if homeserver_denied_username {
283 state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
285 }
286
287 if state.is_valid() {
288 if let Err(e) = limiter.check_registration(requester) {
290 tracing::warn!(error = &e as &dyn std::error::Error);
291 state.add_error_on_form(FormError::RateLimitExceeded);
292 }
293
294 if let Err(e) = limiter.check_email_authentication_email(requester, &form.email) {
295 tracing::warn!(error = &e as &dyn std::error::Error);
296 state.add_error_on_form(FormError::RateLimitExceeded);
297 }
298 }
299
300 state
301 };
302
303 if !state.is_valid() {
304 let content = render(
305 locale,
306 PasswordRegisterContext::default().with_form_state(state),
307 query,
308 csrf_token,
309 &mut repo,
310 &templates,
311 site_config.captcha.clone(),
312 )
313 .await?;
314
315 return Ok((cookie_jar, Html(content)).into_response());
316 }
317
318 let post_auth_action = query
319 .post_auth_action
320 .map(serde_json::to_value)
321 .transpose()?;
322 let registration = repo
323 .user_registration()
324 .add(
325 &mut rng,
326 &clock,
327 form.username,
328 ip_address,
329 user_agent,
330 post_auth_action,
331 )
332 .await?;
333
334 let registration = if let Some(tos_uri) = &site_config.tos_uri {
335 repo.user_registration()
336 .set_terms_url(registration, tos_uri.clone())
337 .await?
338 } else {
339 registration
340 };
341
342 let user_email_authentication = repo
344 .user_email()
345 .add_authentication_for_registration(&mut rng, &clock, form.email, ®istration)
346 .await?;
347
348 repo.queue_job()
350 .schedule_job(
351 &mut rng,
352 &clock,
353 SendEmailAuthenticationCodeJob::new(&user_email_authentication, locale.to_string()),
354 )
355 .await?;
356
357 let registration = repo
358 .user_registration()
359 .set_email_authentication(registration, &user_email_authentication)
360 .await?;
361
362 let password = Zeroizing::new(form.password.into_bytes());
364 let (version, hashed_password) = password_manager.hash(&mut rng, password).await?;
365
366 let registration = repo
368 .user_registration()
369 .set_password(registration, hashed_password, version)
370 .await?;
371
372 repo.save().await?;
373
374 let cookie_jar = UserRegistrationSessions::load(&cookie_jar)
375 .add(®istration)
376 .save(cookie_jar, &clock);
377
378 Ok((
379 cookie_jar,
380 url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)),
381 )
382 .into_response())
383}
384
385async fn render(
386 locale: DataLocale,
387 ctx: PasswordRegisterContext,
388 action: OptionalPostAuthAction,
389 csrf_token: CsrfToken,
390 repo: &mut impl RepositoryAccess,
391 templates: &Templates,
392 captcha_config: Option<CaptchaConfig>,
393) -> Result<String, FancyError> {
394 let next = action.load_context(repo).await?;
395 let ctx = if let Some(next) = next {
396 ctx.with_post_action(next)
397 } else {
398 ctx
399 };
400 let ctx = ctx
401 .with_captcha(captcha_config)
402 .with_csrf(csrf_token.form_value())
403 .with_language(locale);
404
405 let content = templates.render_password_register(&ctx)?;
406 Ok(content)
407}
408
409#[cfg(test)]
410mod tests {
411 use hyper::{
412 Request, StatusCode,
413 header::{CONTENT_TYPE, LOCATION},
414 };
415 use mas_router::Route;
416 use sqlx::PgPool;
417
418 use crate::{
419 SiteConfig,
420 test_utils::{
421 CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup, test_site_config,
422 },
423 };
424
425 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
426 async fn test_password_disabled(pool: PgPool) {
427 setup();
428 let state = TestState::from_pool_with_site_config(
429 pool,
430 SiteConfig {
431 password_login_enabled: false,
432 password_registration_enabled: false,
433 ..test_site_config()
434 },
435 )
436 .await
437 .unwrap();
438
439 let request =
440 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
441 let response = state.request(request).await;
442 response.assert_status(StatusCode::SEE_OTHER);
443 response.assert_header_value(LOCATION, "/login");
444
445 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
446 .form(serde_json::json!({
447 "csrf": "abc",
448 "username": "john",
449 "email": "john@example.com",
450 "password": "hunter2",
451 "password_confirm": "hunter2",
452 }));
453 let response = state.request(request).await;
454 response.assert_status(StatusCode::METHOD_NOT_ALLOWED);
455 }
456
457 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
459 async fn test_register(pool: PgPool) {
460 setup();
461 let state = TestState::from_pool(pool).await.unwrap();
462 let cookies = CookieHelper::new();
463
464 let request =
466 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
467 let request = cookies.with_cookies(request);
468 let response = state.request(request).await;
469 cookies.save_cookies(&response);
470 response.assert_status(StatusCode::OK);
471 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
472 let csrf_token = response
474 .body()
475 .split("name=\"csrf\" value=\"")
476 .nth(1)
477 .unwrap()
478 .split('\"')
479 .next()
480 .unwrap();
481
482 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
484 .form(serde_json::json!({
485 "csrf": csrf_token,
486 "username": "john",
487 "email": "john@example.com",
488 "password": "correcthorsebatterystaple",
489 "password_confirm": "correcthorsebatterystaple",
490 "accept_terms": "on",
491 }));
492 let request = cookies.with_cookies(request);
493 let response = state.request(request).await;
494 cookies.save_cookies(&response);
495 response.assert_status(StatusCode::SEE_OTHER);
496 let location = response.headers().get(LOCATION).unwrap();
497
498 let id = location
500 .to_str()
501 .unwrap()
502 .rsplit('/')
503 .nth(1)
504 .unwrap()
505 .parse()
506 .unwrap();
507
508 let mut repo = state.repository().await.unwrap();
510 let registration = repo.user_registration().lookup(id).await.unwrap().unwrap();
511 assert_eq!(registration.username, "john".to_owned());
512 assert!(registration.password.is_some());
513
514 let email_authentication = repo
515 .user_email()
516 .lookup_authentication(registration.email_authentication_id.unwrap())
517 .await
518 .unwrap()
519 .unwrap();
520 assert_eq!(email_authentication.email, "john@example.com");
521 }
522
523 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
525 async fn test_register_password_mismatch(pool: PgPool) {
526 setup();
527 let state = TestState::from_pool(pool).await.unwrap();
528 let cookies = CookieHelper::new();
529
530 let request =
532 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
533 let request = cookies.with_cookies(request);
534 let response = state.request(request).await;
535 cookies.save_cookies(&response);
536 response.assert_status(StatusCode::OK);
537 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
538 let csrf_token = response
540 .body()
541 .split("name=\"csrf\" value=\"")
542 .nth(1)
543 .unwrap()
544 .split('\"')
545 .next()
546 .unwrap();
547
548 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
550 .form(serde_json::json!({
551 "csrf": csrf_token,
552 "username": "john",
553 "email": "john@example.com",
554 "password": "hunter2",
555 "password_confirm": "mismatch",
556 "accept_terms": "on",
557 }));
558 let request = cookies.with_cookies(request);
559 let response = state.request(request).await;
560 cookies.save_cookies(&response);
561 response.assert_status(StatusCode::OK);
562 assert!(response.body().contains("Password fields don't match"));
563 }
564
565 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
566 async fn test_register_username_too_long(pool: PgPool) {
567 setup();
568 let state = TestState::from_pool(pool).await.unwrap();
569 let cookies = CookieHelper::new();
570
571 let request =
573 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
574 let request = cookies.with_cookies(request);
575 let response = state.request(request).await;
576 cookies.save_cookies(&response);
577 response.assert_status(StatusCode::OK);
578 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
579 let csrf_token = response
581 .body()
582 .split("name=\"csrf\" value=\"")
583 .nth(1)
584 .unwrap()
585 .split('\"')
586 .next()
587 .unwrap();
588
589 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
591 .form(serde_json::json!({
592 "csrf": csrf_token,
593 "username": "a".repeat(256),
594 "email": "john@example.com",
595 "password": "hunter2",
596 "password_confirm": "hunter2",
597 "accept_terms": "on",
598 }));
599 let request = cookies.with_cookies(request);
600 let response = state.request(request).await;
601 cookies.save_cookies(&response);
602 response.assert_status(StatusCode::OK);
603 assert!(
604 response.body().contains("Username is too long"),
605 "response body: {}",
606 response.body()
607 );
608 }
609
610 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
612 async fn test_register_user_exists(pool: PgPool) {
613 setup();
614 let state = TestState::from_pool(pool).await.unwrap();
615 let mut rng = state.rng();
616 let cookies = CookieHelper::new();
617
618 let mut repo = state.repository().await.unwrap();
620 repo.user()
621 .add(&mut rng, &state.clock, "john".to_owned())
622 .await
623 .unwrap();
624 repo.save().await.unwrap();
625
626 let request =
628 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
629 let request = cookies.with_cookies(request);
630 let response = state.request(request).await;
631 cookies.save_cookies(&response);
632 response.assert_status(StatusCode::OK);
633 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
634 let csrf_token = response
636 .body()
637 .split("name=\"csrf\" value=\"")
638 .nth(1)
639 .unwrap()
640 .split('\"')
641 .next()
642 .unwrap();
643
644 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
646 .form(serde_json::json!({
647 "csrf": csrf_token,
648 "username": "john",
649 "email": "john@example.com",
650 "password": "hunter2",
651 "password_confirm": "hunter2",
652 "accept_terms": "on",
653 }));
654 let request = cookies.with_cookies(request);
655 let response = state.request(request).await;
656 cookies.save_cookies(&response);
657 response.assert_status(StatusCode::OK);
658 assert!(response.body().contains("This username is already taken"));
659 }
660
661 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
664 async fn test_register_user_reserved(pool: PgPool) {
665 setup();
666 let state = TestState::from_pool(pool).await.unwrap();
667 let cookies = CookieHelper::new();
668
669 let request =
671 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
672 let request = cookies.with_cookies(request);
673 let response = state.request(request).await;
674 cookies.save_cookies(&response);
675 response.assert_status(StatusCode::OK);
676 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
677 let csrf_token = response
679 .body()
680 .split("name=\"csrf\" value=\"")
681 .nth(1)
682 .unwrap()
683 .split('\"')
684 .next()
685 .unwrap();
686
687 state.homeserver_connection.reserve_localpart("john").await;
689
690 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
692 .form(serde_json::json!({
693 "csrf": csrf_token,
694 "username": "john",
695 "email": "john@example.com",
696 "password": "hunter2",
697 "password_confirm": "hunter2",
698 "accept_terms": "on",
699 }));
700 let request = cookies.with_cookies(request);
701 let response = state.request(request).await;
702 cookies.save_cookies(&response);
703 response.assert_status(StatusCode::OK);
704 assert!(response.body().contains("This username is already taken"));
705 }
706}