mas_handlers/oauth2/
revoke.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use hyper::StatusCode;
9use mas_axum_utils::{
10    client_authorization::{ClientAuthorization, CredentialsVerificationError},
11    sentry::SentryEventID,
12};
13use mas_data_model::TokenType;
14use mas_iana::oauth::OAuthTokenTypeHint;
15use mas_keystore::Encrypter;
16use mas_storage::{
17    BoxClock, BoxRepository, BoxRng, RepositoryAccess,
18    queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
19};
20use oauth2_types::{
21    errors::{ClientError, ClientErrorCode},
22    requests::RevocationRequest,
23};
24use thiserror::Error;
25
26use crate::{BoundActivityTracker, impl_from_error_for_route};
27
28#[derive(Debug, Error)]
29pub(crate) enum RouteError {
30    #[error(transparent)]
31    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
32
33    #[error("bad request")]
34    BadRequest,
35
36    #[error("client not found")]
37    ClientNotFound,
38
39    #[error("client not allowed")]
40    ClientNotAllowed,
41
42    #[error("could not verify client credentials")]
43    ClientCredentialsVerification(#[from] CredentialsVerificationError),
44
45    #[error("client is unauthorized")]
46    UnauthorizedClient,
47
48    #[error("unsupported token type")]
49    UnsupportedTokenType,
50
51    #[error("unknown token")]
52    UnknownToken,
53}
54
55impl IntoResponse for RouteError {
56    fn into_response(self) -> axum::response::Response {
57        let event_id = sentry::capture_error(&self);
58        let response = match self {
59            Self::Internal(_) => (
60                StatusCode::INTERNAL_SERVER_ERROR,
61                Json(ClientError::from(ClientErrorCode::ServerError)),
62            )
63                .into_response(),
64
65            Self::BadRequest => (
66                StatusCode::BAD_REQUEST,
67                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
68            )
69                .into_response(),
70
71            Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
72                StatusCode::UNAUTHORIZED,
73                Json(ClientError::from(ClientErrorCode::InvalidClient)),
74            )
75                .into_response(),
76
77            Self::ClientNotAllowed | Self::UnauthorizedClient => (
78                StatusCode::UNAUTHORIZED,
79                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
80            )
81                .into_response(),
82
83            Self::UnsupportedTokenType => (
84                StatusCode::BAD_REQUEST,
85                Json(ClientError::from(ClientErrorCode::UnsupportedTokenType)),
86            )
87                .into_response(),
88
89            // If the token is unknown, we still return a 200 OK response.
90            Self::UnknownToken => StatusCode::OK.into_response(),
91        };
92
93        (SentryEventID::from(event_id), response).into_response()
94    }
95}
96
97impl_from_error_for_route!(mas_storage::RepositoryError);
98
99impl From<mas_data_model::TokenFormatError> for RouteError {
100    fn from(_e: mas_data_model::TokenFormatError) -> Self {
101        Self::UnknownToken
102    }
103}
104
105#[tracing::instrument(
106    name = "handlers.oauth2.revoke.post",
107    fields(client.id = client_authorization.client_id()),
108    skip_all,
109    err,
110)]
111pub(crate) async fn post(
112    clock: BoxClock,
113    mut rng: BoxRng,
114    State(http_client): State<reqwest::Client>,
115    mut repo: BoxRepository,
116    activity_tracker: BoundActivityTracker,
117    State(encrypter): State<Encrypter>,
118    client_authorization: ClientAuthorization<RevocationRequest>,
119) -> Result<impl IntoResponse, RouteError> {
120    let client = client_authorization
121        .credentials
122        .fetch(&mut repo)
123        .await?
124        .ok_or(RouteError::ClientNotFound)?;
125
126    let method = client
127        .token_endpoint_auth_method
128        .as_ref()
129        .ok_or(RouteError::ClientNotAllowed)?;
130
131    client_authorization
132        .credentials
133        .verify(&http_client, &encrypter, method, &client)
134        .await?;
135
136    let Some(form) = client_authorization.form else {
137        return Err(RouteError::BadRequest);
138    };
139
140    let token_type = TokenType::check(&form.token)?;
141
142    // Find the ID of the session to end.
143    let session_id = match (form.token_type_hint, token_type) {
144        (Some(OAuthTokenTypeHint::AccessToken) | None, TokenType::AccessToken) => {
145            let access_token = repo
146                .oauth2_access_token()
147                .find_by_token(&form.token)
148                .await?
149                .ok_or(RouteError::UnknownToken)?;
150
151            if !access_token.is_valid(clock.now()) {
152                return Err(RouteError::UnknownToken);
153            }
154            access_token.session_id
155        }
156
157        (Some(OAuthTokenTypeHint::RefreshToken) | None, TokenType::RefreshToken) => {
158            let refresh_token = repo
159                .oauth2_refresh_token()
160                .find_by_token(&form.token)
161                .await?
162                .ok_or(RouteError::UnknownToken)?;
163
164            if !refresh_token.is_valid() {
165                return Err(RouteError::UnknownToken);
166            }
167
168            refresh_token.session_id
169        }
170
171        // This case can happen if there is a mismatch between the token type hint and the guessed
172        // token type or if the token was a compat access/refresh token. In those cases, we return
173        // an unknown token error.
174        (Some(OAuthTokenTypeHint::AccessToken | OAuthTokenTypeHint::RefreshToken) | None, _) => {
175            return Err(RouteError::UnknownToken);
176        }
177
178        (Some(_), _) => return Err(RouteError::UnsupportedTokenType),
179    };
180
181    let session = repo
182        .oauth2_session()
183        .lookup(session_id)
184        .await?
185        .ok_or(RouteError::UnknownToken)?;
186
187    // Check that the session is still valid.
188    if !session.is_valid() {
189        return Err(RouteError::UnknownToken);
190    }
191
192    // Check that the client ending the session is the same as the client that
193    // created it.
194    if client.id != session.client_id {
195        return Err(RouteError::UnauthorizedClient);
196    }
197
198    activity_tracker
199        .record_oauth2_session(&clock, &session)
200        .await;
201
202    // If the session is associated with a user, make sure we schedule a device
203    // deletion job for all the devices associated with the session.
204    if let Some(user_id) = session.user_id {
205        // Fetch the user
206        let user = repo
207            .user()
208            .lookup(user_id)
209            .await?
210            .ok_or(RouteError::UnknownToken)?;
211
212        // Schedule a job to sync the devices of the user with the homeserver
213        repo.queue_job()
214            .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
215            .await?;
216    }
217
218    // Now that we checked everything, we can end the session.
219    repo.oauth2_session().finish(&clock, session).await?;
220
221    repo.save().await?;
222
223    Ok(())
224}
225
226#[cfg(test)]
227mod tests {
228    use chrono::Duration;
229    use hyper::Request;
230    use mas_data_model::{AccessToken, RefreshToken};
231    use mas_router::SimpleRoute;
232    use mas_storage::RepositoryAccess;
233    use oauth2_types::{
234        registration::ClientRegistrationResponse,
235        requests::AccessTokenResponse,
236        scope::{OPENID, Scope},
237    };
238    use sqlx::PgPool;
239
240    use super::*;
241    use crate::{
242        oauth2::generate_token_pair,
243        test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
244    };
245
246    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
247    async fn test_revoke_access_token(pool: PgPool) {
248        setup();
249        let state = TestState::from_pool(pool).await.unwrap();
250
251        let request =
252            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
253                "client_uri": "https://example.com/",
254                "redirect_uris": ["https://example.com/callback"],
255                "token_endpoint_auth_method": "client_secret_post",
256                "response_types": ["code"],
257                "grant_types": ["authorization_code", "refresh_token"],
258            }));
259
260        let response = state.request(request).await;
261        response.assert_status(StatusCode::CREATED);
262
263        let client_registration: ClientRegistrationResponse = response.json();
264
265        let client_id = client_registration.client_id;
266        let client_secret = client_registration.client_secret.unwrap();
267
268        // Let's provision a user and create a session for them. This part is hard to
269        // test with just HTTP requests, so we'll use the repository directly.
270        let mut repo = state.repository().await.unwrap();
271
272        let user = repo
273            .user()
274            .add(&mut state.rng(), &state.clock, "alice".to_owned())
275            .await
276            .unwrap();
277
278        let browser_session = repo
279            .browser_session()
280            .add(&mut state.rng(), &state.clock, &user, None)
281            .await
282            .unwrap();
283
284        // Lookup the client in the database.
285        let client = repo
286            .oauth2_client()
287            .find_by_client_id(&client_id)
288            .await
289            .unwrap()
290            .unwrap();
291
292        let session = repo
293            .oauth2_session()
294            .add_from_browser_session(
295                &mut state.rng(),
296                &state.clock,
297                &client,
298                &browser_session,
299                Scope::from_iter([OPENID]),
300            )
301            .await
302            .unwrap();
303
304        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
305            generate_token_pair(
306                &mut state.rng(),
307                &state.clock,
308                &mut repo,
309                &session,
310                Duration::microseconds(5 * 60 * 1000 * 1000),
311            )
312            .await
313            .unwrap();
314
315        repo.save().await.unwrap();
316
317        // Check that the token is valid
318        assert!(state.is_access_token_valid(&access_token).await);
319
320        // Now let's revoke the access token.
321        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
322            "token": access_token,
323            "token_type_hint": "access_token",
324            "client_id": client_id,
325            "client_secret": client_secret,
326        }));
327
328        let response = state.request(request).await;
329        response.assert_status(StatusCode::OK);
330
331        // Check that the token is no longer valid
332        assert!(!state.is_access_token_valid(&access_token).await);
333
334        // Revoking a second time shouldn't fail
335        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
336            "token": access_token,
337            "token_type_hint": "access_token",
338            "client_id": client_id,
339            "client_secret": client_secret,
340        }));
341
342        let response = state.request(request).await;
343        response.assert_status(StatusCode::OK);
344
345        // Try using the refresh token to get a new access token, it should fail.
346        let request =
347            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
348                "grant_type": "refresh_token",
349                "refresh_token": refresh_token,
350                "client_id": client_id,
351                "client_secret": client_secret,
352            }));
353
354        let response = state.request(request).await;
355        response.assert_status(StatusCode::BAD_REQUEST);
356
357        // Now try with a new grant, and by revoking the refresh token instead
358        let mut repo = state.repository().await.unwrap();
359        let session = repo
360            .oauth2_session()
361            .add_from_browser_session(
362                &mut state.rng(),
363                &state.clock,
364                &client,
365                &browser_session,
366                Scope::from_iter([OPENID]),
367            )
368            .await
369            .unwrap();
370
371        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
372            generate_token_pair(
373                &mut state.rng(),
374                &state.clock,
375                &mut repo,
376                &session,
377                Duration::microseconds(5 * 60 * 1000 * 1000),
378            )
379            .await
380            .unwrap();
381
382        repo.save().await.unwrap();
383
384        // Use the refresh token to get a new access token.
385        let request =
386            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
387                "grant_type": "refresh_token",
388                "refresh_token": refresh_token,
389                "client_id": client_id,
390                "client_secret": client_secret,
391            }));
392
393        let response = state.request(request).await;
394        response.assert_status(StatusCode::OK);
395
396        let old_access_token = access_token;
397        let old_refresh_token = refresh_token;
398        let AccessTokenResponse {
399            access_token,
400            refresh_token,
401            ..
402        } = response.json();
403        assert!(state.is_access_token_valid(&access_token).await);
404        assert!(!state.is_access_token_valid(&old_access_token).await);
405
406        // Revoking the old access token shouldn't do anything.
407        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
408            "token": old_access_token,
409            "token_type_hint": "access_token",
410            "client_id": client_id,
411            "client_secret": client_secret,
412        }));
413
414        let response = state.request(request).await;
415        response.assert_status(StatusCode::OK);
416
417        assert!(state.is_access_token_valid(&access_token).await);
418
419        // Revoking the old refresh token shouldn't do anything.
420        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
421            "token": old_refresh_token,
422            "token_type_hint": "refresh_token",
423            "client_id": client_id,
424            "client_secret": client_secret,
425        }));
426
427        let response = state.request(request).await;
428        response.assert_status(StatusCode::OK);
429
430        assert!(state.is_access_token_valid(&access_token).await);
431
432        // Revoking the new refresh token should invalidate the session
433        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
434            "token": refresh_token,
435            "token_type_hint": "refresh_token",
436            "client_id": client_id,
437            "client_secret": client_secret,
438        }));
439
440        let response = state.request(request).await;
441        response.assert_status(StatusCode::OK);
442
443        assert!(!state.is_access_token_valid(&access_token).await);
444    }
445}