mas_handlers/oauth2/
registration.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::LazyLock;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::TypedHeader;
11use hyper::StatusCode;
12use mas_axum_utils::sentry::SentryEventID;
13use mas_iana::oauth::OAuthClientAuthenticationMethod;
14use mas_keystore::Encrypter;
15use mas_policy::{Policy, Violation};
16use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
17use oauth2_types::{
18    errors::{ClientError, ClientErrorCode},
19    registration::{
20        ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
21        VerifiedClientMetadata,
22    },
23};
24use opentelemetry::{Key, KeyValue, metrics::Counter};
25use psl::Psl;
26use rand::distributions::{Alphanumeric, DistString};
27use serde::Serialize;
28use sha2::Digest as _;
29use thiserror::Error;
30use tracing::info;
31use url::Url;
32
33use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
34
35static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
36    METER
37        .u64_counter("mas.oauth2.registration_request")
38        .with_description("Number of OAuth2 registration requests")
39        .with_unit("{request}")
40        .build()
41});
42const RESULT: Key = Key::from_static_str("result");
43
44#[derive(Debug, Error)]
45pub(crate) enum RouteError {
46    #[error(transparent)]
47    Internal(Box<dyn std::error::Error + Send + Sync>),
48
49    #[error(transparent)]
50    JsonExtract(#[from] axum::extract::rejection::JsonRejection),
51
52    #[error("invalid client metadata")]
53    InvalidClientMetadata(#[from] ClientMetadataVerificationError),
54
55    #[error("{0} is a public suffix, not a valid domain")]
56    UrlIsPublicSuffix(&'static str),
57
58    #[error("denied by the policy: {0:?}")]
59    PolicyDenied(Vec<Violation>),
60}
61
62impl_from_error_for_route!(mas_storage::RepositoryError);
63impl_from_error_for_route!(mas_policy::LoadError);
64impl_from_error_for_route!(mas_policy::EvaluationError);
65impl_from_error_for_route!(mas_keystore::aead::Error);
66impl_from_error_for_route!(serde_json::Error);
67
68impl IntoResponse for RouteError {
69    fn into_response(self) -> axum::response::Response {
70        let event_id = sentry::capture_error(&self);
71
72        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
73
74        let response = match self {
75            Self::Internal(_) => (
76                StatusCode::INTERNAL_SERVER_ERROR,
77                Json(ClientError::from(ClientErrorCode::ServerError)),
78            )
79                .into_response(),
80
81            // This error happens if we managed to parse the incomiong JSON but it can't be
82            // deserialized to the expected type. In this case we return an
83            // `invalid_client_metadata` error with the details of the error.
84            Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
85                StatusCode::BAD_REQUEST,
86                Json(
87                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
88                        .with_description(e.to_string()),
89                ),
90            )
91                .into_response(),
92
93            // For all other JSON errors we return a `invalid_request` error, since this is
94            // probably due to a malformed request.
95            Self::JsonExtract(_) => (
96                StatusCode::BAD_REQUEST,
97                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
98            )
99                .into_response(),
100
101            // This error comes from the `ClientMetadata::validate` method. We return an
102            // `invalid_redirect_uri` error if the error is related to the redirect URIs, else we
103            // return an `invalid_client_metadata` error.
104            Self::InvalidClientMetadata(
105                ClientMetadataVerificationError::MissingRedirectUris
106                | ClientMetadataVerificationError::RedirectUriWithFragment(_),
107            ) => (
108                StatusCode::BAD_REQUEST,
109                Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
110            )
111                .into_response(),
112
113            Self::InvalidClientMetadata(e) => (
114                StatusCode::BAD_REQUEST,
115                Json(
116                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
117                        .with_description(e.to_string()),
118                ),
119            )
120                .into_response(),
121
122            // This error happens if the any of the client's URIs are public suffixes. We return
123            // an `invalid_redirect_uri` error if it's a `redirect_uri`, else we return an
124            // `invalid_client_metadata` error.
125            Self::UrlIsPublicSuffix("redirect_uri") => (
126                StatusCode::BAD_REQUEST,
127                Json(
128                    ClientError::from(ClientErrorCode::InvalidRedirectUri)
129                        .with_description("redirect_uri is not using a valid domain".to_owned()),
130                ),
131            )
132                .into_response(),
133
134            Self::UrlIsPublicSuffix(field) => (
135                StatusCode::BAD_REQUEST,
136                Json(
137                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
138                        .with_description(format!("{field} is not using a valid domain")),
139                ),
140            )
141                .into_response(),
142
143            // For policy violations, we return an `invalid_client_metadata` error with the details
144            // of the violations in most cases. If a violation includes `redirect_uri` in the
145            // message, we return an `invalid_redirect_uri` error instead.
146            Self::PolicyDenied(violations) => {
147                // TODO: detect them better
148                let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
149                    ClientErrorCode::InvalidRedirectUri
150                } else {
151                    ClientErrorCode::InvalidClientMetadata
152                };
153
154                let collected = &violations
155                    .iter()
156                    .map(|v| v.msg.clone())
157                    .collect::<Vec<String>>();
158                let joined = collected.join("; ");
159
160                (
161                    StatusCode::BAD_REQUEST,
162                    Json(ClientError::from(code).with_description(joined)),
163                )
164                    .into_response()
165            }
166        };
167
168        (SentryEventID::from(event_id), response).into_response()
169    }
170}
171
172#[derive(Serialize)]
173struct RouteResponse {
174    #[serde(flatten)]
175    response: ClientRegistrationResponse,
176    #[serde(flatten)]
177    metadata: VerifiedClientMetadata,
178}
179
180/// Check if the host of the given URL is a public suffix
181fn host_is_public_suffix(url: &Url) -> bool {
182    let host = url.host_str().unwrap_or_default().as_bytes();
183    let Some(suffix) = psl::List.suffix(host) else {
184        // There is no suffix, which is the case for empty hosts, like with custom
185        // schemes
186        return false;
187    };
188
189    if !suffix.is_known() {
190        // The suffix is not known, so it's not a public suffix
191        return false;
192    }
193
194    // We want to cover two cases:
195    // - The host is the suffix itself, like `com`
196    // - The host is a dot followed by the suffix, like `.com`
197    if host.len() <= suffix.as_bytes().len() + 1 {
198        // The host only has the suffix in it, so it's a public suffix
199        return true;
200    }
201
202    false
203}
204
205/// Check if any of the URLs in the given `Localized` field is a public suffix
206fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
207    url.iter().any(|(_lang, url)| host_is_public_suffix(url))
208}
209
210#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all, err)]
211pub(crate) async fn post(
212    mut rng: BoxRng,
213    clock: BoxClock,
214    mut repo: BoxRepository,
215    mut policy: Policy,
216    activity_tracker: BoundActivityTracker,
217    user_agent: Option<TypedHeader<headers::UserAgent>>,
218    State(encrypter): State<Encrypter>,
219    body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
220) -> Result<impl IntoResponse, RouteError> {
221    // Propagate any JSON extraction error
222    let Json(body) = body?;
223
224    // Sort the properties to ensure a stable serialisation order for hashing
225    let body = body.sorted();
226
227    // We need to serialize the body to compute the hash, and to log it
228    let body_json = serde_json::to_string(&body)?;
229
230    info!(body = body_json, "Client registration");
231
232    let user_agent = user_agent.map(|ua| ua.to_string());
233
234    // Validate the body
235    let metadata = body.validate()?;
236
237    // Some extra validation that is hard to do in OPA and not done by the
238    // `validate` method either
239    if let Some(client_uri) = &metadata.client_uri {
240        if localised_url_has_public_suffix(client_uri) {
241            return Err(RouteError::UrlIsPublicSuffix("client_uri"));
242        }
243    }
244
245    if let Some(logo_uri) = &metadata.logo_uri {
246        if localised_url_has_public_suffix(logo_uri) {
247            return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
248        }
249    }
250
251    if let Some(policy_uri) = &metadata.policy_uri {
252        if localised_url_has_public_suffix(policy_uri) {
253            return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
254        }
255    }
256
257    if let Some(tos_uri) = &metadata.tos_uri {
258        if localised_url_has_public_suffix(tos_uri) {
259            return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
260        }
261    }
262
263    if let Some(initiate_login_uri) = &metadata.initiate_login_uri {
264        if host_is_public_suffix(initiate_login_uri) {
265            return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
266        }
267    }
268
269    for redirect_uri in metadata.redirect_uris() {
270        if host_is_public_suffix(redirect_uri) {
271            return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
272        }
273    }
274
275    let res = policy
276        .evaluate_client_registration(mas_policy::ClientRegistrationInput {
277            client_metadata: &metadata,
278            requester: mas_policy::Requester {
279                ip_address: activity_tracker.ip(),
280                user_agent,
281            },
282        })
283        .await?;
284    if !res.valid() {
285        return Err(RouteError::PolicyDenied(res.violations));
286    }
287
288    let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
289        Some(
290            OAuthClientAuthenticationMethod::ClientSecretJwt
291            | OAuthClientAuthenticationMethod::ClientSecretPost
292            | OAuthClientAuthenticationMethod::ClientSecretBasic,
293        ) => {
294            // Let's generate a random client secret
295            let client_secret = Alphanumeric.sample_string(&mut rng, 20);
296            let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
297            (Some(client_secret), Some(encrypted_client_secret))
298        }
299        _ => (None, None),
300    };
301
302    // If the client doesn't have a secret, we may be able to deduplicate it. To
303    // do so, we hash the client metadata, and look for it in the database
304    let (digest_hash, existing_client) = if client_secret.is_none() {
305        // XXX: One interesting caveat is that we hash *before* saving to the database.
306        // It means it takes into account fields that we don't care about *yet*.
307        //
308        // This means that if later we start supporting a particular field, we
309        // will still serve the 'old' client_id, without updating the client in the
310        // database
311        let hash = sha2::Sha256::digest(body_json);
312        let hash = hex::encode(hash);
313        let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
314        (Some(hash), client)
315    } else {
316        (None, None)
317    };
318
319    let client = if let Some(client) = existing_client {
320        tracing::info!(%client.id, "Reusing existing client");
321        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "reused")]);
322        client
323    } else {
324        let client = repo
325            .oauth2_client()
326            .add(
327                &mut rng,
328                &clock,
329                metadata.redirect_uris().to_vec(),
330                digest_hash,
331                encrypted_client_secret,
332                metadata.application_type.clone(),
333                //&metadata.response_types(),
334                metadata.grant_types().to_vec(),
335                metadata
336                    .client_name
337                    .clone()
338                    .map(Localized::to_non_localized),
339                metadata.logo_uri.clone().map(Localized::to_non_localized),
340                metadata.client_uri.clone().map(Localized::to_non_localized),
341                metadata.policy_uri.clone().map(Localized::to_non_localized),
342                metadata.tos_uri.clone().map(Localized::to_non_localized),
343                metadata.jwks_uri.clone(),
344                metadata.jwks.clone(),
345                // XXX: those might not be right, should be function calls
346                metadata.id_token_signed_response_alg.clone(),
347                metadata.userinfo_signed_response_alg.clone(),
348                metadata.token_endpoint_auth_method.clone(),
349                metadata.token_endpoint_auth_signing_alg.clone(),
350                metadata.initiate_login_uri.clone(),
351            )
352            .await?;
353        tracing::info!(%client.id, "Registered new client");
354        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "created")]);
355        client
356    };
357
358    let response = ClientRegistrationResponse {
359        client_id: client.client_id.clone(),
360        client_secret,
361        // XXX: we should have a `created_at` field on the clients
362        client_id_issued_at: Some(client.id.datetime().into()),
363        client_secret_expires_at: None,
364    };
365
366    // We round-trip back to the metadata to output it in the response
367    // This should never fail, as the client is valid
368    let metadata = client.into_metadata().validate()?;
369
370    repo.save().await?;
371
372    let response = RouteResponse { response, metadata };
373
374    Ok((StatusCode::CREATED, Json(response)))
375}
376
377#[cfg(test)]
378mod tests {
379    use hyper::{Request, StatusCode};
380    use mas_router::SimpleRoute;
381    use oauth2_types::{
382        errors::{ClientError, ClientErrorCode},
383        registration::ClientRegistrationResponse,
384    };
385    use sqlx::PgPool;
386    use url::Url;
387
388    use crate::{
389        oauth2::registration::host_is_public_suffix,
390        test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
391    };
392
393    #[test]
394    fn test_public_suffix_list() {
395        fn url_is_public_suffix(url: &str) -> bool {
396            host_is_public_suffix(&Url::parse(url).unwrap())
397        }
398
399        assert!(url_is_public_suffix("https://.com"));
400        assert!(url_is_public_suffix("https://.com."));
401        assert!(url_is_public_suffix("https://co.uk"));
402        assert!(url_is_public_suffix("https://github.io"));
403        assert!(!url_is_public_suffix("https://example.com"));
404        assert!(!url_is_public_suffix("https://example.com."));
405        assert!(!url_is_public_suffix("https://x.com"));
406        assert!(!url_is_public_suffix("https://x.com."));
407        assert!(!url_is_public_suffix("https://matrix-org.github.io"));
408        assert!(!url_is_public_suffix("http://localhost"));
409        assert!(!url_is_public_suffix("org.matrix:/callback"));
410        assert!(!url_is_public_suffix("http://somerandominternaldomain"));
411    }
412
413    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
414    async fn test_registration_error(pool: PgPool) {
415        setup();
416        let state = TestState::from_pool(pool).await.unwrap();
417
418        // Body is not a JSON
419        let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
420            .body("this is not a json".to_owned())
421            .unwrap();
422
423        let response = state.request(request).await;
424        response.assert_status(StatusCode::BAD_REQUEST);
425        let response: ClientError = response.json();
426        assert_eq!(response.error, ClientErrorCode::InvalidRequest);
427
428        // Invalid client metadata
429        let request =
430            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
431                "client_uri": "this is not a uri",
432            }));
433
434        let response = state.request(request).await;
435        response.assert_status(StatusCode::BAD_REQUEST);
436        let response: ClientError = response.json();
437        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
438
439        // Invalid redirect URI
440        let request =
441            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
442                "application_type": "web",
443                "client_uri": "https://example.com/",
444                "redirect_uris": ["http://this-is-insecure.com/"],
445            }));
446
447        let response = state.request(request).await;
448        response.assert_status(StatusCode::BAD_REQUEST);
449        let response: ClientError = response.json();
450        assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
451
452        // Incoherent response types
453        let request =
454            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
455                "client_uri": "https://example.com/",
456                "redirect_uris": ["https://example.com/"],
457                "response_types": ["id_token"],
458                "grant_types": ["authorization_code"],
459            }));
460
461        let response = state.request(request).await;
462        response.assert_status(StatusCode::BAD_REQUEST);
463        let response: ClientError = response.json();
464        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
465
466        // Using a public suffix
467        let request =
468            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
469                "client_uri": "https://github.io/",
470                "redirect_uris": ["https://github.io/"],
471                "response_types": ["code"],
472                "grant_types": ["authorization_code"],
473                "token_endpoint_auth_method": "client_secret_basic",
474            }));
475
476        let response = state.request(request).await;
477        response.assert_status(StatusCode::BAD_REQUEST);
478        let response: ClientError = response.json();
479        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
480        assert_eq!(
481            response.error_description.unwrap(),
482            "client_uri is not using a valid domain"
483        );
484
485        // Using a public suffix in a translated URL
486        let request =
487            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
488                "client_uri": "https://example.com/",
489                "client_uri#fr-FR": "https://github.io/",
490                "redirect_uris": ["https://example.com/"],
491                "response_types": ["code"],
492                "grant_types": ["authorization_code"],
493                "token_endpoint_auth_method": "client_secret_basic",
494            }));
495
496        let response = state.request(request).await;
497        response.assert_status(StatusCode::BAD_REQUEST);
498        let response: ClientError = response.json();
499        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
500        assert_eq!(
501            response.error_description.unwrap(),
502            "client_uri is not using a valid domain"
503        );
504    }
505
506    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
507    async fn test_registration(pool: PgPool) {
508        setup();
509        let state = TestState::from_pool(pool).await.unwrap();
510
511        // A successful registration with no authentication should not return a client
512        // secret
513        let request =
514            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
515                "client_uri": "https://example.com/",
516                "redirect_uris": ["https://example.com/"],
517                "response_types": ["code"],
518                "grant_types": ["authorization_code"],
519                "token_endpoint_auth_method": "none",
520            }));
521
522        let response = state.request(request).await;
523        response.assert_status(StatusCode::CREATED);
524        let response: ClientRegistrationResponse = response.json();
525        assert!(response.client_secret.is_none());
526
527        // A successful registration with client_secret based authentication should
528        // return a client secret
529        let request =
530            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
531                "client_uri": "https://example.com/",
532                "redirect_uris": ["https://example.com/"],
533                "response_types": ["code"],
534                "grant_types": ["authorization_code"],
535                "token_endpoint_auth_method": "client_secret_basic",
536            }));
537
538        let response = state.request(request).await;
539        response.assert_status(StatusCode::CREATED);
540        let response: ClientRegistrationResponse = response.json();
541        assert!(response.client_secret.is_some());
542    }
543    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
544    async fn test_registration_dedupe(pool: PgPool) {
545        setup();
546        let state = TestState::from_pool(pool).await.unwrap();
547
548        // Post a client registration twice, we should get the same client ID
549        let request =
550            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
551                "client_uri": "https://example.com/",
552                "client_name": "Example",
553                "client_name#en": "Example",
554                "client_name#fr": "Exemple",
555                "client_name#de": "Beispiel",
556                "redirect_uris": ["https://example.com/", "https://example.com/callback"],
557                "response_types": ["code"],
558                "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
559                "token_endpoint_auth_method": "none",
560            }));
561
562        let response = state.request(request.clone()).await;
563        response.assert_status(StatusCode::CREATED);
564        let response: ClientRegistrationResponse = response.json();
565        let client_id = response.client_id;
566
567        let response = state.request(request).await;
568        response.assert_status(StatusCode::CREATED);
569        let response: ClientRegistrationResponse = response.json();
570        assert_eq!(response.client_id, client_id);
571
572        // Check that the order of some properties doesn't matter
573        let request =
574            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
575                "client_uri": "https://example.com/",
576                "client_name": "Example",
577                "client_name#de": "Beispiel",
578                "client_name#fr": "Exemple",
579                "client_name#en": "Example",
580                "redirect_uris": ["https://example.com/callback", "https://example.com/"],
581                "response_types": ["code"],
582                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
583                "token_endpoint_auth_method": "none",
584            }));
585
586        let response = state.request(request).await;
587        response.assert_status(StatusCode::CREATED);
588        let response: ClientRegistrationResponse = response.json();
589        assert_eq!(response.client_id, client_id);
590
591        // Doing that with a client that has a client_secret should not deduplicate
592        let request =
593            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
594                "client_uri": "https://example.com/",
595                "redirect_uris": ["https://example.com/"],
596                "response_types": ["code"],
597                "grant_types": ["authorization_code"],
598                "token_endpoint_auth_method": "client_secret_basic",
599            }));
600
601        let response = state.request(request.clone()).await;
602        response.assert_status(StatusCode::CREATED);
603        let response: ClientRegistrationResponse = response.json();
604        // Sanity check that the client_id is different
605        assert_ne!(response.client_id, client_id);
606        let client_id = response.client_id;
607
608        let response = state.request(request).await;
609        response.assert_status(StatusCode::CREATED);
610        let response: ClientRegistrationResponse = response.json();
611        assert_ne!(response.client_id, client_id);
612    }
613}