mas_handlers/oauth2/device/
authorize.rs
1use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::typed_header::TypedHeader;
9use chrono::Duration;
10use headers::{CacheControl, Pragma};
11use hyper::StatusCode;
12use mas_axum_utils::{
13 client_authorization::{ClientAuthorization, CredentialsVerificationError},
14 sentry::SentryEventID,
15};
16use mas_data_model::UserAgent;
17use mas_keystore::Encrypter;
18use mas_router::UrlBuilder;
19use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2DeviceCodeGrantParams};
20use oauth2_types::{
21 errors::{ClientError, ClientErrorCode},
22 requests::{DeviceAuthorizationRequest, DeviceAuthorizationResponse, GrantType},
23 scope::ScopeToken,
24};
25use rand::distributions::{Alphanumeric, DistString};
26use thiserror::Error;
27
28use crate::{BoundActivityTracker, impl_from_error_for_route};
29
30#[derive(Debug, Error)]
31pub(crate) enum RouteError {
32 #[error(transparent)]
33 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
34
35 #[error("client not found")]
36 ClientNotFound,
37
38 #[error("client not allowed")]
39 ClientNotAllowed,
40
41 #[error("could not verify client credentials")]
42 ClientCredentialsVerification(#[from] CredentialsVerificationError),
43}
44
45impl_from_error_for_route!(mas_storage::RepositoryError);
46
47impl IntoResponse for RouteError {
48 fn into_response(self) -> axum::response::Response {
49 let event_id = sentry::capture_error(&self);
50
51 let response = match self {
52 Self::Internal(_) => (
53 StatusCode::INTERNAL_SERVER_ERROR,
54 Json(ClientError::from(ClientErrorCode::ServerError)),
55 ),
56 Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
57 StatusCode::UNAUTHORIZED,
58 Json(ClientError::from(ClientErrorCode::InvalidClient)),
59 ),
60 Self::ClientNotAllowed => (
61 StatusCode::UNAUTHORIZED,
62 Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
63 ),
64 };
65
66 (SentryEventID::from(event_id), response).into_response()
67 }
68}
69
70#[tracing::instrument(
71 name = "handlers.oauth2.device.request.post",
72 fields(client.id = client_authorization.client_id()),
73 skip_all,
74 err,
75)]
76pub(crate) async fn post(
77 mut rng: BoxRng,
78 clock: BoxClock,
79 mut repo: BoxRepository,
80 user_agent: Option<TypedHeader<headers::UserAgent>>,
81 activity_tracker: BoundActivityTracker,
82 State(url_builder): State<UrlBuilder>,
83 State(http_client): State<reqwest::Client>,
84 State(encrypter): State<Encrypter>,
85 client_authorization: ClientAuthorization<DeviceAuthorizationRequest>,
86) -> Result<impl IntoResponse, RouteError> {
87 let client = client_authorization
88 .credentials
89 .fetch(&mut repo)
90 .await?
91 .ok_or(RouteError::ClientNotFound)?;
92
93 let method = client
95 .token_endpoint_auth_method
96 .as_ref()
97 .ok_or(RouteError::ClientNotAllowed)?;
98
99 client_authorization
100 .credentials
101 .verify(&http_client, &encrypter, method, &client)
102 .await?;
103
104 if !client.grant_types.contains(&GrantType::DeviceCode) {
105 return Err(RouteError::ClientNotAllowed);
106 }
107
108 let scope = client_authorization
109 .form
110 .and_then(|f| f.scope)
111 .unwrap_or(std::iter::empty::<ScopeToken>().collect());
113
114 let expires_in = Duration::microseconds(20 * 60 * 1000 * 1000);
115
116 let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
117 let ip_address = activity_tracker.ip();
118
119 let device_code = Alphanumeric.sample_string(&mut rng, 32);
120 let user_code = Alphanumeric.sample_string(&mut rng, 6).to_uppercase();
121
122 let device_code = repo
123 .oauth2_device_code_grant()
124 .add(
125 &mut rng,
126 &clock,
127 OAuth2DeviceCodeGrantParams {
128 client: &client,
129 scope,
130 device_code,
131 user_code,
132 expires_in,
133 user_agent,
134 ip_address,
135 },
136 )
137 .await?;
138
139 repo.save().await?;
140
141 let response = DeviceAuthorizationResponse {
142 device_code: device_code.device_code,
143 user_code: device_code.user_code.clone(),
144 verification_uri: url_builder.device_code_link(),
145 verification_uri_complete: Some(url_builder.device_code_link_full(device_code.user_code)),
146 expires_in,
147 interval: Some(Duration::microseconds(5 * 1000 * 1000)),
148 };
149
150 Ok((
151 StatusCode::OK,
152 TypedHeader(CacheControl::new().with_no_store()),
153 TypedHeader(Pragma::no_cache()),
154 Json(response),
155 ))
156}
157
158#[cfg(test)]
159mod tests {
160 use hyper::{Request, StatusCode};
161 use mas_router::SimpleRoute;
162 use oauth2_types::{
163 registration::ClientRegistrationResponse, requests::DeviceAuthorizationResponse,
164 };
165 use sqlx::PgPool;
166
167 use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
168
169 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
170 async fn test_device_code_request(pool: PgPool) {
171 setup();
172 let state = TestState::from_pool(pool).await.unwrap();
173
174 let request =
176 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
177 "client_uri": "https://example.com/",
178 "token_endpoint_auth_method": "none",
179 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"],
180 "response_types": [],
181 }));
182
183 let response = state.request(request).await;
184 response.assert_status(StatusCode::CREATED);
185
186 let response: ClientRegistrationResponse = response.json();
187 let client_id = response.client_id;
188
189 let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
191 serde_json::json!({
192 "client_id": client_id,
193 "scope": "openid",
194 }),
195 );
196 let response = state.request(request).await;
197 response.assert_status(StatusCode::OK);
198
199 let response: DeviceAuthorizationResponse = response.json();
200 assert_eq!(response.device_code.len(), 32);
201 assert_eq!(response.user_code.len(), 6);
202 }
203}