mas_handlers/oauth2/
registration.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::sync::LazyLock;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::TypedHeader;
11use hyper::StatusCode;
12use mas_axum_utils::record_error;
13use mas_data_model::{BoxClock, BoxRng};
14use mas_iana::oauth::OAuthClientAuthenticationMethod;
15use mas_keystore::Encrypter;
16use mas_policy::{EvaluationResult, Policy};
17use mas_storage::{BoxRepository, oauth2::OAuth2ClientRepository};
18use oauth2_types::{
19    errors::{ClientError, ClientErrorCode},
20    registration::{
21        ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
22        VerifiedClientMetadata,
23    },
24};
25use opentelemetry::{Key, KeyValue, metrics::Counter};
26use psl::Psl;
27use rand::distributions::{Alphanumeric, DistString};
28use serde::Serialize;
29use sha2::Digest as _;
30use thiserror::Error;
31use tracing::info;
32use url::Url;
33
34use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
35
36static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
37    METER
38        .u64_counter("mas.oauth2.registration_request")
39        .with_description("Number of OAuth2 registration requests")
40        .with_unit("{request}")
41        .build()
42});
43const RESULT: Key = Key::from_static_str("result");
44
45#[derive(Debug, Error)]
46pub(crate) enum RouteError {
47    #[error(transparent)]
48    Internal(Box<dyn std::error::Error + Send + Sync>),
49
50    #[error(transparent)]
51    JsonExtract(#[from] axum::extract::rejection::JsonRejection),
52
53    #[error("invalid client metadata")]
54    InvalidClientMetadata(#[from] ClientMetadataVerificationError),
55
56    #[error("{0} is a public suffix, not a valid domain")]
57    UrlIsPublicSuffix(&'static str),
58
59    #[error("client registration denied by the policy: {0}")]
60    PolicyDenied(EvaluationResult),
61}
62
63impl_from_error_for_route!(mas_storage::RepositoryError);
64impl_from_error_for_route!(mas_policy::LoadError);
65impl_from_error_for_route!(mas_policy::EvaluationError);
66impl_from_error_for_route!(mas_keystore::aead::Error);
67impl_from_error_for_route!(serde_json::Error);
68
69impl IntoResponse for RouteError {
70    fn into_response(self) -> axum::response::Response {
71        let sentry_event_id = record_error!(self, Self::Internal(_));
72
73        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
74
75        let response = match self {
76            Self::Internal(_) => (
77                StatusCode::INTERNAL_SERVER_ERROR,
78                Json(ClientError::from(ClientErrorCode::ServerError)),
79            )
80                .into_response(),
81
82            // This error happens if we managed to parse the incomiong JSON but it can't be
83            // deserialized to the expected type. In this case we return an
84            // `invalid_client_metadata` error with the details of the error.
85            Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
86                StatusCode::BAD_REQUEST,
87                Json(
88                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
89                        .with_description(e.to_string()),
90                ),
91            )
92                .into_response(),
93
94            // For all other JSON errors we return a `invalid_request` error, since this is
95            // probably due to a malformed request.
96            Self::JsonExtract(_) => (
97                StatusCode::BAD_REQUEST,
98                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
99            )
100                .into_response(),
101
102            // This error comes from the `ClientMetadata::validate` method. We return an
103            // `invalid_redirect_uri` error if the error is related to the redirect URIs, else we
104            // return an `invalid_client_metadata` error.
105            Self::InvalidClientMetadata(
106                ClientMetadataVerificationError::MissingRedirectUris
107                | ClientMetadataVerificationError::RedirectUriWithFragment(_),
108            ) => (
109                StatusCode::BAD_REQUEST,
110                Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
111            )
112                .into_response(),
113
114            Self::InvalidClientMetadata(e) => (
115                StatusCode::BAD_REQUEST,
116                Json(
117                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
118                        .with_description(e.to_string()),
119                ),
120            )
121                .into_response(),
122
123            // This error happens if the any of the client's URIs are public suffixes. We return
124            // an `invalid_redirect_uri` error if it's a `redirect_uri`, else we return an
125            // `invalid_client_metadata` error.
126            Self::UrlIsPublicSuffix("redirect_uri") => (
127                StatusCode::BAD_REQUEST,
128                Json(
129                    ClientError::from(ClientErrorCode::InvalidRedirectUri)
130                        .with_description("redirect_uri is not using a valid domain".to_owned()),
131                ),
132            )
133                .into_response(),
134
135            Self::UrlIsPublicSuffix(field) => (
136                StatusCode::BAD_REQUEST,
137                Json(
138                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
139                        .with_description(format!("{field} is not using a valid domain")),
140                ),
141            )
142                .into_response(),
143
144            // For policy violations, we return an `invalid_client_metadata` error with the details
145            // of the violations in most cases. If a violation includes `redirect_uri` in the
146            // message, we return an `invalid_redirect_uri` error instead.
147            Self::PolicyDenied(evaluation) => {
148                // TODO: detect them better
149                let code = if evaluation
150                    .violations
151                    .iter()
152                    .any(|v| v.msg.contains("redirect_uri"))
153                {
154                    ClientErrorCode::InvalidRedirectUri
155                } else {
156                    ClientErrorCode::InvalidClientMetadata
157                };
158
159                let collected = &evaluation
160                    .violations
161                    .iter()
162                    .map(|v| v.msg.clone())
163                    .collect::<Vec<String>>();
164                let joined = collected.join("; ");
165
166                (
167                    StatusCode::BAD_REQUEST,
168                    Json(ClientError::from(code).with_description(joined)),
169                )
170                    .into_response()
171            }
172        };
173
174        (sentry_event_id, response).into_response()
175    }
176}
177
178#[derive(Serialize)]
179struct RouteResponse {
180    #[serde(flatten)]
181    response: ClientRegistrationResponse,
182    #[serde(flatten)]
183    metadata: VerifiedClientMetadata,
184}
185
186/// Check if the host of the given URL is a public suffix
187fn host_is_public_suffix(url: &Url) -> bool {
188    let host = url.host_str().unwrap_or_default().as_bytes();
189    let Some(suffix) = psl::List.suffix(host) else {
190        // There is no suffix, which is the case for empty hosts, like with custom
191        // schemes
192        return false;
193    };
194
195    if !suffix.is_known() {
196        // The suffix is not known, so it's not a public suffix
197        return false;
198    }
199
200    // We want to cover two cases:
201    // - The host is the suffix itself, like `com`
202    // - The host is a dot followed by the suffix, like `.com`
203    if host.len() <= suffix.as_bytes().len() + 1 {
204        // The host only has the suffix in it, so it's a public suffix
205        return true;
206    }
207
208    false
209}
210
211/// Check if any of the URLs in the given `Localized` field is a public suffix
212fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
213    url.iter().any(|(_lang, url)| host_is_public_suffix(url))
214}
215
216#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all)]
217pub(crate) async fn post(
218    mut rng: BoxRng,
219    clock: BoxClock,
220    mut repo: BoxRepository,
221    mut policy: Policy,
222    activity_tracker: BoundActivityTracker,
223    user_agent: Option<TypedHeader<headers::UserAgent>>,
224    State(encrypter): State<Encrypter>,
225    body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
226) -> Result<impl IntoResponse, RouteError> {
227    // Propagate any JSON extraction error
228    let Json(body) = body?;
229
230    // Sort the properties to ensure a stable serialisation order for hashing
231    let body = body.sorted();
232
233    // We need to serialize the body to compute the hash, and to log it
234    let body_json = serde_json::to_string(&body)?;
235
236    info!(body = body_json, "Client registration");
237
238    let user_agent = user_agent.map(|ua| ua.to_string());
239
240    // Validate the body
241    let metadata = body.validate()?;
242
243    // Some extra validation that is hard to do in OPA and not done by the
244    // `validate` method either
245    if let Some(client_uri) = &metadata.client_uri
246        && localised_url_has_public_suffix(client_uri)
247    {
248        return Err(RouteError::UrlIsPublicSuffix("client_uri"));
249    }
250
251    if let Some(logo_uri) = &metadata.logo_uri
252        && localised_url_has_public_suffix(logo_uri)
253    {
254        return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
255    }
256
257    if let Some(policy_uri) = &metadata.policy_uri
258        && localised_url_has_public_suffix(policy_uri)
259    {
260        return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
261    }
262
263    if let Some(tos_uri) = &metadata.tos_uri
264        && localised_url_has_public_suffix(tos_uri)
265    {
266        return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
267    }
268
269    if let Some(initiate_login_uri) = &metadata.initiate_login_uri
270        && host_is_public_suffix(initiate_login_uri)
271    {
272        return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
273    }
274
275    for redirect_uri in metadata.redirect_uris() {
276        if host_is_public_suffix(redirect_uri) {
277            return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
278        }
279    }
280
281    let res = policy
282        .evaluate_client_registration(mas_policy::ClientRegistrationInput {
283            client_metadata: &metadata,
284            requester: mas_policy::Requester {
285                ip_address: activity_tracker.ip(),
286                user_agent,
287            },
288        })
289        .await?;
290    if !res.valid() {
291        return Err(RouteError::PolicyDenied(res));
292    }
293
294    let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
295        Some(
296            OAuthClientAuthenticationMethod::ClientSecretJwt
297            | OAuthClientAuthenticationMethod::ClientSecretPost
298            | OAuthClientAuthenticationMethod::ClientSecretBasic,
299        ) => {
300            // Let's generate a random client secret
301            let client_secret = Alphanumeric.sample_string(&mut rng, 20);
302            let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
303            (Some(client_secret), Some(encrypted_client_secret))
304        }
305        _ => (None, None),
306    };
307
308    // If the client doesn't have a secret, we may be able to deduplicate it. To
309    // do so, we hash the client metadata, and look for it in the database
310    let (digest_hash, existing_client) = if client_secret.is_none() {
311        // XXX: One interesting caveat is that we hash *before* saving to the database.
312        // It means it takes into account fields that we don't care about *yet*.
313        //
314        // This means that if later we start supporting a particular field, we
315        // will still serve the 'old' client_id, without updating the client in the
316        // database
317        let hash = sha2::Sha256::digest(body_json);
318        let hash = hex::encode(hash);
319        let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
320        (Some(hash), client)
321    } else {
322        (None, None)
323    };
324
325    let client = if let Some(client) = existing_client {
326        tracing::info!(%client.id, "Reusing existing client");
327        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "reused")]);
328        client
329    } else {
330        let client = repo
331            .oauth2_client()
332            .add(
333                &mut rng,
334                &clock,
335                metadata.redirect_uris().to_vec(),
336                digest_hash,
337                encrypted_client_secret,
338                metadata.application_type.clone(),
339                //&metadata.response_types(),
340                metadata.grant_types().to_vec(),
341                metadata
342                    .client_name
343                    .clone()
344                    .map(Localized::to_non_localized),
345                metadata.logo_uri.clone().map(Localized::to_non_localized),
346                metadata.client_uri.clone().map(Localized::to_non_localized),
347                metadata.policy_uri.clone().map(Localized::to_non_localized),
348                metadata.tos_uri.clone().map(Localized::to_non_localized),
349                metadata.jwks_uri.clone(),
350                metadata.jwks.clone(),
351                // XXX: those might not be right, should be function calls
352                metadata.id_token_signed_response_alg.clone(),
353                metadata.userinfo_signed_response_alg.clone(),
354                metadata.token_endpoint_auth_method.clone(),
355                metadata.token_endpoint_auth_signing_alg.clone(),
356                metadata.initiate_login_uri.clone(),
357            )
358            .await?;
359        tracing::info!(%client.id, "Registered new client");
360        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "created")]);
361        client
362    };
363
364    let response = ClientRegistrationResponse {
365        client_id: client.client_id.clone(),
366        client_secret,
367        // XXX: we should have a `created_at` field on the clients
368        client_id_issued_at: Some(client.id.datetime().into()),
369        client_secret_expires_at: None,
370    };
371
372    // We round-trip back to the metadata to output it in the response
373    // This should never fail, as the client is valid
374    let metadata = client.into_metadata().validate()?;
375
376    repo.save().await?;
377
378    let response = RouteResponse { response, metadata };
379
380    Ok((StatusCode::CREATED, Json(response)))
381}
382
383#[cfg(test)]
384mod tests {
385    use hyper::{Request, StatusCode};
386    use mas_router::SimpleRoute;
387    use oauth2_types::{
388        errors::{ClientError, ClientErrorCode},
389        registration::ClientRegistrationResponse,
390    };
391    use sqlx::PgPool;
392    use url::Url;
393
394    use crate::{
395        oauth2::registration::host_is_public_suffix,
396        test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
397    };
398
399    #[test]
400    fn test_public_suffix_list() {
401        fn url_is_public_suffix(url: &str) -> bool {
402            host_is_public_suffix(&Url::parse(url).unwrap())
403        }
404
405        assert!(url_is_public_suffix("https://.com"));
406        assert!(url_is_public_suffix("https://.com."));
407        assert!(url_is_public_suffix("https://co.uk"));
408        assert!(url_is_public_suffix("https://github.io"));
409        assert!(!url_is_public_suffix("https://example.com"));
410        assert!(!url_is_public_suffix("https://example.com."));
411        assert!(!url_is_public_suffix("https://x.com"));
412        assert!(!url_is_public_suffix("https://x.com."));
413        assert!(!url_is_public_suffix("https://matrix-org.github.io"));
414        assert!(!url_is_public_suffix("http://localhost"));
415        assert!(!url_is_public_suffix("org.matrix:/callback"));
416        assert!(!url_is_public_suffix("http://somerandominternaldomain"));
417    }
418
419    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
420    async fn test_registration_error(pool: PgPool) {
421        setup();
422        let state = TestState::from_pool(pool).await.unwrap();
423
424        // Body is not a JSON
425        let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
426            .body("this is not a json".to_owned())
427            .unwrap();
428
429        let response = state.request(request).await;
430        response.assert_status(StatusCode::BAD_REQUEST);
431        let response: ClientError = response.json();
432        assert_eq!(response.error, ClientErrorCode::InvalidRequest);
433
434        // Invalid client metadata
435        let request =
436            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
437                "client_uri": "this is not a uri",
438            }));
439
440        let response = state.request(request).await;
441        response.assert_status(StatusCode::BAD_REQUEST);
442        let response: ClientError = response.json();
443        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
444
445        // Invalid redirect URI
446        let request =
447            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
448                "application_type": "web",
449                "client_uri": "https://example.com/",
450                "redirect_uris": ["http://this-is-insecure.com/"],
451            }));
452
453        let response = state.request(request).await;
454        response.assert_status(StatusCode::BAD_REQUEST);
455        let response: ClientError = response.json();
456        assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
457
458        // Incoherent response types
459        let request =
460            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
461                "client_uri": "https://example.com/",
462                "redirect_uris": ["https://example.com/"],
463                "response_types": ["id_token"],
464                "grant_types": ["authorization_code"],
465            }));
466
467        let response = state.request(request).await;
468        response.assert_status(StatusCode::BAD_REQUEST);
469        let response: ClientError = response.json();
470        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
471
472        // Using a public suffix
473        let request =
474            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
475                "client_uri": "https://github.io/",
476                "redirect_uris": ["https://github.io/"],
477                "response_types": ["code"],
478                "grant_types": ["authorization_code"],
479                "token_endpoint_auth_method": "client_secret_basic",
480            }));
481
482        let response = state.request(request).await;
483        response.assert_status(StatusCode::BAD_REQUEST);
484        let response: ClientError = response.json();
485        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
486        assert_eq!(
487            response.error_description.unwrap(),
488            "client_uri is not using a valid domain"
489        );
490
491        // Using a public suffix in a translated URL
492        let request =
493            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
494                "client_uri": "https://example.com/",
495                "client_uri#fr-FR": "https://github.io/",
496                "redirect_uris": ["https://example.com/"],
497                "response_types": ["code"],
498                "grant_types": ["authorization_code"],
499                "token_endpoint_auth_method": "client_secret_basic",
500            }));
501
502        let response = state.request(request).await;
503        response.assert_status(StatusCode::BAD_REQUEST);
504        let response: ClientError = response.json();
505        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
506        assert_eq!(
507            response.error_description.unwrap(),
508            "client_uri is not using a valid domain"
509        );
510    }
511
512    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
513    async fn test_registration(pool: PgPool) {
514        setup();
515        let state = TestState::from_pool(pool).await.unwrap();
516
517        // A successful registration with no authentication should not return a client
518        // secret
519        let request =
520            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
521                "client_uri": "https://example.com/",
522                "redirect_uris": ["https://example.com/"],
523                "response_types": ["code"],
524                "grant_types": ["authorization_code"],
525                "token_endpoint_auth_method": "none",
526            }));
527
528        let response = state.request(request).await;
529        response.assert_status(StatusCode::CREATED);
530        let response: ClientRegistrationResponse = response.json();
531        assert!(response.client_secret.is_none());
532
533        // A successful registration with client_secret based authentication should
534        // return a client secret
535        let request =
536            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
537                "client_uri": "https://example.com/",
538                "redirect_uris": ["https://example.com/"],
539                "response_types": ["code"],
540                "grant_types": ["authorization_code"],
541                "token_endpoint_auth_method": "client_secret_basic",
542            }));
543
544        let response = state.request(request).await;
545        response.assert_status(StatusCode::CREATED);
546        let response: ClientRegistrationResponse = response.json();
547        assert!(response.client_secret.is_some());
548    }
549    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
550    async fn test_registration_dedupe(pool: PgPool) {
551        setup();
552        let state = TestState::from_pool(pool).await.unwrap();
553
554        // Post a client registration twice, we should get the same client ID
555        let request =
556            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
557                "client_uri": "https://example.com/",
558                "client_name": "Example",
559                "client_name#en": "Example",
560                "client_name#fr": "Exemple",
561                "client_name#de": "Beispiel",
562                "redirect_uris": ["https://example.com/", "https://example.com/callback"],
563                "response_types": ["code"],
564                "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
565                "token_endpoint_auth_method": "none",
566            }));
567
568        let response = state.request(request.clone()).await;
569        response.assert_status(StatusCode::CREATED);
570        let response: ClientRegistrationResponse = response.json();
571        let client_id = response.client_id;
572
573        let response = state.request(request).await;
574        response.assert_status(StatusCode::CREATED);
575        let response: ClientRegistrationResponse = response.json();
576        assert_eq!(response.client_id, client_id);
577
578        // Check that the order of some properties doesn't matter
579        let request =
580            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
581                "client_uri": "https://example.com/",
582                "client_name": "Example",
583                "client_name#de": "Beispiel",
584                "client_name#fr": "Exemple",
585                "client_name#en": "Example",
586                "redirect_uris": ["https://example.com/callback", "https://example.com/"],
587                "response_types": ["code"],
588                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
589                "token_endpoint_auth_method": "none",
590            }));
591
592        let response = state.request(request).await;
593        response.assert_status(StatusCode::CREATED);
594        let response: ClientRegistrationResponse = response.json();
595        assert_eq!(response.client_id, client_id);
596
597        // Doing that with a client that has a client_secret should not deduplicate
598        let request =
599            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
600                "client_uri": "https://example.com/",
601                "redirect_uris": ["https://example.com/"],
602                "response_types": ["code"],
603                "grant_types": ["authorization_code"],
604                "token_endpoint_auth_method": "client_secret_basic",
605            }));
606
607        let response = state.request(request.clone()).await;
608        response.assert_status(StatusCode::CREATED);
609        let response: ClientRegistrationResponse = response.json();
610        // Sanity check that the client_id is different
611        assert_ne!(response.client_id, client_id);
612        let client_id = response.client_id;
613
614        let response = state.request(request).await;
615        response.assert_status(StatusCode::CREATED);
616        let response: ClientRegistrationResponse = response.json();
617        assert_ne!(response.client_id, client_id);
618    }
619}