1use std::sync::Arc;
8
9use axum::{
10 Form,
11 extract::{Path, State},
12 response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17 FancyError, SessionInfoExt,
18 cookies::CookieJar,
19 csrf::{CsrfExt, ProtectedForm},
20 sentry::SentryEventID,
21};
22use mas_data_model::UserAgent;
23use mas_jose::jwt::Jwt;
24use mas_matrix::HomeserverConnection;
25use mas_policy::Policy;
26use mas_router::UrlBuilder;
27use mas_storage::{
28 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
29 queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
30 upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
31 user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
32};
33use mas_templates::{
34 AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
35 ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
36};
37use minijinja::Environment;
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tracing::warn;
41use ulid::Ulid;
42
43use super::{
44 UpstreamSessionsCookie,
45 template::{AttributeMappingContext, environment},
46};
47use crate::{
48 BoundActivityTracker, PreferredLanguage, SiteConfig, impl_from_error_for_route,
49 views::shared::OptionalPostAuthAction,
50};
51
52const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
53const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
54const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
55
56#[derive(Debug, Error)]
57pub(crate) enum RouteError {
58 #[error("Link not found")]
60 LinkNotFound,
61
62 #[error("Session not found")]
64 SessionNotFound,
65
66 #[error("User not found")]
68 UserNotFound,
69
70 #[error("Upstream provider not found")]
72 ProviderNotFound,
73
74 #[error("Template {template:?} rendered to an empty string")]
76 RequiredAttributeEmpty { template: String },
77
78 #[error(
80 "Template {template:?} could not be rendered from the upstream provider's response for required claim"
81 )]
82 RequiredAttributeRender {
83 template: String,
84
85 #[source]
86 source: minijinja::Error,
87 },
88
89 #[error("Session already consumed")]
91 SessionConsumed,
92
93 #[error("Missing session cookie")]
94 MissingCookie,
95
96 #[error("Invalid form action")]
97 InvalidFormAction,
98
99 #[error("Homeserver connection error")]
100 HomeserverConnection(#[source] anyhow::Error),
101
102 #[error(transparent)]
103 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
104}
105
106impl_from_error_for_route!(mas_templates::TemplateError);
107impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
108impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
109impl_from_error_for_route!(mas_storage::RepositoryError);
110impl_from_error_for_route!(mas_policy::EvaluationError);
111impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
112
113impl IntoResponse for RouteError {
114 fn into_response(self) -> axum::response::Response {
115 let event_id = sentry::capture_error(&self);
116 let response = match self {
117 Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
118 Self::Internal(e) => FancyError::from(e).into_response(),
119 e => FancyError::from(e).into_response(),
120 };
121
122 (SentryEventID::from(event_id), response).into_response()
123 }
124}
125
126fn render_attribute_template(
139 environment: &Environment,
140 template: &str,
141 context: &minijinja::Value,
142 required: bool,
143) -> Result<Option<String>, RouteError> {
144 match environment.render_str(template, context) {
145 Ok(value) if value.is_empty() => {
146 if required {
147 return Err(RouteError::RequiredAttributeEmpty {
148 template: template.to_owned(),
149 });
150 }
151
152 Ok(None)
153 }
154
155 Ok(value) => Ok(Some(value)),
156
157 Err(source) => {
158 if required {
159 return Err(RouteError::RequiredAttributeRender {
160 template: template.to_owned(),
161 source,
162 });
163 }
164
165 tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
166 Ok(None)
167 }
168 }
169}
170
171#[derive(Deserialize, Serialize)]
172#[serde(rename_all = "lowercase", tag = "action")]
173pub(crate) enum FormData {
174 Register {
175 #[serde(default)]
176 username: Option<String>,
177 #[serde(default)]
178 import_email: Option<String>,
179 #[serde(default)]
180 import_display_name: Option<String>,
181 #[serde(default)]
182 accept_terms: Option<String>,
183 },
184 Link,
185}
186
187impl ToFormState for FormData {
188 type Field = mas_templates::UpstreamRegisterFormField;
189}
190
191#[tracing::instrument(
192 name = "handlers.upstream_oauth2.link.get",
193 fields(upstream_oauth_link.id = %link_id),
194 skip_all,
195 err,
196)]
197pub(crate) async fn get(
198 mut rng: BoxRng,
199 clock: BoxClock,
200 mut repo: BoxRepository,
201 mut policy: Policy,
202 PreferredLanguage(locale): PreferredLanguage,
203 State(templates): State<Templates>,
204 State(url_builder): State<UrlBuilder>,
205 State(homeserver): State<Arc<dyn HomeserverConnection>>,
206 cookie_jar: CookieJar,
207 activity_tracker: BoundActivityTracker,
208 user_agent: Option<TypedHeader<headers::UserAgent>>,
209 Path(link_id): Path<Ulid>,
210) -> Result<impl IntoResponse, RouteError> {
211 let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
212 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
213 let (session_id, post_auth_action) = sessions_cookie
214 .lookup_link(link_id)
215 .map_err(|_| RouteError::MissingCookie)?;
216
217 let post_auth_action = OptionalPostAuthAction {
218 post_auth_action: post_auth_action.cloned(),
219 };
220
221 let link = repo
222 .upstream_oauth_link()
223 .lookup(link_id)
224 .await?
225 .ok_or(RouteError::LinkNotFound)?;
226
227 let upstream_session = repo
228 .upstream_oauth_session()
229 .lookup(session_id)
230 .await?
231 .ok_or(RouteError::SessionNotFound)?;
232
233 if upstream_session.link_id() != Some(link.id) {
236 return Err(RouteError::SessionNotFound);
237 }
238
239 if upstream_session.is_consumed() {
240 return Err(RouteError::SessionConsumed);
241 }
242
243 let (user_session_info, cookie_jar) = cookie_jar.session_info();
244 let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
245 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
246
247 let response = match (maybe_user_session, link.user_id) {
248 (Some(session), Some(user_id)) if session.user.id == user_id => {
249 let upstream_session = repo
252 .upstream_oauth_session()
253 .consume(&clock, upstream_session)
254 .await?;
255
256 repo.browser_session()
257 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
258 .await?;
259
260 cookie_jar = cookie_jar.set_session(&session);
261
262 repo.save().await?;
263
264 post_auth_action.go_next(&url_builder).into_response()
265 }
266
267 (Some(user_session), Some(user_id)) => {
268 let user = repo
272 .user()
273 .lookup(user_id)
274 .await?
275 .ok_or(RouteError::UserNotFound)?;
276
277 let ctx = UpstreamExistingLinkContext::new(user)
278 .with_session(user_session)
279 .with_csrf(csrf_token.form_value())
280 .with_language(locale);
281
282 Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
283 }
284
285 (Some(user_session), None) => {
286 let ctx = UpstreamSuggestLink::new(&link)
288 .with_session(user_session)
289 .with_csrf(csrf_token.form_value())
290 .with_language(locale);
291
292 Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
293 }
294
295 (None, Some(user_id)) => {
296 let user = repo
298 .user()
299 .lookup(user_id)
300 .await?
301 .ok_or(RouteError::UserNotFound)?;
302
303 if user.deactivated_at.is_some() {
305 let ctx = AccountInactiveContext::new(user)
307 .with_csrf(csrf_token.form_value())
308 .with_language(locale);
309 let fallback = templates.render_account_deactivated(&ctx)?;
310 return Ok((cookie_jar, Html(fallback).into_response()));
311 }
312
313 if user.locked_at.is_some() {
314 let ctx = AccountInactiveContext::new(user)
316 .with_csrf(csrf_token.form_value())
317 .with_language(locale);
318 let fallback = templates.render_account_locked(&ctx)?;
319 return Ok((cookie_jar, Html(fallback).into_response()));
320 }
321
322 let session = repo
323 .browser_session()
324 .add(&mut rng, &clock, &user, user_agent)
325 .await?;
326
327 let upstream_session = repo
328 .upstream_oauth_session()
329 .consume(&clock, upstream_session)
330 .await?;
331
332 repo.browser_session()
333 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
334 .await?;
335
336 cookie_jar = sessions_cookie
337 .consume_link(link_id)?
338 .save(cookie_jar, &clock);
339 cookie_jar = cookie_jar.set_session(&session);
340
341 repo.save().await?;
342
343 post_auth_action.go_next(&url_builder).into_response()
344 }
345
346 (None, None) => {
347 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
350
351 let provider = repo
352 .upstream_oauth_provider()
353 .lookup(link.provider_id)
354 .await?
355 .ok_or(RouteError::ProviderNotFound)?;
356
357 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
358
359 let env = environment();
360
361 let mut context = AttributeMappingContext::new();
362 if let Some(id_token) = id_token {
363 let (_, payload) = id_token.into_parts();
364 context = context.with_id_token_claims(payload);
365 }
366 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
367 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
368 }
369 if let Some(userinfo) = upstream_session.userinfo() {
370 context = context.with_userinfo_claims(userinfo.clone());
371 }
372 let context = context.build();
373
374 let ctx = if provider.claims_imports.displayname.ignore() {
375 ctx
376 } else {
377 let template = provider
378 .claims_imports
379 .displayname
380 .template
381 .as_deref()
382 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
383
384 match render_attribute_template(
385 &env,
386 template,
387 &context,
388 provider.claims_imports.displayname.is_required(),
389 )? {
390 Some(value) => ctx
391 .with_display_name(value, provider.claims_imports.displayname.is_forced()),
392 None => ctx,
393 }
394 };
395
396 let ctx = if provider.claims_imports.email.ignore() {
397 ctx
398 } else {
399 let template = provider
400 .claims_imports
401 .email
402 .template
403 .as_deref()
404 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
405
406 match render_attribute_template(
407 &env,
408 template,
409 &context,
410 provider.claims_imports.email.is_required(),
411 )? {
412 Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
413 None => ctx,
414 }
415 };
416
417 let ctx = if provider.claims_imports.localpart.ignore() {
418 ctx
419 } else {
420 let template = provider
421 .claims_imports
422 .localpart
423 .template
424 .as_deref()
425 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
426
427 match render_attribute_template(
428 &env,
429 template,
430 &context,
431 provider.claims_imports.localpart.is_required(),
432 )? {
433 Some(localpart) => {
434 let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
438 let is_available = homeserver
439 .is_localpart_available(&localpart)
440 .await
441 .map_err(RouteError::HomeserverConnection)?;
442
443 if maybe_existing_user.is_some() || !is_available {
444 if let Some(existing_user) = maybe_existing_user {
445 warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
448 }
449
450 let ctx = ErrorContext::new()
452 .with_code("User exists")
453 .with_description(format!(
454 r"Upstream account provider returned {localpart:?} as username,
455 which is not linked to that upstream account"
456 ))
457 .with_language(&locale);
458
459 return Ok((
460 cookie_jar,
461 Html(templates.render_error(&ctx)?).into_response(),
462 ));
463 }
464
465 let res = policy
466 .evaluate_register(mas_policy::RegisterInput {
467 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
468 username: &localpart,
469 email: None,
470 requester: mas_policy::Requester {
471 ip_address: activity_tracker.ip(),
472 user_agent: user_agent.clone().map(|ua| ua.raw),
473 },
474 })
475 .await?;
476
477 if res.valid() {
478 ctx.with_localpart(
480 localpart,
481 provider.claims_imports.localpart.is_forced(),
482 )
483 } else if provider.claims_imports.localpart.is_forced() {
484 let ctx = ErrorContext::new()
488 .with_code("Policy error")
489 .with_description(format!(
490 r"Upstream account provider returned {localpart:?} as username,
491 which does not pass the policy check: {res}"
492 ))
493 .with_language(&locale);
494
495 return Ok((
496 cookie_jar,
497 Html(templates.render_error(&ctx)?).into_response(),
498 ));
499 } else {
500 ctx
502 }
503 }
504 None => ctx,
505 }
506 };
507
508 let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
509
510 Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
511 }
512 };
513
514 Ok((cookie_jar, response))
515}
516
517#[tracing::instrument(
518 name = "handlers.upstream_oauth2.link.post",
519 fields(upstream_oauth_link.id = %link_id),
520 skip_all,
521 err,
522)]
523pub(crate) async fn post(
524 mut rng: BoxRng,
525 clock: BoxClock,
526 mut repo: BoxRepository,
527 cookie_jar: CookieJar,
528 user_agent: Option<TypedHeader<headers::UserAgent>>,
529 mut policy: Policy,
530 PreferredLanguage(locale): PreferredLanguage,
531 activity_tracker: BoundActivityTracker,
532 State(templates): State<Templates>,
533 State(homeserver): State<Arc<dyn HomeserverConnection>>,
534 State(url_builder): State<UrlBuilder>,
535 State(site_config): State<SiteConfig>,
536 Path(link_id): Path<Ulid>,
537 Form(form): Form<ProtectedForm<FormData>>,
538) -> Result<Response, RouteError> {
539 let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
540 let form = cookie_jar.verify_form(&clock, form)?;
541
542 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
543 let (session_id, post_auth_action) = sessions_cookie
544 .lookup_link(link_id)
545 .map_err(|_| RouteError::MissingCookie)?;
546
547 let post_auth_action = OptionalPostAuthAction {
548 post_auth_action: post_auth_action.cloned(),
549 };
550
551 let link = repo
552 .upstream_oauth_link()
553 .lookup(link_id)
554 .await?
555 .ok_or(RouteError::LinkNotFound)?;
556
557 let upstream_session = repo
558 .upstream_oauth_session()
559 .lookup(session_id)
560 .await?
561 .ok_or(RouteError::SessionNotFound)?;
562
563 if upstream_session.link_id() != Some(link.id) {
566 return Err(RouteError::SessionNotFound);
567 }
568
569 if upstream_session.is_consumed() {
570 return Err(RouteError::SessionConsumed);
571 }
572
573 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
574 let (user_session_info, cookie_jar) = cookie_jar.session_info();
575 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
576 let form_state = form.to_form_state();
577
578 let session = match (maybe_user_session, link.user_id, form) {
579 (Some(session), None, FormData::Link) => {
580 repo.upstream_oauth_link()
583 .associate_to_user(&link, &session.user)
584 .await?;
585
586 session
587 }
588
589 (
590 None,
591 None,
592 FormData::Register {
593 username,
594 import_email,
595 import_display_name,
596 accept_terms,
597 },
598 ) => {
599 let import_email = import_email.is_some();
606 let import_display_name = import_display_name.is_some();
607 let accept_terms = accept_terms.is_some();
608
609 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
610
611 let provider = repo
612 .upstream_oauth_provider()
613 .lookup(link.provider_id)
614 .await?
615 .ok_or(RouteError::ProviderNotFound)?;
616
617 let env = environment();
619
620 let mut context = AttributeMappingContext::new();
621 if let Some(id_token) = id_token {
622 let (_, payload) = id_token.into_parts();
623 context = context.with_id_token_claims(payload);
624 }
625 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
626 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
627 }
628 if let Some(userinfo) = upstream_session.userinfo() {
629 context = context.with_userinfo_claims(userinfo.clone());
630 }
631 let context = context.build();
632
633 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
635
636 let display_name = if provider
637 .claims_imports
638 .displayname
639 .should_import(import_display_name)
640 {
641 let template = provider
642 .claims_imports
643 .displayname
644 .template
645 .as_deref()
646 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
647
648 render_attribute_template(
649 &env,
650 template,
651 &context,
652 provider.claims_imports.displayname.is_required(),
653 )?
654 } else {
655 None
656 };
657
658 let ctx = if let Some(ref display_name) = display_name {
659 ctx.with_display_name(
660 display_name.clone(),
661 provider.claims_imports.email.is_forced(),
662 )
663 } else {
664 ctx
665 };
666
667 let email = if provider.claims_imports.email.should_import(import_email) {
668 let template = provider
669 .claims_imports
670 .email
671 .template
672 .as_deref()
673 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
674
675 render_attribute_template(
676 &env,
677 template,
678 &context,
679 provider.claims_imports.email.is_required(),
680 )?
681 } else {
682 None
683 };
684
685 let ctx = if let Some(ref email) = email {
686 ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
687 } else {
688 ctx
689 };
690
691 let username = if provider.claims_imports.localpart.is_forced() {
692 let template = provider
693 .claims_imports
694 .localpart
695 .template
696 .as_deref()
697 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
698
699 render_attribute_template(&env, template, &context, true)?
700 } else {
701 username
703 }
704 .unwrap_or_default();
705
706 let ctx = ctx.with_localpart(
707 username.clone(),
708 provider.claims_imports.localpart.is_forced(),
709 );
710
711 let form_state = {
713 let mut form_state = form_state;
714 let mut homeserver_denied_username = false;
715 if username.is_empty() {
716 form_state.add_error_on_field(
717 mas_templates::UpstreamRegisterFormField::Username,
718 FieldError::Required,
719 );
720 } else if repo.user().exists(&username).await? {
721 form_state.add_error_on_field(
722 mas_templates::UpstreamRegisterFormField::Username,
723 FieldError::Exists,
724 );
725 } else if !homeserver
726 .is_localpart_available(&username)
727 .await
728 .map_err(RouteError::HomeserverConnection)?
729 {
730 tracing::warn!(
732 %username,
733 "Homeserver denied username provided by user"
734 );
735
736 homeserver_denied_username = true;
739 }
740
741 if site_config.tos_uri.is_some() && !accept_terms {
743 form_state.add_error_on_field(
744 mas_templates::UpstreamRegisterFormField::AcceptTerms,
745 FieldError::Required,
746 );
747 }
748
749 let res = policy
751 .evaluate_register(mas_policy::RegisterInput {
752 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
753 username: &username,
754 email: email.as_deref(),
755 requester: mas_policy::Requester {
756 ip_address: activity_tracker.ip(),
757 user_agent: user_agent.clone().map(|ua| ua.raw),
758 },
759 })
760 .await?;
761
762 for violation in res.violations {
763 match violation.field.as_deref() {
764 Some("username") => {
765 homeserver_denied_username = false;
769 form_state.add_error_on_field(
770 mas_templates::UpstreamRegisterFormField::Username,
771 FieldError::Policy {
772 code: violation.code.map(|c| c.as_str()),
773 message: violation.msg,
774 },
775 );
776 }
777 _ => form_state.add_error_on_form(FormError::Policy {
778 code: violation.code.map(|c| c.as_str()),
779 message: violation.msg,
780 }),
781 }
782 }
783
784 if homeserver_denied_username {
785 form_state.add_error_on_field(
787 mas_templates::UpstreamRegisterFormField::Username,
788 FieldError::Exists,
789 );
790 }
791
792 form_state
793 };
794
795 if !form_state.is_valid() {
796 let ctx = ctx
797 .with_form_state(form_state)
798 .with_csrf(csrf_token.form_value())
799 .with_language(locale);
800
801 return Ok((
802 cookie_jar,
803 Html(templates.render_upstream_oauth2_do_register(&ctx)?),
804 )
805 .into_response());
806 }
807
808 let user = repo.user().add(&mut rng, &clock, username).await?;
810
811 if let Some(terms_url) = &site_config.tos_uri {
812 repo.user_terms()
813 .accept_terms(&mut rng, &clock, &user, terms_url.clone())
814 .await?;
815 }
816
817 let mut job = ProvisionUserJob::new(&user);
819
820 if let Some(name) = display_name {
822 job = job.set_display_name(name);
823 }
824
825 repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
826
827 if let Some(email) = email {
829 repo.user_email()
830 .add(&mut rng, &clock, &user, email)
831 .await?;
832 }
833
834 repo.upstream_oauth_link()
835 .associate_to_user(&link, &user)
836 .await?;
837
838 repo.browser_session()
839 .add(&mut rng, &clock, &user, user_agent)
840 .await?
841 }
842
843 _ => return Err(RouteError::InvalidFormAction),
844 };
845
846 let upstream_session = repo
847 .upstream_oauth_session()
848 .consume(&clock, upstream_session)
849 .await?;
850
851 repo.browser_session()
852 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
853 .await?;
854
855 let cookie_jar = sessions_cookie
856 .consume_link(link_id)?
857 .save(cookie_jar, &clock);
858 let cookie_jar = cookie_jar.set_session(&session);
859
860 repo.save().await?;
861
862 Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
863}
864
865#[cfg(test)]
866mod tests {
867 use hyper::{Request, StatusCode, header::CONTENT_TYPE};
868 use mas_data_model::{
869 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
870 UpstreamOAuthProviderTokenAuthMethod,
871 };
872 use mas_iana::jose::JsonWebSignatureAlg;
873 use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
874 use mas_router::Route;
875 use mas_storage::{
876 Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
877 };
878 use oauth2_types::scope::{OPENID, Scope};
879 use sqlx::PgPool;
880
881 use super::UpstreamSessionsCookie;
882 use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
883
884 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
885 async fn test_register(pool: PgPool) {
886 setup();
887 let state = TestState::from_pool(pool).await.unwrap();
888 let mut rng = state.rng();
889 let cookies = CookieHelper::new();
890
891 let claims_imports = UpstreamOAuthProviderClaimsImports {
892 localpart: UpstreamOAuthProviderImportPreference {
893 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
894 template: None,
895 },
896 email: UpstreamOAuthProviderImportPreference {
897 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
898 template: None,
899 },
900 ..UpstreamOAuthProviderClaimsImports::default()
901 };
902
903 let id_token = serde_json::json!({
904 "preferred_username": "john",
905 "email": "john@example.com",
906 "email_verified": true,
907 });
908
909 let key = state
913 .key_store
914 .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
915 .unwrap();
916
917 let signer = key
918 .params()
919 .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
920 .unwrap();
921 let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
922 let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
923
924 let mut repo = state.repository().await.unwrap();
926 let provider = repo
927 .upstream_oauth_provider()
928 .add(
929 &mut rng,
930 &state.clock,
931 UpstreamOAuthProviderParams {
932 issuer: Some("https://example.com/".to_owned()),
933 human_name: Some("Example Ltd.".to_owned()),
934 brand_name: None,
935 scope: Scope::from_iter([OPENID]),
936 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
937 token_endpoint_signing_alg: None,
938 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
939 client_id: "client".to_owned(),
940 encrypted_client_secret: None,
941 claims_imports,
942 authorization_endpoint_override: None,
943 token_endpoint_override: None,
944 userinfo_endpoint_override: None,
945 fetch_userinfo: false,
946 userinfo_signed_response_alg: None,
947 jwks_uri_override: None,
948 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
949 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
950 response_mode: None,
951 additional_authorization_parameters: Vec::new(),
952 ui_order: 0,
953 },
954 )
955 .await
956 .unwrap();
957
958 let session = repo
959 .upstream_oauth_session()
960 .add(
961 &mut rng,
962 &state.clock,
963 &provider,
964 "state".to_owned(),
965 None,
966 "nonce".to_owned(),
967 )
968 .await
969 .unwrap();
970
971 let link = repo
972 .upstream_oauth_link()
973 .add(
974 &mut rng,
975 &state.clock,
976 &provider,
977 "subject".to_owned(),
978 None,
979 )
980 .await
981 .unwrap();
982
983 let session = repo
984 .upstream_oauth_session()
985 .complete_with_link(
986 &state.clock,
987 session,
988 &link,
989 Some(id_token.into_string()),
990 None,
991 None,
992 )
993 .await
994 .unwrap();
995
996 repo.save().await.unwrap();
997
998 let cookie_jar = state.cookie_jar();
999 let upstream_sessions = UpstreamSessionsCookie::default()
1000 .add(session.id, provider.id, "state".to_owned(), None)
1001 .add_link_to_session(session.id, link.id)
1002 .unwrap();
1003 let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1004 cookies.import(cookie_jar);
1005
1006 let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1007 let request = cookies.with_cookies(request);
1008 let response = state.request(request).await;
1009 cookies.save_cookies(&response);
1010 response.assert_status(StatusCode::OK);
1011 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1012
1013 let csrf_token = response
1015 .body()
1016 .split("name=\"csrf\" value=\"")
1017 .nth(1)
1018 .unwrap()
1019 .split('\"')
1020 .next()
1021 .unwrap();
1022
1023 let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1024 serde_json::json!({
1025 "csrf": csrf_token,
1026 "action": "register",
1027 "import_email": "on",
1028 "accept_terms": "on",
1029 }),
1030 );
1031 let request = cookies.with_cookies(request);
1032 let response = state.request(request).await;
1033 cookies.save_cookies(&response);
1034 response.assert_status(StatusCode::SEE_OTHER);
1035
1036 let mut repo = state.repository().await.unwrap();
1038 let user = repo
1039 .user()
1040 .find_by_username("john")
1041 .await
1042 .unwrap()
1043 .expect("user exists");
1044
1045 let link = repo
1046 .upstream_oauth_link()
1047 .find_by_subject(&provider, "subject")
1048 .await
1049 .unwrap()
1050 .expect("link exists");
1051
1052 assert_eq!(link.user_id, Some(user.id));
1053
1054 let page = repo
1055 .user_email()
1056 .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1057 .await
1058 .unwrap();
1059 let email = page.edges.first().expect("email exists");
1060
1061 assert_eq!(email.email, "john@example.com");
1062 }
1063}