1use std::sync::LazyLock;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::TypedHeader;
11use hyper::StatusCode;
12use mas_axum_utils::sentry::SentryEventID;
13use mas_iana::oauth::OAuthClientAuthenticationMethod;
14use mas_keystore::Encrypter;
15use mas_policy::{Policy, Violation};
16use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
17use oauth2_types::{
18 errors::{ClientError, ClientErrorCode},
19 registration::{
20 ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
21 VerifiedClientMetadata,
22 },
23};
24use opentelemetry::{Key, KeyValue, metrics::Counter};
25use psl::Psl;
26use rand::distributions::{Alphanumeric, DistString};
27use serde::Serialize;
28use sha2::Digest as _;
29use thiserror::Error;
30use tracing::info;
31use url::Url;
32
33use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
34
35static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
36 METER
37 .u64_counter("mas.oauth2.registration_request")
38 .with_description("Number of OAuth2 registration requests")
39 .with_unit("{request}")
40 .build()
41});
42const RESULT: Key = Key::from_static_str("result");
43
44#[derive(Debug, Error)]
45pub(crate) enum RouteError {
46 #[error(transparent)]
47 Internal(Box<dyn std::error::Error + Send + Sync>),
48
49 #[error(transparent)]
50 JsonExtract(#[from] axum::extract::rejection::JsonRejection),
51
52 #[error("invalid client metadata")]
53 InvalidClientMetadata(#[from] ClientMetadataVerificationError),
54
55 #[error("{0} is a public suffix, not a valid domain")]
56 UrlIsPublicSuffix(&'static str),
57
58 #[error("denied by the policy: {0:?}")]
59 PolicyDenied(Vec<Violation>),
60}
61
62impl_from_error_for_route!(mas_storage::RepositoryError);
63impl_from_error_for_route!(mas_policy::LoadError);
64impl_from_error_for_route!(mas_policy::EvaluationError);
65impl_from_error_for_route!(mas_keystore::aead::Error);
66impl_from_error_for_route!(serde_json::Error);
67
68impl IntoResponse for RouteError {
69 fn into_response(self) -> axum::response::Response {
70 let event_id = sentry::capture_error(&self);
71
72 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
73
74 let response = match self {
75 Self::Internal(_) => (
76 StatusCode::INTERNAL_SERVER_ERROR,
77 Json(ClientError::from(ClientErrorCode::ServerError)),
78 )
79 .into_response(),
80
81 Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
85 StatusCode::BAD_REQUEST,
86 Json(
87 ClientError::from(ClientErrorCode::InvalidClientMetadata)
88 .with_description(e.to_string()),
89 ),
90 )
91 .into_response(),
92
93 Self::JsonExtract(_) => (
96 StatusCode::BAD_REQUEST,
97 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
98 )
99 .into_response(),
100
101 Self::InvalidClientMetadata(
105 ClientMetadataVerificationError::MissingRedirectUris
106 | ClientMetadataVerificationError::RedirectUriWithFragment(_),
107 ) => (
108 StatusCode::BAD_REQUEST,
109 Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
110 )
111 .into_response(),
112
113 Self::InvalidClientMetadata(e) => (
114 StatusCode::BAD_REQUEST,
115 Json(
116 ClientError::from(ClientErrorCode::InvalidClientMetadata)
117 .with_description(e.to_string()),
118 ),
119 )
120 .into_response(),
121
122 Self::UrlIsPublicSuffix("redirect_uri") => (
126 StatusCode::BAD_REQUEST,
127 Json(
128 ClientError::from(ClientErrorCode::InvalidRedirectUri)
129 .with_description("redirect_uri is not using a valid domain".to_owned()),
130 ),
131 )
132 .into_response(),
133
134 Self::UrlIsPublicSuffix(field) => (
135 StatusCode::BAD_REQUEST,
136 Json(
137 ClientError::from(ClientErrorCode::InvalidClientMetadata)
138 .with_description(format!("{field} is not using a valid domain")),
139 ),
140 )
141 .into_response(),
142
143 Self::PolicyDenied(violations) => {
147 let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
149 ClientErrorCode::InvalidRedirectUri
150 } else {
151 ClientErrorCode::InvalidClientMetadata
152 };
153
154 let collected = &violations
155 .iter()
156 .map(|v| v.msg.clone())
157 .collect::<Vec<String>>();
158 let joined = collected.join("; ");
159
160 (
161 StatusCode::BAD_REQUEST,
162 Json(ClientError::from(code).with_description(joined)),
163 )
164 .into_response()
165 }
166 };
167
168 (SentryEventID::from(event_id), response).into_response()
169 }
170}
171
172#[derive(Serialize)]
173struct RouteResponse {
174 #[serde(flatten)]
175 response: ClientRegistrationResponse,
176 #[serde(flatten)]
177 metadata: VerifiedClientMetadata,
178}
179
180fn host_is_public_suffix(url: &Url) -> bool {
182 let host = url.host_str().unwrap_or_default().as_bytes();
183 let Some(suffix) = psl::List.suffix(host) else {
184 return false;
187 };
188
189 if !suffix.is_known() {
190 return false;
192 }
193
194 if host.len() <= suffix.as_bytes().len() + 1 {
198 return true;
200 }
201
202 false
203}
204
205fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
207 url.iter().any(|(_lang, url)| host_is_public_suffix(url))
208}
209
210#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all, err)]
211pub(crate) async fn post(
212 mut rng: BoxRng,
213 clock: BoxClock,
214 mut repo: BoxRepository,
215 mut policy: Policy,
216 activity_tracker: BoundActivityTracker,
217 user_agent: Option<TypedHeader<headers::UserAgent>>,
218 State(encrypter): State<Encrypter>,
219 body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
220) -> Result<impl IntoResponse, RouteError> {
221 let Json(body) = body?;
223
224 let body = body.sorted();
226
227 let body_json = serde_json::to_string(&body)?;
229
230 info!(body = body_json, "Client registration");
231
232 let user_agent = user_agent.map(|ua| ua.to_string());
233
234 let metadata = body.validate()?;
236
237 if let Some(client_uri) = &metadata.client_uri {
240 if localised_url_has_public_suffix(client_uri) {
241 return Err(RouteError::UrlIsPublicSuffix("client_uri"));
242 }
243 }
244
245 if let Some(logo_uri) = &metadata.logo_uri {
246 if localised_url_has_public_suffix(logo_uri) {
247 return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
248 }
249 }
250
251 if let Some(policy_uri) = &metadata.policy_uri {
252 if localised_url_has_public_suffix(policy_uri) {
253 return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
254 }
255 }
256
257 if let Some(tos_uri) = &metadata.tos_uri {
258 if localised_url_has_public_suffix(tos_uri) {
259 return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
260 }
261 }
262
263 if let Some(initiate_login_uri) = &metadata.initiate_login_uri {
264 if host_is_public_suffix(initiate_login_uri) {
265 return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
266 }
267 }
268
269 for redirect_uri in metadata.redirect_uris() {
270 if host_is_public_suffix(redirect_uri) {
271 return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
272 }
273 }
274
275 let res = policy
276 .evaluate_client_registration(mas_policy::ClientRegistrationInput {
277 client_metadata: &metadata,
278 requester: mas_policy::Requester {
279 ip_address: activity_tracker.ip(),
280 user_agent,
281 },
282 })
283 .await?;
284 if !res.valid() {
285 return Err(RouteError::PolicyDenied(res.violations));
286 }
287
288 let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
289 Some(
290 OAuthClientAuthenticationMethod::ClientSecretJwt
291 | OAuthClientAuthenticationMethod::ClientSecretPost
292 | OAuthClientAuthenticationMethod::ClientSecretBasic,
293 ) => {
294 let client_secret = Alphanumeric.sample_string(&mut rng, 20);
296 let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
297 (Some(client_secret), Some(encrypted_client_secret))
298 }
299 _ => (None, None),
300 };
301
302 let (digest_hash, existing_client) = if client_secret.is_none() {
305 let hash = sha2::Sha256::digest(body_json);
312 let hash = hex::encode(hash);
313 let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
314 (Some(hash), client)
315 } else {
316 (None, None)
317 };
318
319 let client = if let Some(client) = existing_client {
320 tracing::info!(%client.id, "Reusing existing client");
321 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "reused")]);
322 client
323 } else {
324 let client = repo
325 .oauth2_client()
326 .add(
327 &mut rng,
328 &clock,
329 metadata.redirect_uris().to_vec(),
330 digest_hash,
331 encrypted_client_secret,
332 metadata.application_type.clone(),
333 metadata.grant_types().to_vec(),
335 metadata
336 .client_name
337 .clone()
338 .map(Localized::to_non_localized),
339 metadata.logo_uri.clone().map(Localized::to_non_localized),
340 metadata.client_uri.clone().map(Localized::to_non_localized),
341 metadata.policy_uri.clone().map(Localized::to_non_localized),
342 metadata.tos_uri.clone().map(Localized::to_non_localized),
343 metadata.jwks_uri.clone(),
344 metadata.jwks.clone(),
345 metadata.id_token_signed_response_alg.clone(),
347 metadata.userinfo_signed_response_alg.clone(),
348 metadata.token_endpoint_auth_method.clone(),
349 metadata.token_endpoint_auth_signing_alg.clone(),
350 metadata.initiate_login_uri.clone(),
351 )
352 .await?;
353 tracing::info!(%client.id, "Registered new client");
354 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "created")]);
355 client
356 };
357
358 let response = ClientRegistrationResponse {
359 client_id: client.client_id.clone(),
360 client_secret,
361 client_id_issued_at: Some(client.id.datetime().into()),
363 client_secret_expires_at: None,
364 };
365
366 let metadata = client.into_metadata().validate()?;
369
370 repo.save().await?;
371
372 let response = RouteResponse { response, metadata };
373
374 Ok((StatusCode::CREATED, Json(response)))
375}
376
377#[cfg(test)]
378mod tests {
379 use hyper::{Request, StatusCode};
380 use mas_router::SimpleRoute;
381 use oauth2_types::{
382 errors::{ClientError, ClientErrorCode},
383 registration::ClientRegistrationResponse,
384 };
385 use sqlx::PgPool;
386 use url::Url;
387
388 use crate::{
389 oauth2::registration::host_is_public_suffix,
390 test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
391 };
392
393 #[test]
394 fn test_public_suffix_list() {
395 fn url_is_public_suffix(url: &str) -> bool {
396 host_is_public_suffix(&Url::parse(url).unwrap())
397 }
398
399 assert!(url_is_public_suffix("https://.com"));
400 assert!(url_is_public_suffix("https://.com."));
401 assert!(url_is_public_suffix("https://co.uk"));
402 assert!(url_is_public_suffix("https://github.io"));
403 assert!(!url_is_public_suffix("https://example.com"));
404 assert!(!url_is_public_suffix("https://example.com."));
405 assert!(!url_is_public_suffix("https://x.com"));
406 assert!(!url_is_public_suffix("https://x.com."));
407 assert!(!url_is_public_suffix("https://matrix-org.github.io"));
408 assert!(!url_is_public_suffix("http://localhost"));
409 assert!(!url_is_public_suffix("org.matrix:/callback"));
410 assert!(!url_is_public_suffix("http://somerandominternaldomain"));
411 }
412
413 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
414 async fn test_registration_error(pool: PgPool) {
415 setup();
416 let state = TestState::from_pool(pool).await.unwrap();
417
418 let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
420 .body("this is not a json".to_owned())
421 .unwrap();
422
423 let response = state.request(request).await;
424 response.assert_status(StatusCode::BAD_REQUEST);
425 let response: ClientError = response.json();
426 assert_eq!(response.error, ClientErrorCode::InvalidRequest);
427
428 let request =
430 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
431 "client_uri": "this is not a uri",
432 }));
433
434 let response = state.request(request).await;
435 response.assert_status(StatusCode::BAD_REQUEST);
436 let response: ClientError = response.json();
437 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
438
439 let request =
441 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
442 "application_type": "web",
443 "client_uri": "https://example.com/",
444 "redirect_uris": ["http://this-is-insecure.com/"],
445 }));
446
447 let response = state.request(request).await;
448 response.assert_status(StatusCode::BAD_REQUEST);
449 let response: ClientError = response.json();
450 assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
451
452 let request =
454 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
455 "client_uri": "https://example.com/",
456 "redirect_uris": ["https://example.com/"],
457 "response_types": ["id_token"],
458 "grant_types": ["authorization_code"],
459 }));
460
461 let response = state.request(request).await;
462 response.assert_status(StatusCode::BAD_REQUEST);
463 let response: ClientError = response.json();
464 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
465
466 let request =
468 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
469 "client_uri": "https://github.io/",
470 "redirect_uris": ["https://github.io/"],
471 "response_types": ["code"],
472 "grant_types": ["authorization_code"],
473 "token_endpoint_auth_method": "client_secret_basic",
474 }));
475
476 let response = state.request(request).await;
477 response.assert_status(StatusCode::BAD_REQUEST);
478 let response: ClientError = response.json();
479 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
480 assert_eq!(
481 response.error_description.unwrap(),
482 "client_uri is not using a valid domain"
483 );
484
485 let request =
487 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
488 "client_uri": "https://example.com/",
489 "client_uri#fr-FR": "https://github.io/",
490 "redirect_uris": ["https://example.com/"],
491 "response_types": ["code"],
492 "grant_types": ["authorization_code"],
493 "token_endpoint_auth_method": "client_secret_basic",
494 }));
495
496 let response = state.request(request).await;
497 response.assert_status(StatusCode::BAD_REQUEST);
498 let response: ClientError = response.json();
499 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
500 assert_eq!(
501 response.error_description.unwrap(),
502 "client_uri is not using a valid domain"
503 );
504 }
505
506 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
507 async fn test_registration(pool: PgPool) {
508 setup();
509 let state = TestState::from_pool(pool).await.unwrap();
510
511 let request =
514 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
515 "client_uri": "https://example.com/",
516 "redirect_uris": ["https://example.com/"],
517 "response_types": ["code"],
518 "grant_types": ["authorization_code"],
519 "token_endpoint_auth_method": "none",
520 }));
521
522 let response = state.request(request).await;
523 response.assert_status(StatusCode::CREATED);
524 let response: ClientRegistrationResponse = response.json();
525 assert!(response.client_secret.is_none());
526
527 let request =
530 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
531 "client_uri": "https://example.com/",
532 "redirect_uris": ["https://example.com/"],
533 "response_types": ["code"],
534 "grant_types": ["authorization_code"],
535 "token_endpoint_auth_method": "client_secret_basic",
536 }));
537
538 let response = state.request(request).await;
539 response.assert_status(StatusCode::CREATED);
540 let response: ClientRegistrationResponse = response.json();
541 assert!(response.client_secret.is_some());
542 }
543 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
544 async fn test_registration_dedupe(pool: PgPool) {
545 setup();
546 let state = TestState::from_pool(pool).await.unwrap();
547
548 let request =
550 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
551 "client_uri": "https://example.com/",
552 "client_name": "Example",
553 "client_name#en": "Example",
554 "client_name#fr": "Exemple",
555 "client_name#de": "Beispiel",
556 "redirect_uris": ["https://example.com/", "https://example.com/callback"],
557 "response_types": ["code"],
558 "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
559 "token_endpoint_auth_method": "none",
560 }));
561
562 let response = state.request(request.clone()).await;
563 response.assert_status(StatusCode::CREATED);
564 let response: ClientRegistrationResponse = response.json();
565 let client_id = response.client_id;
566
567 let response = state.request(request).await;
568 response.assert_status(StatusCode::CREATED);
569 let response: ClientRegistrationResponse = response.json();
570 assert_eq!(response.client_id, client_id);
571
572 let request =
574 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
575 "client_uri": "https://example.com/",
576 "client_name": "Example",
577 "client_name#de": "Beispiel",
578 "client_name#fr": "Exemple",
579 "client_name#en": "Example",
580 "redirect_uris": ["https://example.com/callback", "https://example.com/"],
581 "response_types": ["code"],
582 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
583 "token_endpoint_auth_method": "none",
584 }));
585
586 let response = state.request(request).await;
587 response.assert_status(StatusCode::CREATED);
588 let response: ClientRegistrationResponse = response.json();
589 assert_eq!(response.client_id, client_id);
590
591 let request =
593 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
594 "client_uri": "https://example.com/",
595 "redirect_uris": ["https://example.com/"],
596 "response_types": ["code"],
597 "grant_types": ["authorization_code"],
598 "token_endpoint_auth_method": "client_secret_basic",
599 }));
600
601 let response = state.request(request.clone()).await;
602 response.assert_status(StatusCode::CREATED);
603 let response: ClientRegistrationResponse = response.json();
604 assert_ne!(response.client_id, client_id);
606 let client_id = response.client_id;
607
608 let response = state.request(request).await;
609 response.assert_status(StatusCode::CREATED);
610 let response: ClientRegistrationResponse = response.json();
611 assert_ne!(response.client_id, client_id);
612 }
613}