mas_handlers/oauth2/device/
authorize.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::typed_header::TypedHeader;
9use chrono::Duration;
10use headers::{CacheControl, Pragma};
11use hyper::StatusCode;
12use mas_axum_utils::{
13    client_authorization::{ClientAuthorization, CredentialsVerificationError},
14    record_error,
15};
16use mas_data_model::{BoxClock, BoxRng};
17use mas_keystore::Encrypter;
18use mas_router::UrlBuilder;
19use mas_storage::{BoxRepository, oauth2::OAuth2DeviceCodeGrantParams};
20use oauth2_types::{
21    errors::{ClientError, ClientErrorCode},
22    requests::{DeviceAuthorizationRequest, DeviceAuthorizationResponse, GrantType},
23    scope::ScopeToken,
24};
25use rand::distributions::{Alphanumeric, DistString};
26use thiserror::Error;
27use ulid::Ulid;
28
29use crate::{BoundActivityTracker, impl_from_error_for_route};
30
31#[derive(Debug, Error)]
32pub(crate) enum RouteError {
33    #[error(transparent)]
34    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
35
36    #[error("client not found")]
37    ClientNotFound,
38
39    #[error("client {0} is not allowed to use the device code grant")]
40    ClientNotAllowed(Ulid),
41
42    #[error("invalid client credentials for client {client_id}")]
43    InvalidClientCredentials {
44        client_id: Ulid,
45        #[source]
46        source: CredentialsVerificationError,
47    },
48
49    #[error("could not verify client credentials for client {client_id}")]
50    ClientCredentialsVerification {
51        client_id: Ulid,
52        #[source]
53        source: CredentialsVerificationError,
54    },
55}
56
57impl_from_error_for_route!(mas_storage::RepositoryError);
58
59impl IntoResponse for RouteError {
60    fn into_response(self) -> axum::response::Response {
61        let sentry_event_id = record_error!(self, Self::Internal(_));
62
63        let response = match self {
64            Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
65                StatusCode::INTERNAL_SERVER_ERROR,
66                Json(ClientError::from(ClientErrorCode::ServerError)),
67            ),
68            Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
69                StatusCode::UNAUTHORIZED,
70                Json(ClientError::from(ClientErrorCode::InvalidClient)),
71            ),
72            Self::ClientNotAllowed(_) => (
73                StatusCode::UNAUTHORIZED,
74                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
75            ),
76        };
77
78        (sentry_event_id, response).into_response()
79    }
80}
81
82#[tracing::instrument(
83    name = "handlers.oauth2.device.request.post",
84    fields(client.id = client_authorization.client_id()),
85    skip_all,
86)]
87pub(crate) async fn post(
88    mut rng: BoxRng,
89    clock: BoxClock,
90    mut repo: BoxRepository,
91    user_agent: Option<TypedHeader<headers::UserAgent>>,
92    activity_tracker: BoundActivityTracker,
93    State(url_builder): State<UrlBuilder>,
94    State(http_client): State<reqwest::Client>,
95    State(encrypter): State<Encrypter>,
96    client_authorization: ClientAuthorization<DeviceAuthorizationRequest>,
97) -> Result<impl IntoResponse, RouteError> {
98    let client = client_authorization
99        .credentials
100        .fetch(&mut repo)
101        .await?
102        .ok_or(RouteError::ClientNotFound)?;
103
104    // Reuse the token endpoint auth method to verify the client
105    let method = client
106        .token_endpoint_auth_method
107        .as_ref()
108        .ok_or(RouteError::ClientNotAllowed(client.id))?;
109
110    client_authorization
111        .credentials
112        .verify(&http_client, &encrypter, method, &client)
113        .await
114        .map_err(|err| {
115            if err.is_internal() {
116                RouteError::ClientCredentialsVerification {
117                    client_id: client.id,
118                    source: err,
119                }
120            } else {
121                RouteError::InvalidClientCredentials {
122                    client_id: client.id,
123                    source: err,
124                }
125            }
126        })?;
127
128    if !client.grant_types.contains(&GrantType::DeviceCode) {
129        return Err(RouteError::ClientNotAllowed(client.id));
130    }
131
132    let scope = client_authorization
133        .form
134        .and_then(|f| f.scope)
135        // XXX: Is this really how we do empty scopes?
136        .unwrap_or(std::iter::empty::<ScopeToken>().collect());
137
138    let expires_in = Duration::microseconds(20 * 60 * 1000 * 1000);
139
140    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
141    let ip_address = activity_tracker.ip();
142
143    let device_code = Alphanumeric.sample_string(&mut rng, 32);
144    let user_code = Alphanumeric.sample_string(&mut rng, 6).to_uppercase();
145
146    let device_code = repo
147        .oauth2_device_code_grant()
148        .add(
149            &mut rng,
150            &clock,
151            OAuth2DeviceCodeGrantParams {
152                client: &client,
153                scope,
154                device_code,
155                user_code,
156                expires_in,
157                user_agent,
158                ip_address,
159            },
160        )
161        .await?;
162
163    repo.save().await?;
164
165    let response = DeviceAuthorizationResponse {
166        device_code: device_code.device_code,
167        user_code: device_code.user_code.clone(),
168        verification_uri: url_builder.device_code_link(),
169        verification_uri_complete: Some(url_builder.device_code_link_full(device_code.user_code)),
170        expires_in,
171        interval: Some(Duration::microseconds(5 * 1000 * 1000)),
172    };
173
174    Ok((
175        StatusCode::OK,
176        TypedHeader(CacheControl::new().with_no_store()),
177        TypedHeader(Pragma::no_cache()),
178        Json(response),
179    ))
180}
181
182#[cfg(test)]
183mod tests {
184    use hyper::{Request, StatusCode};
185    use mas_router::SimpleRoute;
186    use oauth2_types::{
187        registration::ClientRegistrationResponse, requests::DeviceAuthorizationResponse,
188    };
189    use sqlx::PgPool;
190
191    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
192
193    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
194    async fn test_device_code_request(pool: PgPool) {
195        setup();
196        let state = TestState::from_pool(pool).await.unwrap();
197
198        // Provision a client
199        let request =
200            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
201                "client_uri": "https://example.com/",
202                "token_endpoint_auth_method": "none",
203                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"],
204                "response_types": [],
205            }));
206
207        let response = state.request(request).await;
208        response.assert_status(StatusCode::CREATED);
209
210        let response: ClientRegistrationResponse = response.json();
211        let client_id = response.client_id;
212
213        // Test the happy path: the client is allowed to use the device code grant type
214        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
215            serde_json::json!({
216                "client_id": client_id,
217                "scope": "openid",
218            }),
219        );
220        let response = state.request(request).await;
221        response.assert_status(StatusCode::OK);
222
223        let response: DeviceAuthorizationResponse = response.json();
224        assert_eq!(response.device_code.len(), 32);
225        assert_eq!(response.user_code.len(), 6);
226    }
227}