1use std::sync::LazyLock;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::TypedHeader;
11use hyper::StatusCode;
12use mas_axum_utils::record_error;
13use mas_data_model::{BoxClock, BoxRng};
14use mas_iana::oauth::OAuthClientAuthenticationMethod;
15use mas_keystore::Encrypter;
16use mas_policy::{EvaluationResult, Policy};
17use mas_storage::{BoxRepository, oauth2::OAuth2ClientRepository};
18use oauth2_types::{
19 errors::{ClientError, ClientErrorCode},
20 registration::{
21 ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
22 VerifiedClientMetadata,
23 },
24};
25use opentelemetry::{Key, KeyValue, metrics::Counter};
26use psl::Psl;
27use rand::distributions::{Alphanumeric, DistString};
28use serde::Serialize;
29use sha2::Digest as _;
30use thiserror::Error;
31use tracing::info;
32use url::Url;
33
34use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
35
36static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
37 METER
38 .u64_counter("mas.oauth2.registration_request")
39 .with_description("Number of OAuth2 registration requests")
40 .with_unit("{request}")
41 .build()
42});
43const RESULT: Key = Key::from_static_str("result");
44
45#[derive(Debug, Error)]
46pub(crate) enum RouteError {
47 #[error(transparent)]
48 Internal(Box<dyn std::error::Error + Send + Sync>),
49
50 #[error(transparent)]
51 JsonExtract(#[from] axum::extract::rejection::JsonRejection),
52
53 #[error("invalid client metadata")]
54 InvalidClientMetadata(#[from] ClientMetadataVerificationError),
55
56 #[error("{0} is a public suffix, not a valid domain")]
57 UrlIsPublicSuffix(&'static str),
58
59 #[error("client registration denied by the policy: {0}")]
60 PolicyDenied(EvaluationResult),
61}
62
63impl_from_error_for_route!(mas_storage::RepositoryError);
64impl_from_error_for_route!(mas_policy::LoadError);
65impl_from_error_for_route!(mas_policy::EvaluationError);
66impl_from_error_for_route!(mas_keystore::aead::Error);
67impl_from_error_for_route!(serde_json::Error);
68
69impl IntoResponse for RouteError {
70 fn into_response(self) -> axum::response::Response {
71 let sentry_event_id = record_error!(self, Self::Internal(_));
72
73 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
74
75 let response = match self {
76 Self::Internal(_) => (
77 StatusCode::INTERNAL_SERVER_ERROR,
78 Json(ClientError::from(ClientErrorCode::ServerError)),
79 )
80 .into_response(),
81
82 Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
86 StatusCode::BAD_REQUEST,
87 Json(
88 ClientError::from(ClientErrorCode::InvalidClientMetadata)
89 .with_description(e.to_string()),
90 ),
91 )
92 .into_response(),
93
94 Self::JsonExtract(_) => (
97 StatusCode::BAD_REQUEST,
98 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
99 )
100 .into_response(),
101
102 Self::InvalidClientMetadata(
106 ClientMetadataVerificationError::MissingRedirectUris
107 | ClientMetadataVerificationError::RedirectUriWithFragment(_),
108 ) => (
109 StatusCode::BAD_REQUEST,
110 Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
111 )
112 .into_response(),
113
114 Self::InvalidClientMetadata(e) => (
115 StatusCode::BAD_REQUEST,
116 Json(
117 ClientError::from(ClientErrorCode::InvalidClientMetadata)
118 .with_description(e.to_string()),
119 ),
120 )
121 .into_response(),
122
123 Self::UrlIsPublicSuffix("redirect_uri") => (
127 StatusCode::BAD_REQUEST,
128 Json(
129 ClientError::from(ClientErrorCode::InvalidRedirectUri)
130 .with_description("redirect_uri is not using a valid domain".to_owned()),
131 ),
132 )
133 .into_response(),
134
135 Self::UrlIsPublicSuffix(field) => (
136 StatusCode::BAD_REQUEST,
137 Json(
138 ClientError::from(ClientErrorCode::InvalidClientMetadata)
139 .with_description(format!("{field} is not using a valid domain")),
140 ),
141 )
142 .into_response(),
143
144 Self::PolicyDenied(evaluation) => {
148 let code = if evaluation
150 .violations
151 .iter()
152 .any(|v| v.msg.contains("redirect_uri"))
153 {
154 ClientErrorCode::InvalidRedirectUri
155 } else {
156 ClientErrorCode::InvalidClientMetadata
157 };
158
159 let collected = &evaluation
160 .violations
161 .iter()
162 .map(|v| v.msg.clone())
163 .collect::<Vec<String>>();
164 let joined = collected.join("; ");
165
166 (
167 StatusCode::BAD_REQUEST,
168 Json(ClientError::from(code).with_description(joined)),
169 )
170 .into_response()
171 }
172 };
173
174 (sentry_event_id, response).into_response()
175 }
176}
177
178#[derive(Serialize)]
179struct RouteResponse {
180 #[serde(flatten)]
181 response: ClientRegistrationResponse,
182 #[serde(flatten)]
183 metadata: VerifiedClientMetadata,
184}
185
186fn host_is_public_suffix(url: &Url) -> bool {
188 let host = url.host_str().unwrap_or_default().as_bytes();
189 let Some(suffix) = psl::List.suffix(host) else {
190 return false;
193 };
194
195 if !suffix.is_known() {
196 return false;
198 }
199
200 if host.len() <= suffix.as_bytes().len() + 1 {
204 return true;
206 }
207
208 false
209}
210
211fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
213 url.iter().any(|(_lang, url)| host_is_public_suffix(url))
214}
215
216#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all)]
217pub(crate) async fn post(
218 mut rng: BoxRng,
219 clock: BoxClock,
220 mut repo: BoxRepository,
221 mut policy: Policy,
222 activity_tracker: BoundActivityTracker,
223 user_agent: Option<TypedHeader<headers::UserAgent>>,
224 State(encrypter): State<Encrypter>,
225 body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
226) -> Result<impl IntoResponse, RouteError> {
227 let Json(body) = body?;
229
230 let body = body.sorted();
232
233 let body_json = serde_json::to_string(&body)?;
235
236 info!(body = body_json, "Client registration");
237
238 let user_agent = user_agent.map(|ua| ua.to_string());
239
240 let metadata = body.validate()?;
242
243 if let Some(client_uri) = &metadata.client_uri
246 && localised_url_has_public_suffix(client_uri)
247 {
248 return Err(RouteError::UrlIsPublicSuffix("client_uri"));
249 }
250
251 if let Some(logo_uri) = &metadata.logo_uri
252 && localised_url_has_public_suffix(logo_uri)
253 {
254 return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
255 }
256
257 if let Some(policy_uri) = &metadata.policy_uri
258 && localised_url_has_public_suffix(policy_uri)
259 {
260 return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
261 }
262
263 if let Some(tos_uri) = &metadata.tos_uri
264 && localised_url_has_public_suffix(tos_uri)
265 {
266 return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
267 }
268
269 if let Some(initiate_login_uri) = &metadata.initiate_login_uri
270 && host_is_public_suffix(initiate_login_uri)
271 {
272 return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
273 }
274
275 for redirect_uri in metadata.redirect_uris() {
276 if host_is_public_suffix(redirect_uri) {
277 return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
278 }
279 }
280
281 let res = policy
282 .evaluate_client_registration(mas_policy::ClientRegistrationInput {
283 client_metadata: &metadata,
284 requester: mas_policy::Requester {
285 ip_address: activity_tracker.ip(),
286 user_agent,
287 },
288 })
289 .await?;
290 if !res.valid() {
291 return Err(RouteError::PolicyDenied(res));
292 }
293
294 let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
295 Some(
296 OAuthClientAuthenticationMethod::ClientSecretJwt
297 | OAuthClientAuthenticationMethod::ClientSecretPost
298 | OAuthClientAuthenticationMethod::ClientSecretBasic,
299 ) => {
300 let client_secret = Alphanumeric.sample_string(&mut rng, 20);
302 let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
303 (Some(client_secret), Some(encrypted_client_secret))
304 }
305 _ => (None, None),
306 };
307
308 let (digest_hash, existing_client) = if client_secret.is_none() {
311 let hash = sha2::Sha256::digest(body_json);
318 let hash = hex::encode(hash);
319 let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
320 (Some(hash), client)
321 } else {
322 (None, None)
323 };
324
325 let client = if let Some(client) = existing_client {
326 tracing::info!(%client.id, "Reusing existing client");
327 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "reused")]);
328 client
329 } else {
330 let client = repo
331 .oauth2_client()
332 .add(
333 &mut rng,
334 &clock,
335 metadata.redirect_uris().to_vec(),
336 digest_hash,
337 encrypted_client_secret,
338 metadata.application_type.clone(),
339 metadata.grant_types().to_vec(),
341 metadata
342 .client_name
343 .clone()
344 .map(Localized::to_non_localized),
345 metadata.logo_uri.clone().map(Localized::to_non_localized),
346 metadata.client_uri.clone().map(Localized::to_non_localized),
347 metadata.policy_uri.clone().map(Localized::to_non_localized),
348 metadata.tos_uri.clone().map(Localized::to_non_localized),
349 metadata.jwks_uri.clone(),
350 metadata.jwks.clone(),
351 metadata.id_token_signed_response_alg.clone(),
353 metadata.userinfo_signed_response_alg.clone(),
354 metadata.token_endpoint_auth_method.clone(),
355 metadata.token_endpoint_auth_signing_alg.clone(),
356 metadata.initiate_login_uri.clone(),
357 )
358 .await?;
359 tracing::info!(%client.id, "Registered new client");
360 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "created")]);
361 client
362 };
363
364 let response = ClientRegistrationResponse {
365 client_id: client.client_id.clone(),
366 client_secret,
367 client_id_issued_at: Some(client.id.datetime().into()),
369 client_secret_expires_at: None,
370 };
371
372 let metadata = client.into_metadata().validate()?;
375
376 repo.save().await?;
377
378 let response = RouteResponse { response, metadata };
379
380 Ok((StatusCode::CREATED, Json(response)))
381}
382
383#[cfg(test)]
384mod tests {
385 use hyper::{Request, StatusCode};
386 use mas_router::SimpleRoute;
387 use oauth2_types::{
388 errors::{ClientError, ClientErrorCode},
389 registration::ClientRegistrationResponse,
390 };
391 use sqlx::PgPool;
392 use url::Url;
393
394 use crate::{
395 oauth2::registration::host_is_public_suffix,
396 test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
397 };
398
399 #[test]
400 fn test_public_suffix_list() {
401 fn url_is_public_suffix(url: &str) -> bool {
402 host_is_public_suffix(&Url::parse(url).unwrap())
403 }
404
405 assert!(url_is_public_suffix("https://.com"));
406 assert!(url_is_public_suffix("https://.com."));
407 assert!(url_is_public_suffix("https://co.uk"));
408 assert!(url_is_public_suffix("https://github.io"));
409 assert!(!url_is_public_suffix("https://example.com"));
410 assert!(!url_is_public_suffix("https://example.com."));
411 assert!(!url_is_public_suffix("https://x.com"));
412 assert!(!url_is_public_suffix("https://x.com."));
413 assert!(!url_is_public_suffix("https://matrix-org.github.io"));
414 assert!(!url_is_public_suffix("http://localhost"));
415 assert!(!url_is_public_suffix("org.matrix:/callback"));
416 assert!(!url_is_public_suffix("http://somerandominternaldomain"));
417 }
418
419 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
420 async fn test_registration_error(pool: PgPool) {
421 setup();
422 let state = TestState::from_pool(pool).await.unwrap();
423
424 let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
426 .body("this is not a json".to_owned())
427 .unwrap();
428
429 let response = state.request(request).await;
430 response.assert_status(StatusCode::BAD_REQUEST);
431 let response: ClientError = response.json();
432 assert_eq!(response.error, ClientErrorCode::InvalidRequest);
433
434 let request =
436 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
437 "client_uri": "this is not a uri",
438 }));
439
440 let response = state.request(request).await;
441 response.assert_status(StatusCode::BAD_REQUEST);
442 let response: ClientError = response.json();
443 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
444
445 let request =
447 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
448 "application_type": "web",
449 "client_uri": "https://example.com/",
450 "redirect_uris": ["http://this-is-insecure.com/"],
451 }));
452
453 let response = state.request(request).await;
454 response.assert_status(StatusCode::BAD_REQUEST);
455 let response: ClientError = response.json();
456 assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
457
458 let request =
460 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
461 "client_uri": "https://example.com/",
462 "redirect_uris": ["https://example.com/"],
463 "response_types": ["id_token"],
464 "grant_types": ["authorization_code"],
465 }));
466
467 let response = state.request(request).await;
468 response.assert_status(StatusCode::BAD_REQUEST);
469 let response: ClientError = response.json();
470 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
471
472 let request =
474 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
475 "client_uri": "https://github.io/",
476 "redirect_uris": ["https://github.io/"],
477 "response_types": ["code"],
478 "grant_types": ["authorization_code"],
479 "token_endpoint_auth_method": "client_secret_basic",
480 }));
481
482 let response = state.request(request).await;
483 response.assert_status(StatusCode::BAD_REQUEST);
484 let response: ClientError = response.json();
485 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
486 assert_eq!(
487 response.error_description.unwrap(),
488 "client_uri is not using a valid domain"
489 );
490
491 let request =
493 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
494 "client_uri": "https://example.com/",
495 "client_uri#fr-FR": "https://github.io/",
496 "redirect_uris": ["https://example.com/"],
497 "response_types": ["code"],
498 "grant_types": ["authorization_code"],
499 "token_endpoint_auth_method": "client_secret_basic",
500 }));
501
502 let response = state.request(request).await;
503 response.assert_status(StatusCode::BAD_REQUEST);
504 let response: ClientError = response.json();
505 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
506 assert_eq!(
507 response.error_description.unwrap(),
508 "client_uri is not using a valid domain"
509 );
510 }
511
512 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
513 async fn test_registration(pool: PgPool) {
514 setup();
515 let state = TestState::from_pool(pool).await.unwrap();
516
517 let request =
520 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
521 "client_uri": "https://example.com/",
522 "redirect_uris": ["https://example.com/"],
523 "response_types": ["code"],
524 "grant_types": ["authorization_code"],
525 "token_endpoint_auth_method": "none",
526 }));
527
528 let response = state.request(request).await;
529 response.assert_status(StatusCode::CREATED);
530 let response: ClientRegistrationResponse = response.json();
531 assert!(response.client_secret.is_none());
532
533 let request =
536 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
537 "client_uri": "https://example.com/",
538 "redirect_uris": ["https://example.com/"],
539 "response_types": ["code"],
540 "grant_types": ["authorization_code"],
541 "token_endpoint_auth_method": "client_secret_basic",
542 }));
543
544 let response = state.request(request).await;
545 response.assert_status(StatusCode::CREATED);
546 let response: ClientRegistrationResponse = response.json();
547 assert!(response.client_secret.is_some());
548 }
549 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
550 async fn test_registration_dedupe(pool: PgPool) {
551 setup();
552 let state = TestState::from_pool(pool).await.unwrap();
553
554 let request =
556 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
557 "client_uri": "https://example.com/",
558 "client_name": "Example",
559 "client_name#en": "Example",
560 "client_name#fr": "Exemple",
561 "client_name#de": "Beispiel",
562 "redirect_uris": ["https://example.com/", "https://example.com/callback"],
563 "response_types": ["code"],
564 "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
565 "token_endpoint_auth_method": "none",
566 }));
567
568 let response = state.request(request.clone()).await;
569 response.assert_status(StatusCode::CREATED);
570 let response: ClientRegistrationResponse = response.json();
571 let client_id = response.client_id;
572
573 let response = state.request(request).await;
574 response.assert_status(StatusCode::CREATED);
575 let response: ClientRegistrationResponse = response.json();
576 assert_eq!(response.client_id, client_id);
577
578 let request =
580 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
581 "client_uri": "https://example.com/",
582 "client_name": "Example",
583 "client_name#de": "Beispiel",
584 "client_name#fr": "Exemple",
585 "client_name#en": "Example",
586 "redirect_uris": ["https://example.com/callback", "https://example.com/"],
587 "response_types": ["code"],
588 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
589 "token_endpoint_auth_method": "none",
590 }));
591
592 let response = state.request(request).await;
593 response.assert_status(StatusCode::CREATED);
594 let response: ClientRegistrationResponse = response.json();
595 assert_eq!(response.client_id, client_id);
596
597 let request =
599 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
600 "client_uri": "https://example.com/",
601 "redirect_uris": ["https://example.com/"],
602 "response_types": ["code"],
603 "grant_types": ["authorization_code"],
604 "token_endpoint_auth_method": "client_secret_basic",
605 }));
606
607 let response = state.request(request.clone()).await;
608 response.assert_status(StatusCode::CREATED);
609 let response: ClientRegistrationResponse = response.json();
610 assert_ne!(response.client_id, client_id);
612 let client_id = response.client_id;
613
614 let response = state.request(request).await;
615 response.assert_status(StatusCode::CREATED);
616 let response: ClientRegistrationResponse = response.json();
617 assert_ne!(response.client_id, client_id);
618 }
619}