1use axum::{Json, extract::State, response::IntoResponse};
8use hyper::StatusCode;
9use mas_axum_utils::{
10 client_authorization::{ClientAuthorization, CredentialsVerificationError},
11 sentry::SentryEventID,
12};
13use mas_data_model::TokenType;
14use mas_iana::oauth::OAuthTokenTypeHint;
15use mas_keystore::Encrypter;
16use mas_storage::{
17 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
18 queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
19};
20use oauth2_types::{
21 errors::{ClientError, ClientErrorCode},
22 requests::RevocationRequest,
23};
24use thiserror::Error;
25
26use crate::{BoundActivityTracker, impl_from_error_for_route};
27
28#[derive(Debug, Error)]
29pub(crate) enum RouteError {
30 #[error(transparent)]
31 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
32
33 #[error("bad request")]
34 BadRequest,
35
36 #[error("client not found")]
37 ClientNotFound,
38
39 #[error("client not allowed")]
40 ClientNotAllowed,
41
42 #[error("could not verify client credentials")]
43 ClientCredentialsVerification(#[from] CredentialsVerificationError),
44
45 #[error("client is unauthorized")]
46 UnauthorizedClient,
47
48 #[error("unsupported token type")]
49 UnsupportedTokenType,
50
51 #[error("unknown token")]
52 UnknownToken,
53}
54
55impl IntoResponse for RouteError {
56 fn into_response(self) -> axum::response::Response {
57 let event_id = sentry::capture_error(&self);
58 let response = match self {
59 Self::Internal(_) => (
60 StatusCode::INTERNAL_SERVER_ERROR,
61 Json(ClientError::from(ClientErrorCode::ServerError)),
62 )
63 .into_response(),
64
65 Self::BadRequest => (
66 StatusCode::BAD_REQUEST,
67 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
68 )
69 .into_response(),
70
71 Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
72 StatusCode::UNAUTHORIZED,
73 Json(ClientError::from(ClientErrorCode::InvalidClient)),
74 )
75 .into_response(),
76
77 Self::ClientNotAllowed | Self::UnauthorizedClient => (
78 StatusCode::UNAUTHORIZED,
79 Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
80 )
81 .into_response(),
82
83 Self::UnsupportedTokenType => (
84 StatusCode::BAD_REQUEST,
85 Json(ClientError::from(ClientErrorCode::UnsupportedTokenType)),
86 )
87 .into_response(),
88
89 Self::UnknownToken => StatusCode::OK.into_response(),
91 };
92
93 (SentryEventID::from(event_id), response).into_response()
94 }
95}
96
97impl_from_error_for_route!(mas_storage::RepositoryError);
98
99impl From<mas_data_model::TokenFormatError> for RouteError {
100 fn from(_e: mas_data_model::TokenFormatError) -> Self {
101 Self::UnknownToken
102 }
103}
104
105#[tracing::instrument(
106 name = "handlers.oauth2.revoke.post",
107 fields(client.id = client_authorization.client_id()),
108 skip_all,
109 err,
110)]
111pub(crate) async fn post(
112 clock: BoxClock,
113 mut rng: BoxRng,
114 State(http_client): State<reqwest::Client>,
115 mut repo: BoxRepository,
116 activity_tracker: BoundActivityTracker,
117 State(encrypter): State<Encrypter>,
118 client_authorization: ClientAuthorization<RevocationRequest>,
119) -> Result<impl IntoResponse, RouteError> {
120 let client = client_authorization
121 .credentials
122 .fetch(&mut repo)
123 .await?
124 .ok_or(RouteError::ClientNotFound)?;
125
126 let method = client
127 .token_endpoint_auth_method
128 .as_ref()
129 .ok_or(RouteError::ClientNotAllowed)?;
130
131 client_authorization
132 .credentials
133 .verify(&http_client, &encrypter, method, &client)
134 .await?;
135
136 let Some(form) = client_authorization.form else {
137 return Err(RouteError::BadRequest);
138 };
139
140 let token_type = TokenType::check(&form.token)?;
141
142 let session_id = match (form.token_type_hint, token_type) {
144 (Some(OAuthTokenTypeHint::AccessToken) | None, TokenType::AccessToken) => {
145 let access_token = repo
146 .oauth2_access_token()
147 .find_by_token(&form.token)
148 .await?
149 .ok_or(RouteError::UnknownToken)?;
150
151 if !access_token.is_valid(clock.now()) {
152 return Err(RouteError::UnknownToken);
153 }
154 access_token.session_id
155 }
156
157 (Some(OAuthTokenTypeHint::RefreshToken) | None, TokenType::RefreshToken) => {
158 let refresh_token = repo
159 .oauth2_refresh_token()
160 .find_by_token(&form.token)
161 .await?
162 .ok_or(RouteError::UnknownToken)?;
163
164 if !refresh_token.is_valid() {
165 return Err(RouteError::UnknownToken);
166 }
167
168 refresh_token.session_id
169 }
170
171 (Some(OAuthTokenTypeHint::AccessToken | OAuthTokenTypeHint::RefreshToken) | None, _) => {
175 return Err(RouteError::UnknownToken);
176 }
177
178 (Some(_), _) => return Err(RouteError::UnsupportedTokenType),
179 };
180
181 let session = repo
182 .oauth2_session()
183 .lookup(session_id)
184 .await?
185 .ok_or(RouteError::UnknownToken)?;
186
187 if !session.is_valid() {
189 return Err(RouteError::UnknownToken);
190 }
191
192 if client.id != session.client_id {
195 return Err(RouteError::UnauthorizedClient);
196 }
197
198 activity_tracker
199 .record_oauth2_session(&clock, &session)
200 .await;
201
202 if let Some(user_id) = session.user_id {
205 let user = repo
207 .user()
208 .lookup(user_id)
209 .await?
210 .ok_or(RouteError::UnknownToken)?;
211
212 repo.queue_job()
214 .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
215 .await?;
216 }
217
218 repo.oauth2_session().finish(&clock, session).await?;
220
221 repo.save().await?;
222
223 Ok(())
224}
225
226#[cfg(test)]
227mod tests {
228 use chrono::Duration;
229 use hyper::Request;
230 use mas_data_model::{AccessToken, RefreshToken};
231 use mas_router::SimpleRoute;
232 use mas_storage::RepositoryAccess;
233 use oauth2_types::{
234 registration::ClientRegistrationResponse,
235 requests::AccessTokenResponse,
236 scope::{OPENID, Scope},
237 };
238 use sqlx::PgPool;
239
240 use super::*;
241 use crate::{
242 oauth2::generate_token_pair,
243 test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
244 };
245
246 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
247 async fn test_revoke_access_token(pool: PgPool) {
248 setup();
249 let state = TestState::from_pool(pool).await.unwrap();
250
251 let request =
252 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
253 "client_uri": "https://example.com/",
254 "redirect_uris": ["https://example.com/callback"],
255 "token_endpoint_auth_method": "client_secret_post",
256 "response_types": ["code"],
257 "grant_types": ["authorization_code", "refresh_token"],
258 }));
259
260 let response = state.request(request).await;
261 response.assert_status(StatusCode::CREATED);
262
263 let client_registration: ClientRegistrationResponse = response.json();
264
265 let client_id = client_registration.client_id;
266 let client_secret = client_registration.client_secret.unwrap();
267
268 let mut repo = state.repository().await.unwrap();
271
272 let user = repo
273 .user()
274 .add(&mut state.rng(), &state.clock, "alice".to_owned())
275 .await
276 .unwrap();
277
278 let browser_session = repo
279 .browser_session()
280 .add(&mut state.rng(), &state.clock, &user, None)
281 .await
282 .unwrap();
283
284 let client = repo
286 .oauth2_client()
287 .find_by_client_id(&client_id)
288 .await
289 .unwrap()
290 .unwrap();
291
292 let session = repo
293 .oauth2_session()
294 .add_from_browser_session(
295 &mut state.rng(),
296 &state.clock,
297 &client,
298 &browser_session,
299 Scope::from_iter([OPENID]),
300 )
301 .await
302 .unwrap();
303
304 let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
305 generate_token_pair(
306 &mut state.rng(),
307 &state.clock,
308 &mut repo,
309 &session,
310 Duration::microseconds(5 * 60 * 1000 * 1000),
311 )
312 .await
313 .unwrap();
314
315 repo.save().await.unwrap();
316
317 assert!(state.is_access_token_valid(&access_token).await);
319
320 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
322 "token": access_token,
323 "token_type_hint": "access_token",
324 "client_id": client_id,
325 "client_secret": client_secret,
326 }));
327
328 let response = state.request(request).await;
329 response.assert_status(StatusCode::OK);
330
331 assert!(!state.is_access_token_valid(&access_token).await);
333
334 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
336 "token": access_token,
337 "token_type_hint": "access_token",
338 "client_id": client_id,
339 "client_secret": client_secret,
340 }));
341
342 let response = state.request(request).await;
343 response.assert_status(StatusCode::OK);
344
345 let request =
347 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
348 "grant_type": "refresh_token",
349 "refresh_token": refresh_token,
350 "client_id": client_id,
351 "client_secret": client_secret,
352 }));
353
354 let response = state.request(request).await;
355 response.assert_status(StatusCode::BAD_REQUEST);
356
357 let mut repo = state.repository().await.unwrap();
359 let session = repo
360 .oauth2_session()
361 .add_from_browser_session(
362 &mut state.rng(),
363 &state.clock,
364 &client,
365 &browser_session,
366 Scope::from_iter([OPENID]),
367 )
368 .await
369 .unwrap();
370
371 let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
372 generate_token_pair(
373 &mut state.rng(),
374 &state.clock,
375 &mut repo,
376 &session,
377 Duration::microseconds(5 * 60 * 1000 * 1000),
378 )
379 .await
380 .unwrap();
381
382 repo.save().await.unwrap();
383
384 let request =
386 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
387 "grant_type": "refresh_token",
388 "refresh_token": refresh_token,
389 "client_id": client_id,
390 "client_secret": client_secret,
391 }));
392
393 let response = state.request(request).await;
394 response.assert_status(StatusCode::OK);
395
396 let old_access_token = access_token;
397 let old_refresh_token = refresh_token;
398 let AccessTokenResponse {
399 access_token,
400 refresh_token,
401 ..
402 } = response.json();
403 assert!(state.is_access_token_valid(&access_token).await);
404 assert!(!state.is_access_token_valid(&old_access_token).await);
405
406 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
408 "token": old_access_token,
409 "token_type_hint": "access_token",
410 "client_id": client_id,
411 "client_secret": client_secret,
412 }));
413
414 let response = state.request(request).await;
415 response.assert_status(StatusCode::OK);
416
417 assert!(state.is_access_token_valid(&access_token).await);
418
419 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
421 "token": old_refresh_token,
422 "token_type_hint": "refresh_token",
423 "client_id": client_id,
424 "client_secret": client_secret,
425 }));
426
427 let response = state.request(request).await;
428 response.assert_status(StatusCode::OK);
429
430 assert!(state.is_access_token_valid(&access_token).await);
431
432 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
434 "token": refresh_token,
435 "token_type_hint": "refresh_token",
436 "client_id": client_id,
437 "client_secret": client_secret,
438 }));
439
440 let response = state.request(request).await;
441 response.assert_status(StatusCode::OK);
442
443 assert!(!state.is_access_token_valid(&access_token).await);
444 }
445}