1use std::sync::LazyLock;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::TypedHeader;
11use hyper::StatusCode;
12use mas_axum_utils::record_error;
13use mas_data_model::{BoxClock, BoxRng, UlidExt as _};
14use mas_iana::oauth::OAuthClientAuthenticationMethod;
15use mas_keystore::Encrypter;
16use mas_policy::{EvaluationResult, Policy};
17use mas_storage::{BoxRepository, oauth2::OAuth2ClientRepository};
18use oauth2_types::{
19 errors::{ClientError, ClientErrorCode},
20 registration::{
21 ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
22 VerifiedClientMetadata,
23 },
24 requests::GrantType,
25};
26use opentelemetry::{Key, KeyValue, metrics::Counter};
27use psl::Psl;
28use rand::distributions::{Alphanumeric, DistString};
29use serde::Serialize;
30use sha2::Digest as _;
31use thiserror::Error;
32use tracing::info;
33use url::Url;
34
35use crate::{BoundActivityTracker, METER, SiteConfig, impl_from_error_for_route};
36
37static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
38 METER
39 .u64_counter("mas.oauth2.registration_request")
40 .with_description("Number of OAuth2 registration requests")
41 .with_unit("{request}")
42 .build()
43});
44const RESULT: Key = Key::from_static_str("result");
45
46#[derive(Debug, Error)]
47pub(crate) enum RouteError {
48 #[error(transparent)]
49 Internal(Box<dyn std::error::Error + Send + Sync>),
50
51 #[error(transparent)]
52 JsonExtract(#[from] axum::extract::rejection::JsonRejection),
53
54 #[error("invalid client metadata")]
55 InvalidClientMetadata(#[from] ClientMetadataVerificationError),
56
57 #[error("{0} is a public suffix, not a valid domain")]
58 UrlIsPublicSuffix(&'static str),
59
60 #[error("client registration denied by the policy: {0}")]
61 PolicyDenied(EvaluationResult),
62}
63
64impl_from_error_for_route!(mas_storage::RepositoryError);
65impl_from_error_for_route!(mas_policy::LoadError);
66impl_from_error_for_route!(mas_policy::EvaluationError);
67impl_from_error_for_route!(mas_keystore::aead::Error);
68impl_from_error_for_route!(serde_json::Error);
69
70impl IntoResponse for RouteError {
71 fn into_response(self) -> axum::response::Response {
72 let sentry_event_id = record_error!(self, Self::Internal(_));
73
74 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
75
76 let response = match self {
77 Self::Internal(_) => (
78 StatusCode::INTERNAL_SERVER_ERROR,
79 Json(ClientError::from(ClientErrorCode::ServerError)),
80 )
81 .into_response(),
82
83 Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
87 StatusCode::BAD_REQUEST,
88 Json(
89 ClientError::from(ClientErrorCode::InvalidClientMetadata)
90 .with_description(e.to_string()),
91 ),
92 )
93 .into_response(),
94
95 Self::JsonExtract(_) => (
98 StatusCode::BAD_REQUEST,
99 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
100 )
101 .into_response(),
102
103 Self::InvalidClientMetadata(
107 ClientMetadataVerificationError::MissingRedirectUris
108 | ClientMetadataVerificationError::RedirectUriWithFragment(_),
109 ) => (
110 StatusCode::BAD_REQUEST,
111 Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
112 )
113 .into_response(),
114
115 Self::InvalidClientMetadata(e) => (
116 StatusCode::BAD_REQUEST,
117 Json(
118 ClientError::from(ClientErrorCode::InvalidClientMetadata)
119 .with_description(e.to_string()),
120 ),
121 )
122 .into_response(),
123
124 Self::UrlIsPublicSuffix("redirect_uri") => (
128 StatusCode::BAD_REQUEST,
129 Json(
130 ClientError::from(ClientErrorCode::InvalidRedirectUri)
131 .with_description("redirect_uri is not using a valid domain".to_owned()),
132 ),
133 )
134 .into_response(),
135
136 Self::UrlIsPublicSuffix(field) => (
137 StatusCode::BAD_REQUEST,
138 Json(
139 ClientError::from(ClientErrorCode::InvalidClientMetadata)
140 .with_description(format!("{field} is not using a valid domain")),
141 ),
142 )
143 .into_response(),
144
145 Self::PolicyDenied(evaluation) => {
149 let code = if evaluation
151 .violations
152 .iter()
153 .any(|v| v.msg.contains("redirect_uri"))
154 {
155 ClientErrorCode::InvalidRedirectUri
156 } else {
157 ClientErrorCode::InvalidClientMetadata
158 };
159
160 let collected = &evaluation
161 .violations
162 .iter()
163 .map(|v| v.msg.clone())
164 .collect::<Vec<String>>();
165 let joined = collected.join("; ");
166
167 (
168 StatusCode::BAD_REQUEST,
169 Json(ClientError::from(code).with_description(joined)),
170 )
171 .into_response()
172 }
173 };
174
175 (sentry_event_id, response).into_response()
176 }
177}
178
179#[derive(Serialize)]
180struct RouteResponse {
181 #[serde(flatten)]
182 response: ClientRegistrationResponse,
183 #[serde(flatten)]
184 metadata: VerifiedClientMetadata,
185}
186
187fn host_is_public_suffix(url: &Url) -> bool {
189 let host = url.host_str().unwrap_or_default().as_bytes();
190 let Some(suffix) = psl::List.suffix(host) else {
191 return false;
194 };
195
196 if !suffix.is_known() {
197 return false;
199 }
200
201 if host.len() <= suffix.as_bytes().len() + 1 {
205 return true;
207 }
208
209 false
210}
211
212fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
214 url.iter().any(|(_lang, url)| host_is_public_suffix(url))
215}
216
217#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all)]
218pub(crate) async fn post(
219 mut rng: BoxRng,
220 clock: BoxClock,
221 mut repo: BoxRepository,
222 mut policy: Policy,
223 activity_tracker: BoundActivityTracker,
224 user_agent: Option<TypedHeader<headers::UserAgent>>,
225 State(encrypter): State<Encrypter>,
226 State(site_config): State<SiteConfig>,
227 body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
228) -> Result<impl IntoResponse, RouteError> {
229 let Json(body) = body?;
231
232 let mut body = body.sorted();
234
235 let body_json = serde_json::to_string(&body)?;
237
238 info!(body = body_json, "Client registration");
239
240 if !site_config.device_code_grant_enabled
242 && let Some(grant_types) = &mut body.grant_types
243 && grant_types.contains(&GrantType::DeviceCode)
244 {
245 tracing::warn!(
246 "A client requested the device_code grant type but it's disabled, dropping from the grant types"
247 );
248 grant_types.retain(|t| t != &GrantType::DeviceCode);
249 }
250
251 let user_agent = user_agent.map(|ua| ua.to_string());
252
253 let metadata = body.validate()?;
255
256 if let Some(client_uri) = &metadata.client_uri
259 && localised_url_has_public_suffix(client_uri)
260 {
261 return Err(RouteError::UrlIsPublicSuffix("client_uri"));
262 }
263
264 if let Some(logo_uri) = &metadata.logo_uri
265 && localised_url_has_public_suffix(logo_uri)
266 {
267 return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
268 }
269
270 if let Some(policy_uri) = &metadata.policy_uri
271 && localised_url_has_public_suffix(policy_uri)
272 {
273 return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
274 }
275
276 if let Some(tos_uri) = &metadata.tos_uri
277 && localised_url_has_public_suffix(tos_uri)
278 {
279 return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
280 }
281
282 if let Some(initiate_login_uri) = &metadata.initiate_login_uri
283 && host_is_public_suffix(initiate_login_uri)
284 {
285 return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
286 }
287
288 for redirect_uri in metadata.redirect_uris() {
289 if host_is_public_suffix(redirect_uri) {
290 return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
291 }
292 }
293
294 let res = policy
295 .evaluate_client_registration(mas_policy::ClientRegistrationInput {
296 client_metadata: &metadata,
297 requester: mas_policy::Requester {
298 ip_address: activity_tracker.ip(),
299 user_agent,
300 },
301 })
302 .await?;
303 if !res.valid() {
304 return Err(RouteError::PolicyDenied(res));
305 }
306
307 let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
308 Some(
309 OAuthClientAuthenticationMethod::ClientSecretJwt
310 | OAuthClientAuthenticationMethod::ClientSecretPost
311 | OAuthClientAuthenticationMethod::ClientSecretBasic,
312 ) => {
313 let client_secret = Alphanumeric.sample_string(&mut rng, 20);
315 let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
316 (Some(client_secret), Some(encrypted_client_secret))
317 }
318 _ => (None, None),
319 };
320
321 let (digest_hash, existing_client) = if client_secret.is_none() {
324 let hash = sha2::Sha256::digest(body_json);
331 let hash = hex::encode(hash);
332 let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
333 (Some(hash), client)
334 } else {
335 (None, None)
336 };
337
338 let client = if let Some(client) = existing_client {
339 tracing::info!(%client.id, "Reusing existing client");
340 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "reused")]);
341 client
342 } else {
343 let client = repo
344 .oauth2_client()
345 .add(
346 &mut rng,
347 &clock,
348 metadata.redirect_uris().to_vec(),
349 digest_hash,
350 encrypted_client_secret,
351 metadata.application_type.clone(),
352 metadata.grant_types().to_vec(),
354 metadata
355 .client_name
356 .clone()
357 .map(Localized::to_non_localized),
358 metadata.logo_uri.clone().map(Localized::to_non_localized),
359 metadata.client_uri.clone().map(Localized::to_non_localized),
360 metadata.policy_uri.clone().map(Localized::to_non_localized),
361 metadata.tos_uri.clone().map(Localized::to_non_localized),
362 metadata.jwks_uri.clone(),
363 metadata.jwks.clone(),
364 metadata.id_token_signed_response_alg.clone(),
366 metadata.userinfo_signed_response_alg.clone(),
367 metadata.token_endpoint_auth_method.clone(),
368 metadata.token_endpoint_auth_signing_alg.clone(),
369 metadata.initiate_login_uri.clone(),
370 )
371 .await?;
372 tracing::info!(%client.id, "Registered new client");
373 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "created")]);
374 client
375 };
376
377 let response = ClientRegistrationResponse {
378 client_id: client.client_id.clone(),
379 client_secret,
380 client_id_issued_at: Some(client.id.datetime_utc()),
382 client_secret_expires_at: None,
383 };
384
385 let metadata = client.into_metadata().validate()?;
388
389 repo.save().await?;
390
391 let response = RouteResponse { response, metadata };
392
393 Ok((StatusCode::CREATED, Json(response)))
394}
395
396#[cfg(test)]
397mod tests {
398 use hyper::{Request, StatusCode};
399 use insta::assert_json_snapshot;
400 use mas_data_model::SiteConfig;
401 use mas_router::SimpleRoute;
402 use oauth2_types::{
403 errors::{ClientError, ClientErrorCode},
404 registration::ClientRegistrationResponse,
405 };
406 use sqlx::PgPool;
407 use url::Url;
408
409 use crate::{
410 oauth2::registration::host_is_public_suffix,
411 test_utils::{RequestBuilderExt, ResponseExt, TestState, setup, test_site_config},
412 };
413
414 #[test]
415 fn test_public_suffix_list() {
416 fn url_is_public_suffix(url: &str) -> bool {
417 host_is_public_suffix(&Url::parse(url).unwrap())
418 }
419
420 assert!(url_is_public_suffix("https://.com"));
421 assert!(url_is_public_suffix("https://.com."));
422 assert!(url_is_public_suffix("https://co.uk"));
423 assert!(url_is_public_suffix("https://github.io"));
424 assert!(!url_is_public_suffix("https://example.com"));
425 assert!(!url_is_public_suffix("https://example.com."));
426 assert!(!url_is_public_suffix("https://x.com"));
427 assert!(!url_is_public_suffix("https://x.com."));
428 assert!(!url_is_public_suffix("https://matrix-org.github.io"));
429 assert!(!url_is_public_suffix("http://localhost"));
430 assert!(!url_is_public_suffix("org.matrix:/callback"));
431 assert!(!url_is_public_suffix("http://somerandominternaldomain"));
432 }
433
434 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
435 async fn test_registration_error(pool: PgPool) {
436 setup();
437 let state = TestState::from_pool(pool).await.unwrap();
438
439 let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
441 .body("this is not a json".to_owned())
442 .unwrap();
443
444 let response = state.request(request).await;
445 response.assert_status(StatusCode::BAD_REQUEST);
446 let response: ClientError = response.json();
447 assert_eq!(response.error, ClientErrorCode::InvalidRequest);
448
449 let request =
451 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
452 "client_uri": "this is not a uri",
453 }));
454
455 let response = state.request(request).await;
456 response.assert_status(StatusCode::BAD_REQUEST);
457 let response: ClientError = response.json();
458 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
459
460 let request =
462 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
463 "application_type": "web",
464 "client_uri": "https://example.com/",
465 "redirect_uris": ["http://this-is-insecure.com/"],
466 }));
467
468 let response = state.request(request).await;
469 response.assert_status(StatusCode::BAD_REQUEST);
470 let response: ClientError = response.json();
471 assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
472
473 let request =
475 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
476 "client_uri": "https://example.com/",
477 "redirect_uris": ["https://example.com/"],
478 "response_types": ["id_token"],
479 "grant_types": ["authorization_code"],
480 }));
481
482 let response = state.request(request).await;
483 response.assert_status(StatusCode::BAD_REQUEST);
484 let response: ClientError = response.json();
485 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
486
487 let request =
489 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
490 "client_uri": "https://github.io/",
491 "redirect_uris": ["https://github.io/"],
492 "response_types": ["code"],
493 "grant_types": ["authorization_code"],
494 "token_endpoint_auth_method": "client_secret_basic",
495 }));
496
497 let response = state.request(request).await;
498 response.assert_status(StatusCode::BAD_REQUEST);
499 let response: ClientError = response.json();
500 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
501 assert_eq!(
502 response.error_description.unwrap(),
503 "client_uri is not using a valid domain"
504 );
505
506 let request =
508 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
509 "client_uri": "https://example.com/",
510 "client_uri#fr-FR": "https://github.io/",
511 "redirect_uris": ["https://example.com/"],
512 "response_types": ["code"],
513 "grant_types": ["authorization_code"],
514 "token_endpoint_auth_method": "client_secret_basic",
515 }));
516
517 let response = state.request(request).await;
518 response.assert_status(StatusCode::BAD_REQUEST);
519 let response: ClientError = response.json();
520 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
521 assert_eq!(
522 response.error_description.unwrap(),
523 "client_uri is not using a valid domain"
524 );
525 }
526
527 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
528 async fn test_registration(pool: PgPool) {
529 setup();
530 let state = TestState::from_pool(pool).await.unwrap();
531
532 let request =
535 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
536 "client_uri": "https://example.com/",
537 "redirect_uris": ["https://example.com/"],
538 "response_types": ["code"],
539 "grant_types": ["authorization_code"],
540 "token_endpoint_auth_method": "none",
541 }));
542
543 let response = state.request(request).await;
544 response.assert_status(StatusCode::CREATED);
545 let response: ClientRegistrationResponse = response.json();
546 assert!(response.client_secret.is_none());
547
548 let request =
551 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
552 "client_uri": "https://example.com/",
553 "redirect_uris": ["https://example.com/"],
554 "response_types": ["code"],
555 "grant_types": ["authorization_code"],
556 "token_endpoint_auth_method": "client_secret_basic",
557 }));
558
559 let response = state.request(request).await;
560 response.assert_status(StatusCode::CREATED);
561 let response: ClientRegistrationResponse = response.json();
562 assert!(response.client_secret.is_some());
563 }
564 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
565 async fn test_registration_dedupe(pool: PgPool) {
566 setup();
567 let state = TestState::from_pool(pool).await.unwrap();
568
569 let request =
571 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
572 "client_uri": "https://example.com/",
573 "client_name": "Example",
574 "client_name#en": "Example",
575 "client_name#fr": "Exemple",
576 "client_name#de": "Beispiel",
577 "redirect_uris": ["https://example.com/", "https://example.com/callback"],
578 "response_types": ["code"],
579 "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
580 "token_endpoint_auth_method": "none",
581 }));
582
583 let response = state.request(request.clone()).await;
584 response.assert_status(StatusCode::CREATED);
585 let response: ClientRegistrationResponse = response.json();
586 let client_id = response.client_id;
587
588 let response = state.request(request).await;
589 response.assert_status(StatusCode::CREATED);
590 let response: ClientRegistrationResponse = response.json();
591 assert_eq!(response.client_id, client_id);
592
593 let request =
595 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
596 "client_uri": "https://example.com/",
597 "client_name": "Example",
598 "client_name#de": "Beispiel",
599 "client_name#fr": "Exemple",
600 "client_name#en": "Example",
601 "redirect_uris": ["https://example.com/callback", "https://example.com/"],
602 "response_types": ["code"],
603 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
604 "token_endpoint_auth_method": "none",
605 }));
606
607 let response = state.request(request).await;
608 response.assert_status(StatusCode::CREATED);
609 let response: ClientRegistrationResponse = response.json();
610 assert_eq!(response.client_id, client_id);
611
612 let request =
614 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
615 "client_uri": "https://example.com/",
616 "redirect_uris": ["https://example.com/"],
617 "response_types": ["code"],
618 "grant_types": ["authorization_code"],
619 "token_endpoint_auth_method": "client_secret_basic",
620 }));
621
622 let response = state.request(request.clone()).await;
623 response.assert_status(StatusCode::CREATED);
624 let response: ClientRegistrationResponse = response.json();
625 assert_ne!(response.client_id, client_id);
627 let client_id = response.client_id;
628
629 let response = state.request(request).await;
630 response.assert_status(StatusCode::CREATED);
631 let response: ClientRegistrationResponse = response.json();
632 assert_ne!(response.client_id, client_id);
633 }
634
635 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
636 async fn test_registration_device_code_grant_disabled(pool: PgPool) {
637 setup();
638 let state = TestState::from_pool_with_site_config(
639 pool,
640 SiteConfig {
641 device_code_grant_enabled: false,
642 ..test_site_config()
643 },
644 )
645 .await
646 .unwrap();
647
648 let request =
651 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
652 "client_uri": "https://example.com/",
653 "token_endpoint_auth_method": "none",
654 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"],
655 "response_types": [],
656 }));
657
658 let response = state.request(request).await;
659 response.assert_status(StatusCode::CREATED);
660 let client: serde_json::Value = response.json();
661 assert_json_snapshot!(client, @r#"
662 {
663 "client_id": "01FSHN9AG09FE39KETP6F390F8",
664 "client_id_issued_at": 1642344000,
665 "redirect_uris": [],
666 "grant_types": [],
667 "token_endpoint_auth_method": "none",
668 "client_uri": "https://example.com/"
669 }
670 "#);
671 }
672}