mas_handlers/oauth2/device/
authorize.rs1use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::typed_header::TypedHeader;
9use chrono::Duration;
10use headers::{CacheControl, Pragma};
11use hyper::StatusCode;
12use mas_axum_utils::{
13 client_authorization::{ClientAuthorization, CredentialsVerificationError},
14 record_error,
15};
16use mas_data_model::{BoxClock, BoxRng};
17use mas_keystore::Encrypter;
18use mas_router::UrlBuilder;
19use mas_storage::{BoxRepository, oauth2::OAuth2DeviceCodeGrantParams};
20use oauth2_types::{
21 errors::{ClientError, ClientErrorCode},
22 requests::{DeviceAuthorizationRequest, DeviceAuthorizationResponse, GrantType},
23 scope::ScopeToken,
24};
25use rand::distributions::{Alphanumeric, DistString};
26use thiserror::Error;
27use ulid::Ulid;
28
29use crate::{BoundActivityTracker, impl_from_error_for_route};
30
31#[derive(Debug, Error)]
32pub(crate) enum RouteError {
33 #[error(transparent)]
34 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
35
36 #[error("client not found")]
37 ClientNotFound,
38
39 #[error("client {0} is not allowed to use the device code grant")]
40 ClientNotAllowed(Ulid),
41
42 #[error("invalid client credentials for client {client_id}")]
43 InvalidClientCredentials {
44 client_id: Ulid,
45 #[source]
46 source: CredentialsVerificationError,
47 },
48
49 #[error("could not verify client credentials for client {client_id}")]
50 ClientCredentialsVerification {
51 client_id: Ulid,
52 #[source]
53 source: CredentialsVerificationError,
54 },
55}
56
57impl_from_error_for_route!(mas_storage::RepositoryError);
58
59impl IntoResponse for RouteError {
60 fn into_response(self) -> axum::response::Response {
61 let sentry_event_id = record_error!(self, Self::Internal(_));
62
63 let response = match self {
64 Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
65 StatusCode::INTERNAL_SERVER_ERROR,
66 Json(ClientError::from(ClientErrorCode::ServerError)),
67 ),
68 Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
69 StatusCode::UNAUTHORIZED,
70 Json(ClientError::from(ClientErrorCode::InvalidClient)),
71 ),
72 Self::ClientNotAllowed(_) => (
73 StatusCode::UNAUTHORIZED,
74 Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
75 ),
76 };
77
78 (sentry_event_id, response).into_response()
79 }
80}
81
82#[tracing::instrument(
83 name = "handlers.oauth2.device.request.post",
84 fields(client.id = client_authorization.client_id()),
85 skip_all,
86)]
87pub(crate) async fn post(
88 mut rng: BoxRng,
89 clock: BoxClock,
90 mut repo: BoxRepository,
91 user_agent: Option<TypedHeader<headers::UserAgent>>,
92 activity_tracker: BoundActivityTracker,
93 State(url_builder): State<UrlBuilder>,
94 State(http_client): State<reqwest::Client>,
95 State(encrypter): State<Encrypter>,
96 client_authorization: ClientAuthorization<DeviceAuthorizationRequest>,
97) -> Result<impl IntoResponse, RouteError> {
98 let client = client_authorization
99 .credentials
100 .fetch(&mut repo)
101 .await?
102 .ok_or(RouteError::ClientNotFound)?;
103
104 let method = client
106 .token_endpoint_auth_method
107 .as_ref()
108 .ok_or(RouteError::ClientNotAllowed(client.id))?;
109
110 client_authorization
111 .credentials
112 .verify(&http_client, &encrypter, method, &client)
113 .await
114 .map_err(|err| {
115 if err.is_internal() {
116 RouteError::ClientCredentialsVerification {
117 client_id: client.id,
118 source: err,
119 }
120 } else {
121 RouteError::InvalidClientCredentials {
122 client_id: client.id,
123 source: err,
124 }
125 }
126 })?;
127
128 if !client.grant_types.contains(&GrantType::DeviceCode) {
129 return Err(RouteError::ClientNotAllowed(client.id));
130 }
131
132 let scope = client_authorization
133 .form
134 .and_then(|f| f.scope)
135 .unwrap_or(std::iter::empty::<ScopeToken>().collect());
137
138 let expires_in = Duration::microseconds(20 * 60 * 1000 * 1000);
139
140 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
141 let ip_address = activity_tracker.ip();
142
143 let device_code = Alphanumeric.sample_string(&mut rng, 32);
144 let user_code = Alphanumeric.sample_string(&mut rng, 6).to_uppercase();
145
146 let device_code = repo
147 .oauth2_device_code_grant()
148 .add(
149 &mut rng,
150 &clock,
151 OAuth2DeviceCodeGrantParams {
152 client: &client,
153 scope,
154 device_code,
155 user_code,
156 expires_in,
157 user_agent,
158 ip_address,
159 },
160 )
161 .await?;
162
163 repo.save().await?;
164
165 let response = DeviceAuthorizationResponse {
166 device_code: device_code.device_code,
167 user_code: device_code.user_code.clone(),
168 verification_uri: url_builder.device_code_link(),
169 verification_uri_complete: Some(url_builder.device_code_link_full(device_code.user_code)),
170 expires_in,
171 interval: Some(Duration::microseconds(5 * 1000 * 1000)),
172 };
173
174 Ok((
175 StatusCode::OK,
176 TypedHeader(CacheControl::new().with_no_store()),
177 TypedHeader(Pragma::no_cache()),
178 Json(response),
179 ))
180}
181
182#[cfg(test)]
183mod tests {
184 use hyper::{Request, StatusCode};
185 use mas_router::SimpleRoute;
186 use oauth2_types::{
187 registration::ClientRegistrationResponse, requests::DeviceAuthorizationResponse,
188 };
189 use sqlx::PgPool;
190
191 use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
192
193 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
194 async fn test_device_code_request(pool: PgPool) {
195 setup();
196 let state = TestState::from_pool(pool).await.unwrap();
197
198 let request =
200 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
201 "client_uri": "https://example.com/",
202 "token_endpoint_auth_method": "none",
203 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"],
204 "response_types": [],
205 }));
206
207 let response = state.request(request).await;
208 response.assert_status(StatusCode::CREATED);
209
210 let response: ClientRegistrationResponse = response.json();
211 let client_id = response.client_id;
212
213 let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
215 serde_json::json!({
216 "client_id": client_id,
217 "scope": "openid",
218 }),
219 );
220 let response = state.request(request).await;
221 response.assert_status(StatusCode::OK);
222
223 let response: DeviceAuthorizationResponse = response.json();
224 assert_eq!(response.device_code.len(), 32);
225 assert_eq!(response.user_code.len(), 6);
226 }
227}