mas_handlers/oauth2/device/
authorize.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::typed_header::TypedHeader;
9use chrono::Duration;
10use headers::{CacheControl, Pragma};
11use hyper::StatusCode;
12use mas_axum_utils::{
13    client_authorization::{ClientAuthorization, CredentialsVerificationError},
14    sentry::SentryEventID,
15};
16use mas_data_model::UserAgent;
17use mas_keystore::Encrypter;
18use mas_router::UrlBuilder;
19use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2DeviceCodeGrantParams};
20use oauth2_types::{
21    errors::{ClientError, ClientErrorCode},
22    requests::{DeviceAuthorizationRequest, DeviceAuthorizationResponse, GrantType},
23    scope::ScopeToken,
24};
25use rand::distributions::{Alphanumeric, DistString};
26use thiserror::Error;
27
28use crate::{BoundActivityTracker, impl_from_error_for_route};
29
30#[derive(Debug, Error)]
31pub(crate) enum RouteError {
32    #[error(transparent)]
33    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
34
35    #[error("client not found")]
36    ClientNotFound,
37
38    #[error("client not allowed")]
39    ClientNotAllowed,
40
41    #[error("could not verify client credentials")]
42    ClientCredentialsVerification(#[from] CredentialsVerificationError),
43}
44
45impl_from_error_for_route!(mas_storage::RepositoryError);
46
47impl IntoResponse for RouteError {
48    fn into_response(self) -> axum::response::Response {
49        let event_id = sentry::capture_error(&self);
50
51        let response = match self {
52            Self::Internal(_) => (
53                StatusCode::INTERNAL_SERVER_ERROR,
54                Json(ClientError::from(ClientErrorCode::ServerError)),
55            ),
56            Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
57                StatusCode::UNAUTHORIZED,
58                Json(ClientError::from(ClientErrorCode::InvalidClient)),
59            ),
60            Self::ClientNotAllowed => (
61                StatusCode::UNAUTHORIZED,
62                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
63            ),
64        };
65
66        (SentryEventID::from(event_id), response).into_response()
67    }
68}
69
70#[tracing::instrument(
71    name = "handlers.oauth2.device.request.post",
72    fields(client.id = client_authorization.client_id()),
73    skip_all,
74    err,
75)]
76pub(crate) async fn post(
77    mut rng: BoxRng,
78    clock: BoxClock,
79    mut repo: BoxRepository,
80    user_agent: Option<TypedHeader<headers::UserAgent>>,
81    activity_tracker: BoundActivityTracker,
82    State(url_builder): State<UrlBuilder>,
83    State(http_client): State<reqwest::Client>,
84    State(encrypter): State<Encrypter>,
85    client_authorization: ClientAuthorization<DeviceAuthorizationRequest>,
86) -> Result<impl IntoResponse, RouteError> {
87    let client = client_authorization
88        .credentials
89        .fetch(&mut repo)
90        .await?
91        .ok_or(RouteError::ClientNotFound)?;
92
93    // Reuse the token endpoint auth method to verify the client
94    let method = client
95        .token_endpoint_auth_method
96        .as_ref()
97        .ok_or(RouteError::ClientNotAllowed)?;
98
99    client_authorization
100        .credentials
101        .verify(&http_client, &encrypter, method, &client)
102        .await?;
103
104    if !client.grant_types.contains(&GrantType::DeviceCode) {
105        return Err(RouteError::ClientNotAllowed);
106    }
107
108    let scope = client_authorization
109        .form
110        .and_then(|f| f.scope)
111        // XXX: Is this really how we do empty scopes?
112        .unwrap_or(std::iter::empty::<ScopeToken>().collect());
113
114    let expires_in = Duration::microseconds(20 * 60 * 1000 * 1000);
115
116    let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
117    let ip_address = activity_tracker.ip();
118
119    let device_code = Alphanumeric.sample_string(&mut rng, 32);
120    let user_code = Alphanumeric.sample_string(&mut rng, 6).to_uppercase();
121
122    let device_code = repo
123        .oauth2_device_code_grant()
124        .add(
125            &mut rng,
126            &clock,
127            OAuth2DeviceCodeGrantParams {
128                client: &client,
129                scope,
130                device_code,
131                user_code,
132                expires_in,
133                user_agent,
134                ip_address,
135            },
136        )
137        .await?;
138
139    repo.save().await?;
140
141    let response = DeviceAuthorizationResponse {
142        device_code: device_code.device_code,
143        user_code: device_code.user_code.clone(),
144        verification_uri: url_builder.device_code_link(),
145        verification_uri_complete: Some(url_builder.device_code_link_full(device_code.user_code)),
146        expires_in,
147        interval: Some(Duration::microseconds(5 * 1000 * 1000)),
148    };
149
150    Ok((
151        StatusCode::OK,
152        TypedHeader(CacheControl::new().with_no_store()),
153        TypedHeader(Pragma::no_cache()),
154        Json(response),
155    ))
156}
157
158#[cfg(test)]
159mod tests {
160    use hyper::{Request, StatusCode};
161    use mas_router::SimpleRoute;
162    use oauth2_types::{
163        registration::ClientRegistrationResponse, requests::DeviceAuthorizationResponse,
164    };
165    use sqlx::PgPool;
166
167    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
168
169    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
170    async fn test_device_code_request(pool: PgPool) {
171        setup();
172        let state = TestState::from_pool(pool).await.unwrap();
173
174        // Provision a client
175        let request =
176            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
177                "client_uri": "https://example.com/",
178                "token_endpoint_auth_method": "none",
179                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"],
180                "response_types": [],
181            }));
182
183        let response = state.request(request).await;
184        response.assert_status(StatusCode::CREATED);
185
186        let response: ClientRegistrationResponse = response.json();
187        let client_id = response.client_id;
188
189        // Test the happy path: the client is allowed to use the device code grant type
190        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
191            serde_json::json!({
192                "client_id": client_id,
193                "scope": "openid",
194            }),
195        );
196        let response = state.request(request).await;
197        response.assert_status(StatusCode::OK);
198
199        let response: DeviceAuthorizationResponse = response.json();
200        assert_eq!(response.device_code.len(), 32);
201        assert_eq!(response.user_code.len(), 6);
202    }
203}