mas_handlers/oauth2/
token.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::sync::{Arc, LazyLock};
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::typed_header::TypedHeader;
11use chrono::Duration;
12use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
13use hyper::StatusCode;
14use mas_axum_utils::{
15    client_authorization::{ClientAuthorization, CredentialsVerificationError},
16    record_error,
17};
18use mas_data_model::{
19    AuthorizationGrantStage, BoxClock, BoxRng, Client, Clock, Device, DeviceCodeGrantState,
20    SiteConfig, TokenType,
21};
22use mas_i18n::DataLocale;
23use mas_keystore::{Encrypter, Keystore};
24use mas_matrix::HomeserverConnection;
25use mas_oidc_client::types::scope::ScopeToken;
26use mas_policy::Policy;
27use mas_router::UrlBuilder;
28use mas_storage::{
29    BoxRepository, RepositoryAccess,
30    oauth2::{
31        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
32        OAuth2RefreshTokenRepository, OAuth2SessionRepository,
33    },
34    user::BrowserSessionRepository,
35};
36use mas_templates::{DeviceNameContext, TemplateContext, Templates};
37use oauth2_types::{
38    errors::{ClientError, ClientErrorCode},
39    pkce::CodeChallengeError,
40    requests::{
41        AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, ClientCredentialsGrant,
42        DeviceCodeGrant, GrantType, RefreshTokenGrant,
43    },
44    scope,
45};
46use opentelemetry::{Key, KeyValue, metrics::Counter};
47use thiserror::Error;
48use tracing::{debug, info, warn};
49use ulid::Ulid;
50
51use super::{generate_id_token, generate_token_pair};
52use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
53
54static TOKEN_REQUEST_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
55    METER
56        .u64_counter("mas.oauth2.token_request")
57        .with_description("How many OAuth 2.0 token requests have gone through")
58        .with_unit("{request}")
59        .build()
60});
61const GRANT_TYPE: Key = Key::from_static_str("grant_type");
62const RESULT: Key = Key::from_static_str("successful");
63
64#[derive(Debug, Error)]
65pub(crate) enum RouteError {
66    #[error(transparent)]
67    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
68
69    #[error("bad request")]
70    BadRequest,
71
72    #[error("pkce verification failed")]
73    PkceVerification(#[from] CodeChallengeError),
74
75    #[error("client not found")]
76    ClientNotFound,
77
78    #[error("client not allowed to use the token endpoint: {0}")]
79    ClientNotAllowed(Ulid),
80
81    #[error("invalid client credentials for client {client_id}")]
82    InvalidClientCredentials {
83        client_id: Ulid,
84        #[source]
85        source: CredentialsVerificationError,
86    },
87
88    #[error("could not verify client credentials for client {client_id}")]
89    ClientCredentialsVerification {
90        client_id: Ulid,
91        #[source]
92        source: CredentialsVerificationError,
93    },
94
95    #[error("grant not found")]
96    GrantNotFound,
97
98    #[error("invalid grant {0}")]
99    InvalidGrant(Ulid),
100
101    #[error("refresh token not found")]
102    RefreshTokenNotFound,
103
104    #[error("refresh token {0} is invalid")]
105    RefreshTokenInvalid(Ulid),
106
107    #[error("session {0} is invalid")]
108    SessionInvalid(Ulid),
109
110    #[error("client id mismatch: expected {expected}, got {actual}")]
111    ClientIDMismatch { expected: Ulid, actual: Ulid },
112
113    #[error("policy denied the request: {0}")]
114    DeniedByPolicy(mas_policy::EvaluationResult),
115
116    #[error("unsupported grant type")]
117    UnsupportedGrantType,
118
119    #[error("client {0} is not authorized to use this grant type")]
120    UnauthorizedClient(Ulid),
121
122    #[error("unexpected client {was} (expected {expected})")]
123    UnexptectedClient { was: Ulid, expected: Ulid },
124
125    #[error("failed to load browser session {0}")]
126    NoSuchBrowserSession(Ulid),
127
128    #[error("failed to load oauth session {0}")]
129    NoSuchOAuthSession(Ulid),
130
131    #[error(
132        "failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
133    )]
134    NoSuchNextRefreshToken { next: Ulid, previous: Ulid },
135
136    #[error(
137        "failed to load the access token ({access_token:?}) associated with the next refresh token ({refresh_token:?})"
138    )]
139    NoSuchNextAccessToken {
140        access_token: Ulid,
141        refresh_token: Ulid,
142    },
143
144    #[error("no access token associated with the refresh token {refresh_token:?}")]
145    NoAccessTokenOnRefreshToken { refresh_token: Ulid },
146
147    #[error("device code grant expired")]
148    DeviceCodeExpired,
149
150    #[error("device code grant is still pending")]
151    DeviceCodePending,
152
153    #[error("device code grant was rejected")]
154    DeviceCodeRejected,
155
156    #[error("device code grant was already exchanged")]
157    DeviceCodeExchanged,
158
159    #[error("failed to provision device")]
160    ProvisionDeviceFailed(#[source] anyhow::Error),
161}
162
163impl IntoResponse for RouteError {
164    fn into_response(self) -> axum::response::Response {
165        let sentry_event_id = record_error!(
166            self,
167            Self::Internal(_)
168                | Self::ClientCredentialsVerification { .. }
169                | Self::NoSuchBrowserSession(_)
170                | Self::NoSuchOAuthSession(_)
171                | Self::ProvisionDeviceFailed(_)
172                | Self::NoSuchNextRefreshToken { .. }
173                | Self::NoSuchNextAccessToken { .. }
174                | Self::NoAccessTokenOnRefreshToken { .. }
175        );
176
177        TOKEN_REQUEST_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
178
179        let response = match self {
180            Self::Internal(_)
181            | Self::ClientCredentialsVerification { .. }
182            | Self::NoSuchBrowserSession(_)
183            | Self::NoSuchOAuthSession(_)
184            | Self::ProvisionDeviceFailed(_)
185            | Self::NoSuchNextRefreshToken { .. }
186            | Self::NoSuchNextAccessToken { .. }
187            | Self::NoAccessTokenOnRefreshToken { .. } => (
188                StatusCode::INTERNAL_SERVER_ERROR,
189                Json(ClientError::from(ClientErrorCode::ServerError)),
190            ),
191
192            Self::BadRequest => (
193                StatusCode::BAD_REQUEST,
194                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
195            ),
196
197            Self::PkceVerification(err) => (
198                StatusCode::BAD_REQUEST,
199                Json(
200                    ClientError::from(ClientErrorCode::InvalidGrant)
201                        .with_description(format!("PKCE verification failed: {err}")),
202                ),
203            ),
204
205            Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
206                StatusCode::UNAUTHORIZED,
207                Json(ClientError::from(ClientErrorCode::InvalidClient)),
208            ),
209
210            Self::ClientNotAllowed(_)
211            | Self::UnauthorizedClient(_)
212            | Self::UnexptectedClient { .. } => (
213                StatusCode::UNAUTHORIZED,
214                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
215            ),
216
217            Self::DeniedByPolicy(evaluation) => (
218                StatusCode::FORBIDDEN,
219                Json(
220                    ClientError::from(ClientErrorCode::InvalidScope).with_description(
221                        evaluation
222                            .violations
223                            .into_iter()
224                            .map(|violation| violation.msg)
225                            .collect::<Vec<_>>()
226                            .join(", "),
227                    ),
228                ),
229            ),
230
231            Self::DeviceCodeRejected => (
232                StatusCode::FORBIDDEN,
233                Json(ClientError::from(ClientErrorCode::AccessDenied)),
234            ),
235
236            Self::DeviceCodeExpired => (
237                StatusCode::FORBIDDEN,
238                Json(ClientError::from(ClientErrorCode::ExpiredToken)),
239            ),
240
241            Self::DeviceCodePending => (
242                StatusCode::FORBIDDEN,
243                Json(ClientError::from(ClientErrorCode::AuthorizationPending)),
244            ),
245
246            Self::InvalidGrant(_)
247            | Self::DeviceCodeExchanged
248            | Self::RefreshTokenNotFound
249            | Self::RefreshTokenInvalid(_)
250            | Self::SessionInvalid(_)
251            | Self::ClientIDMismatch { .. }
252            | Self::GrantNotFound => (
253                StatusCode::BAD_REQUEST,
254                Json(ClientError::from(ClientErrorCode::InvalidGrant)),
255            ),
256
257            Self::UnsupportedGrantType => (
258                StatusCode::BAD_REQUEST,
259                Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
260            ),
261        };
262
263        (sentry_event_id, response).into_response()
264    }
265}
266
267impl_from_error_for_route!(mas_i18n::DataError);
268impl_from_error_for_route!(mas_templates::TemplateError);
269impl_from_error_for_route!(mas_storage::RepositoryError);
270impl_from_error_for_route!(mas_policy::EvaluationError);
271impl_from_error_for_route!(super::IdTokenSignatureError);
272
273#[tracing::instrument(
274    name = "handlers.oauth2.token.post",
275    fields(client.id = client_authorization.client_id()),
276    skip_all,
277)]
278pub(crate) async fn post(
279    mut rng: BoxRng,
280    clock: BoxClock,
281    State(http_client): State<reqwest::Client>,
282    State(key_store): State<Keystore>,
283    State(url_builder): State<UrlBuilder>,
284    activity_tracker: BoundActivityTracker,
285    mut repo: BoxRepository,
286    State(homeserver): State<Arc<dyn HomeserverConnection>>,
287    State(site_config): State<SiteConfig>,
288    State(encrypter): State<Encrypter>,
289    State(templates): State<Templates>,
290    policy: Policy,
291    user_agent: Option<TypedHeader<headers::UserAgent>>,
292    client_authorization: ClientAuthorization<AccessTokenRequest>,
293) -> Result<impl IntoResponse, RouteError> {
294    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
295    let client = client_authorization
296        .credentials
297        .fetch(&mut repo)
298        .await?
299        .ok_or(RouteError::ClientNotFound)?;
300
301    let method = client
302        .token_endpoint_auth_method
303        .as_ref()
304        .ok_or(RouteError::ClientNotAllowed(client.id))?;
305
306    client_authorization
307        .credentials
308        .verify(&http_client, &encrypter, method, &client)
309        .await
310        .map_err(|err| {
311            // Classify the error differntly, depending on whether it's an 'internal' error,
312            // or just because the client presented invalid credentials.
313            if err.is_internal() {
314                RouteError::ClientCredentialsVerification {
315                    client_id: client.id,
316                    source: err,
317                }
318            } else {
319                RouteError::InvalidClientCredentials {
320                    client_id: client.id,
321                    source: err,
322                }
323            }
324        })?;
325
326    let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
327
328    let grant_type = form.grant_type();
329
330    let (reply, repo) = match form {
331        AccessTokenRequest::AuthorizationCode(grant) => {
332            authorization_code_grant(
333                &mut rng,
334                &clock,
335                &activity_tracker,
336                &grant,
337                &client,
338                &key_store,
339                &url_builder,
340                &site_config,
341                repo,
342                &homeserver,
343                &templates,
344                user_agent,
345            )
346            .await?
347        }
348        AccessTokenRequest::RefreshToken(grant) => {
349            refresh_token_grant(
350                &mut rng,
351                &clock,
352                &activity_tracker,
353                &grant,
354                &client,
355                &site_config,
356                repo,
357                user_agent,
358            )
359            .await?
360        }
361        AccessTokenRequest::ClientCredentials(grant) => {
362            client_credentials_grant(
363                &mut rng,
364                &clock,
365                &activity_tracker,
366                &grant,
367                &client,
368                &site_config,
369                repo,
370                policy,
371                user_agent,
372            )
373            .await?
374        }
375        AccessTokenRequest::DeviceCode(grant) => {
376            device_code_grant(
377                &mut rng,
378                &clock,
379                &activity_tracker,
380                &grant,
381                &client,
382                &key_store,
383                &url_builder,
384                &site_config,
385                repo,
386                &homeserver,
387                user_agent,
388            )
389            .await?
390        }
391        _ => {
392            return Err(RouteError::UnsupportedGrantType);
393        }
394    };
395
396    repo.save().await?;
397
398    TOKEN_REQUEST_COUNTER.add(
399        1,
400        &[
401            KeyValue::new(GRANT_TYPE, grant_type),
402            KeyValue::new(RESULT, "success"),
403        ],
404    );
405
406    let mut headers = HeaderMap::new();
407    headers.typed_insert(CacheControl::new().with_no_store());
408    headers.typed_insert(Pragma::no_cache());
409
410    Ok((headers, Json(reply)))
411}
412
413async fn authorization_code_grant(
414    mut rng: &mut BoxRng,
415    clock: &impl Clock,
416    activity_tracker: &BoundActivityTracker,
417    grant: &AuthorizationCodeGrant,
418    client: &Client,
419    key_store: &Keystore,
420    url_builder: &UrlBuilder,
421    site_config: &SiteConfig,
422    mut repo: BoxRepository,
423    homeserver: &Arc<dyn HomeserverConnection>,
424    templates: &Templates,
425    user_agent: Option<String>,
426) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
427    // Check that the client is allowed to use this grant type
428    if !client.grant_types.contains(&GrantType::AuthorizationCode) {
429        return Err(RouteError::UnauthorizedClient(client.id));
430    }
431
432    let authz_grant = repo
433        .oauth2_authorization_grant()
434        .find_by_code(&grant.code)
435        .await?
436        .ok_or(RouteError::GrantNotFound)?;
437
438    let now = clock.now();
439
440    let session_id = match authz_grant.stage {
441        AuthorizationGrantStage::Cancelled { cancelled_at } => {
442            debug!(%cancelled_at, "Authorization grant was cancelled");
443            return Err(RouteError::InvalidGrant(authz_grant.id));
444        }
445        AuthorizationGrantStage::Exchanged {
446            exchanged_at,
447            fulfilled_at,
448            session_id,
449        } => {
450            warn!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
451
452            // Ending the session if the token was already exchanged more than 20s ago
453            if now - exchanged_at > Duration::microseconds(20 * 1000 * 1000) {
454                warn!(oauth_session.id = %session_id, "Ending potentially compromised session");
455                let session = repo
456                    .oauth2_session()
457                    .lookup(session_id)
458                    .await?
459                    .ok_or(RouteError::NoSuchOAuthSession(session_id))?;
460
461                //if !session.is_finished() {
462                repo.oauth2_session().finish(clock, session).await?;
463                repo.save().await?;
464                //}
465            }
466
467            return Err(RouteError::InvalidGrant(authz_grant.id));
468        }
469        AuthorizationGrantStage::Pending => {
470            warn!("Authorization grant has not been fulfilled yet");
471            return Err(RouteError::InvalidGrant(authz_grant.id));
472        }
473        AuthorizationGrantStage::Fulfilled {
474            session_id,
475            fulfilled_at,
476        } => {
477            if now - fulfilled_at > Duration::microseconds(10 * 60 * 1000 * 1000) {
478                warn!("Code exchange took more than 10 minutes");
479                return Err(RouteError::InvalidGrant(authz_grant.id));
480            }
481
482            session_id
483        }
484    };
485
486    let mut session = repo
487        .oauth2_session()
488        .lookup(session_id)
489        .await?
490        .ok_or(RouteError::NoSuchOAuthSession(session_id))?;
491
492    // Generate a device name
493    let lang: DataLocale = authz_grant.locale.as_deref().unwrap_or("en").parse()?;
494    let ctx = DeviceNameContext::new(client.clone(), user_agent.clone()).with_language(lang);
495    let device_name = templates.render_device_name(&ctx)?;
496
497    if let Some(user_agent) = user_agent {
498        session = repo
499            .oauth2_session()
500            .record_user_agent(session, user_agent)
501            .await?;
502    }
503
504    // This should never happen, since we looked up in the database using the code
505    let code = authz_grant
506        .code
507        .as_ref()
508        .ok_or(RouteError::InvalidGrant(authz_grant.id))?;
509
510    if client.id != session.client_id {
511        return Err(RouteError::UnexptectedClient {
512            was: client.id,
513            expected: session.client_id,
514        });
515    }
516
517    match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
518        (None, None) => {}
519        // We have a challenge but no verifier (or vice-versa)? Bad request.
520        (Some(_), None) | (None, Some(_)) => return Err(RouteError::BadRequest),
521        // If we have both, we need to check the code validity
522        (Some(pkce), Some(verifier)) => {
523            pkce.verify(verifier)?;
524        }
525    }
526
527    let Some(user_session_id) = session.user_session_id else {
528        tracing::warn!("No user session associated with this OAuth2 session");
529        return Err(RouteError::InvalidGrant(authz_grant.id));
530    };
531
532    let browser_session = repo
533        .browser_session()
534        .lookup(user_session_id)
535        .await?
536        .ok_or(RouteError::NoSuchBrowserSession(user_session_id))?;
537
538    let last_authentication = repo
539        .browser_session()
540        .get_last_authentication(&browser_session)
541        .await?;
542
543    let ttl = site_config.access_token_ttl;
544    let (access_token, refresh_token) =
545        generate_token_pair(&mut rng, clock, &mut repo, &session, ttl).await?;
546
547    let id_token = if session.scope.contains(&scope::OPENID) {
548        Some(generate_id_token(
549            &mut rng,
550            clock,
551            url_builder,
552            key_store,
553            client,
554            Some(&authz_grant),
555            &browser_session,
556            Some(&access_token),
557            last_authentication.as_ref(),
558        )?)
559    } else {
560        None
561    };
562
563    let mut params = AccessTokenResponse::new(access_token.access_token)
564        .with_expires_in(ttl)
565        .with_refresh_token(refresh_token.refresh_token)
566        .with_scope(session.scope.clone());
567
568    if let Some(id_token) = id_token {
569        params = params.with_id_token(id_token);
570    }
571
572    // Lock the user sync to make sure we don't get into a race condition
573    repo.user()
574        .acquire_lock_for_sync(&browser_session.user)
575        .await?;
576
577    // Look for device to provision
578    for scope in &*session.scope {
579        if let Some(device) = Device::from_scope_token(scope) {
580            homeserver
581                .upsert_device(
582                    &browser_session.user.username,
583                    device.as_str(),
584                    Some(&device_name),
585                )
586                .await
587                .map_err(RouteError::ProvisionDeviceFailed)?;
588        }
589    }
590
591    repo.oauth2_authorization_grant()
592        .exchange(clock, authz_grant)
593        .await?;
594
595    // XXX: there is a potential (but unlikely) race here, where the activity for
596    // the session is recorded before the transaction is committed. We would have to
597    // save the repository here to fix that.
598    activity_tracker
599        .record_oauth2_session(clock, &session)
600        .await;
601
602    Ok((params, repo))
603}
604
605async fn refresh_token_grant(
606    rng: &mut BoxRng,
607    clock: &impl Clock,
608    activity_tracker: &BoundActivityTracker,
609    grant: &RefreshTokenGrant,
610    client: &Client,
611    site_config: &SiteConfig,
612    mut repo: BoxRepository,
613    user_agent: Option<String>,
614) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
615    // Check that the client is allowed to use this grant type
616    if !client.grant_types.contains(&GrantType::RefreshToken) {
617        return Err(RouteError::UnauthorizedClient(client.id));
618    }
619
620    let refresh_token = repo
621        .oauth2_refresh_token()
622        .find_by_token(&grant.refresh_token)
623        .await?
624        .ok_or(RouteError::RefreshTokenNotFound)?;
625
626    let mut session = repo
627        .oauth2_session()
628        .lookup(refresh_token.session_id)
629        .await?
630        .ok_or(RouteError::NoSuchOAuthSession(refresh_token.session_id))?;
631
632    // Let's for now record the user agent on each refresh, that should be
633    // responsive enough and not too much of a burden on the database.
634    if let Some(user_agent) = user_agent {
635        session = repo
636            .oauth2_session()
637            .record_user_agent(session, user_agent)
638            .await?;
639    }
640
641    if !session.is_valid() {
642        return Err(RouteError::SessionInvalid(session.id));
643    }
644
645    if client.id != session.client_id {
646        // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
647        return Err(RouteError::ClientIDMismatch {
648            expected: session.client_id,
649            actual: client.id,
650        });
651    }
652
653    if !refresh_token.is_valid() {
654        // We're seing a refresh token that already has been consumed, this might be a
655        // double-refresh or a replay attack
656
657        // First, get the next refresh token
658        let Some(next_refresh_token_id) = refresh_token.next_refresh_token_id() else {
659            // If we don't have a 'next' refresh token, it may just be because this was
660            // before we were recording those. Let's just treat it as a replay.
661            return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
662        };
663
664        let Some(next_refresh_token) = repo
665            .oauth2_refresh_token()
666            .lookup(next_refresh_token_id)
667            .await?
668        else {
669            return Err(RouteError::NoSuchNextRefreshToken {
670                next: next_refresh_token_id,
671                previous: refresh_token.id,
672            });
673        };
674
675        // Check if the next refresh token was already consumed or not
676        if !next_refresh_token.is_valid() {
677            // XXX: This is a replay, we *may* want to invalidate the session
678            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
679        }
680
681        // Check if the associated access token was already used
682        let Some(access_token_id) = next_refresh_token.access_token_id else {
683            // This should in theory not happen: this means an access token got cleaned up,
684            // but the refresh token was still valid.
685            return Err(RouteError::NoAccessTokenOnRefreshToken {
686                refresh_token: next_refresh_token.id,
687            });
688        };
689
690        // Load it
691        let next_access_token = repo
692            .oauth2_access_token()
693            .lookup(access_token_id)
694            .await?
695            .ok_or(RouteError::NoSuchNextAccessToken {
696                access_token: access_token_id,
697                refresh_token: next_refresh_token_id,
698            })?;
699
700        if next_access_token.is_used() {
701            // XXX: This is a replay, we *may* want to invalidate the session
702            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
703        }
704
705        // Looks like it's a double-refresh, client lost their refresh token on
706        // the way back. Let's revoke the unused access and refresh tokens, and
707        // issue new ones
708        info!(
709            oauth2_session.id = %session.id,
710            oauth2_client.id = %client.id,
711            %refresh_token.id,
712            "Refresh token already used, but issued refresh and access tokens are unused. Assuming those were lost; revoking those and reissuing new ones."
713        );
714
715        repo.oauth2_access_token()
716            .revoke(clock, next_access_token)
717            .await?;
718
719        repo.oauth2_refresh_token()
720            .revoke(clock, next_refresh_token)
721            .await?;
722    }
723
724    activity_tracker
725        .record_oauth2_session(clock, &session)
726        .await;
727
728    let ttl = site_config.access_token_ttl;
729    let (new_access_token, new_refresh_token) =
730        generate_token_pair(rng, clock, &mut repo, &session, ttl).await?;
731
732    let refresh_token = repo
733        .oauth2_refresh_token()
734        .consume(clock, refresh_token, &new_refresh_token)
735        .await?;
736
737    if let Some(access_token_id) = refresh_token.access_token_id {
738        let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
739        if let Some(access_token) = access_token {
740            // If it is a double-refresh, it might already be revoked
741            if !access_token.state.is_revoked() {
742                repo.oauth2_access_token()
743                    .revoke(clock, access_token)
744                    .await?;
745            }
746        }
747    }
748
749    let params = AccessTokenResponse::new(new_access_token.access_token)
750        .with_expires_in(ttl)
751        .with_refresh_token(new_refresh_token.refresh_token)
752        .with_scope(session.scope);
753
754    Ok((params, repo))
755}
756
757async fn client_credentials_grant(
758    rng: &mut BoxRng,
759    clock: &impl Clock,
760    activity_tracker: &BoundActivityTracker,
761    grant: &ClientCredentialsGrant,
762    client: &Client,
763    site_config: &SiteConfig,
764    mut repo: BoxRepository,
765    mut policy: Policy,
766    user_agent: Option<String>,
767) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
768    // Check that the client is allowed to use this grant type
769    if !client.grant_types.contains(&GrantType::ClientCredentials) {
770        return Err(RouteError::UnauthorizedClient(client.id));
771    }
772
773    // Default to an empty scope if none is provided
774    let scope = grant
775        .scope
776        .clone()
777        .unwrap_or_else(|| std::iter::empty::<ScopeToken>().collect());
778
779    // Make the request go through the policy engine
780    let res = policy
781        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
782            user: None,
783            client,
784            scope: &scope,
785            grant_type: mas_policy::GrantType::ClientCredentials,
786            requester: mas_policy::Requester {
787                ip_address: activity_tracker.ip(),
788                user_agent: user_agent.clone(),
789            },
790        })
791        .await?;
792    if !res.valid() {
793        return Err(RouteError::DeniedByPolicy(res));
794    }
795
796    // Start the session
797    let mut session = repo
798        .oauth2_session()
799        .add_from_client_credentials(rng, clock, client, scope)
800        .await?;
801
802    if let Some(user_agent) = user_agent {
803        session = repo
804            .oauth2_session()
805            .record_user_agent(session, user_agent)
806            .await?;
807    }
808
809    let ttl = site_config.access_token_ttl;
810    let access_token_str = TokenType::AccessToken.generate(rng);
811
812    let access_token = repo
813        .oauth2_access_token()
814        .add(rng, clock, &session, access_token_str, Some(ttl))
815        .await?;
816
817    let mut params = AccessTokenResponse::new(access_token.access_token).with_expires_in(ttl);
818
819    // XXX: there is a potential (but unlikely) race here, where the activity for
820    // the session is recorded before the transaction is committed. We would have to
821    // save the repository here to fix that.
822    activity_tracker
823        .record_oauth2_session(clock, &session)
824        .await;
825
826    if !session.scope.is_empty() {
827        // We only return the scope if it's not empty
828        params = params.with_scope(session.scope);
829    }
830
831    Ok((params, repo))
832}
833
834async fn device_code_grant(
835    rng: &mut BoxRng,
836    clock: &impl Clock,
837    activity_tracker: &BoundActivityTracker,
838    grant: &DeviceCodeGrant,
839    client: &Client,
840    key_store: &Keystore,
841    url_builder: &UrlBuilder,
842    site_config: &SiteConfig,
843    mut repo: BoxRepository,
844    homeserver: &Arc<dyn HomeserverConnection>,
845    user_agent: Option<String>,
846) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
847    // Check that the client is allowed to use this grant type
848    if !client.grant_types.contains(&GrantType::DeviceCode) {
849        return Err(RouteError::UnauthorizedClient(client.id));
850    }
851
852    let grant = repo
853        .oauth2_device_code_grant()
854        .find_by_device_code(&grant.device_code)
855        .await?
856        .ok_or(RouteError::GrantNotFound)?;
857
858    // Check that the client match
859    if client.id != grant.client_id {
860        return Err(RouteError::ClientIDMismatch {
861            expected: grant.client_id,
862            actual: client.id,
863        });
864    }
865
866    if grant.expires_at < clock.now() {
867        return Err(RouteError::DeviceCodeExpired);
868    }
869
870    let browser_session_id = match &grant.state {
871        DeviceCodeGrantState::Pending => {
872            return Err(RouteError::DeviceCodePending);
873        }
874        DeviceCodeGrantState::Rejected { .. } => {
875            return Err(RouteError::DeviceCodeRejected);
876        }
877        DeviceCodeGrantState::Exchanged { .. } => {
878            return Err(RouteError::DeviceCodeExchanged);
879        }
880        DeviceCodeGrantState::Fulfilled {
881            browser_session_id, ..
882        } => *browser_session_id,
883    };
884
885    let browser_session = repo
886        .browser_session()
887        .lookup(browser_session_id)
888        .await?
889        .ok_or(RouteError::NoSuchBrowserSession(browser_session_id))?;
890
891    // Start the session
892    let mut session = repo
893        .oauth2_session()
894        .add_from_browser_session(rng, clock, client, &browser_session, grant.scope.clone())
895        .await?;
896
897    repo.oauth2_device_code_grant()
898        .exchange(clock, grant, &session)
899        .await?;
900
901    // XXX: should we get the user agent from the device code grant instead?
902    if let Some(user_agent) = user_agent {
903        session = repo
904            .oauth2_session()
905            .record_user_agent(session, user_agent)
906            .await?;
907    }
908
909    let ttl = site_config.access_token_ttl;
910    let access_token_str = TokenType::AccessToken.generate(rng);
911
912    let access_token = repo
913        .oauth2_access_token()
914        .add(rng, clock, &session, access_token_str, Some(ttl))
915        .await?;
916
917    let mut params =
918        AccessTokenResponse::new(access_token.access_token.clone()).with_expires_in(ttl);
919
920    // If the client uses the refresh token grant type, we also generate a refresh
921    // token
922    if client.grant_types.contains(&GrantType::RefreshToken) {
923        let refresh_token_str = TokenType::RefreshToken.generate(rng);
924
925        let refresh_token = repo
926            .oauth2_refresh_token()
927            .add(rng, clock, &session, &access_token, refresh_token_str)
928            .await?;
929
930        params = params.with_refresh_token(refresh_token.refresh_token);
931    }
932
933    // If the client asked for an ID token, we generate one
934    if session.scope.contains(&scope::OPENID) {
935        let id_token = generate_id_token(
936            rng,
937            clock,
938            url_builder,
939            key_store,
940            client,
941            None,
942            &browser_session,
943            Some(&access_token),
944            None,
945        )?;
946
947        params = params.with_id_token(id_token);
948    }
949
950    // Lock the user sync to make sure we don't get into a race condition
951    repo.user()
952        .acquire_lock_for_sync(&browser_session.user)
953        .await?;
954
955    // Look for device to provision
956    for scope in &*session.scope {
957        if let Some(device) = Device::from_scope_token(scope) {
958            homeserver
959                .upsert_device(&browser_session.user.username, device.as_str(), None)
960                .await
961                .map_err(RouteError::ProvisionDeviceFailed)?;
962        }
963    }
964
965    // XXX: there is a potential (but unlikely) race here, where the activity for
966    // the session is recorded before the transaction is committed. We would have to
967    // save the repository here to fix that.
968    activity_tracker
969        .record_oauth2_session(clock, &session)
970        .await;
971
972    if !session.scope.is_empty() {
973        // We only return the scope if it's not empty
974        params = params.with_scope(session.scope);
975    }
976
977    Ok((params, repo))
978}
979
980#[cfg(test)]
981mod tests {
982    use hyper::Request;
983    use mas_data_model::{AccessToken, AuthorizationCode, RefreshToken};
984    use mas_router::SimpleRoute;
985    use oauth2_types::{
986        registration::ClientRegistrationResponse,
987        requests::{DeviceAuthorizationResponse, ResponseMode},
988        scope::{OPENID, Scope},
989    };
990    use sqlx::PgPool;
991
992    use super::*;
993    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
994
995    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
996    async fn test_auth_code_grant(pool: PgPool) {
997        setup();
998        let state = TestState::from_pool(pool).await.unwrap();
999
1000        // Provision a client
1001        let request =
1002            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1003                "client_uri": "https://example.com/",
1004                "redirect_uris": ["https://example.com/callback"],
1005                "token_endpoint_auth_method": "none",
1006                "response_types": ["code"],
1007                "grant_types": ["authorization_code"],
1008            }));
1009
1010        let response = state.request(request).await;
1011        response.assert_status(StatusCode::CREATED);
1012
1013        let ClientRegistrationResponse { client_id, .. } = response.json();
1014
1015        // Let's provision a user and create a session for them. This part is hard to
1016        // test with just HTTP requests, so we'll use the repository directly.
1017        let mut repo = state.repository().await.unwrap();
1018
1019        let user = repo
1020            .user()
1021            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1022            .await
1023            .unwrap();
1024
1025        let browser_session = repo
1026            .browser_session()
1027            .add(&mut state.rng(), &state.clock, &user, None)
1028            .await
1029            .unwrap();
1030
1031        // Lookup the client in the database.
1032        let client = repo
1033            .oauth2_client()
1034            .find_by_client_id(&client_id)
1035            .await
1036            .unwrap()
1037            .unwrap();
1038
1039        // Start a grant
1040        let code = "thisisaverysecurecode";
1041        let grant = repo
1042            .oauth2_authorization_grant()
1043            .add(
1044                &mut state.rng(),
1045                &state.clock,
1046                &client,
1047                "https://example.com/redirect".parse().unwrap(),
1048                Scope::from_iter([OPENID]),
1049                Some(AuthorizationCode {
1050                    code: code.to_owned(),
1051                    pkce: None,
1052                }),
1053                Some("state".to_owned()),
1054                Some("nonce".to_owned()),
1055                ResponseMode::Query,
1056                false,
1057                None,
1058                None,
1059            )
1060            .await
1061            .unwrap();
1062
1063        let session = repo
1064            .oauth2_session()
1065            .add_from_browser_session(
1066                &mut state.rng(),
1067                &state.clock,
1068                &client,
1069                &browser_session,
1070                grant.scope.clone(),
1071            )
1072            .await
1073            .unwrap();
1074
1075        // And fulfill it
1076        let grant = repo
1077            .oauth2_authorization_grant()
1078            .fulfill(&state.clock, &session, grant)
1079            .await
1080            .unwrap();
1081
1082        repo.save().await.unwrap();
1083
1084        // Now call the token endpoint to get an access token.
1085        let request =
1086            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1087                "grant_type": "authorization_code",
1088                "code": code,
1089                "redirect_uri": grant.redirect_uri,
1090                "client_id": client.client_id,
1091            }));
1092
1093        let response = state.request(request).await;
1094        response.assert_status(StatusCode::OK);
1095
1096        let AccessTokenResponse { access_token, .. } = response.json();
1097
1098        // Check that the token is valid
1099        assert!(state.is_access_token_valid(&access_token).await);
1100
1101        // Exchange it again, this it should fail
1102        let request =
1103            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1104                "grant_type": "authorization_code",
1105                "code": code,
1106                "redirect_uri": grant.redirect_uri,
1107                "client_id": client.client_id,
1108            }));
1109
1110        let response = state.request(request).await;
1111        response.assert_status(StatusCode::BAD_REQUEST);
1112        let error: ClientError = response.json();
1113        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1114
1115        // The token should still be valid
1116        assert!(state.is_access_token_valid(&access_token).await);
1117
1118        // Now wait a bit
1119        state.clock.advance(Duration::try_minutes(1).unwrap());
1120
1121        // Exchange it again, this it should fail
1122        let request =
1123            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1124                "grant_type": "authorization_code",
1125                "code": code,
1126                "redirect_uri": grant.redirect_uri,
1127                "client_id": client.client_id,
1128            }));
1129
1130        let response = state.request(request).await;
1131        response.assert_status(StatusCode::BAD_REQUEST);
1132        let error: ClientError = response.json();
1133        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1134
1135        // And it should have revoked the token we got
1136        assert!(!state.is_access_token_valid(&access_token).await);
1137
1138        // Try another one and wait for too long before exchanging it
1139        let mut repo = state.repository().await.unwrap();
1140        let code = "thisisanothercode";
1141        let grant = repo
1142            .oauth2_authorization_grant()
1143            .add(
1144                &mut state.rng(),
1145                &state.clock,
1146                &client,
1147                "https://example.com/redirect".parse().unwrap(),
1148                Scope::from_iter([OPENID]),
1149                Some(AuthorizationCode {
1150                    code: code.to_owned(),
1151                    pkce: None,
1152                }),
1153                Some("state".to_owned()),
1154                Some("nonce".to_owned()),
1155                ResponseMode::Query,
1156                false,
1157                None,
1158                None,
1159            )
1160            .await
1161            .unwrap();
1162
1163        let session = repo
1164            .oauth2_session()
1165            .add_from_browser_session(
1166                &mut state.rng(),
1167                &state.clock,
1168                &client,
1169                &browser_session,
1170                grant.scope.clone(),
1171            )
1172            .await
1173            .unwrap();
1174
1175        // And fulfill it
1176        let grant = repo
1177            .oauth2_authorization_grant()
1178            .fulfill(&state.clock, &session, grant)
1179            .await
1180            .unwrap();
1181
1182        repo.save().await.unwrap();
1183
1184        // Now wait a bit
1185        state
1186            .clock
1187            .advance(Duration::microseconds(15 * 60 * 1000 * 1000));
1188
1189        // Exchange it, it should fail
1190        let request =
1191            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1192                "grant_type": "authorization_code",
1193                "code": code,
1194                "redirect_uri": grant.redirect_uri,
1195                "client_id": client.client_id,
1196            }));
1197
1198        let response = state.request(request).await;
1199        response.assert_status(StatusCode::BAD_REQUEST);
1200        let ClientError { error, .. } = response.json();
1201        assert_eq!(error, ClientErrorCode::InvalidGrant);
1202    }
1203
1204    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1205    async fn test_refresh_token_grant(pool: PgPool) {
1206        setup();
1207        let state = TestState::from_pool(pool).await.unwrap();
1208
1209        // Provision a client
1210        let request =
1211            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1212                "client_uri": "https://example.com/",
1213                "redirect_uris": ["https://example.com/callback"],
1214                "token_endpoint_auth_method": "none",
1215                "response_types": ["code"],
1216                "grant_types": ["authorization_code", "refresh_token"],
1217            }));
1218
1219        let response = state.request(request).await;
1220        response.assert_status(StatusCode::CREATED);
1221
1222        let ClientRegistrationResponse { client_id, .. } = response.json();
1223
1224        // Let's provision a user and create a session for them. This part is hard to
1225        // test with just HTTP requests, so we'll use the repository directly.
1226        let mut repo = state.repository().await.unwrap();
1227
1228        let user = repo
1229            .user()
1230            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1231            .await
1232            .unwrap();
1233
1234        let browser_session = repo
1235            .browser_session()
1236            .add(&mut state.rng(), &state.clock, &user, None)
1237            .await
1238            .unwrap();
1239
1240        // Lookup the client in the database.
1241        let client = repo
1242            .oauth2_client()
1243            .find_by_client_id(&client_id)
1244            .await
1245            .unwrap()
1246            .unwrap();
1247
1248        // Get a token pair
1249        let session = repo
1250            .oauth2_session()
1251            .add_from_browser_session(
1252                &mut state.rng(),
1253                &state.clock,
1254                &client,
1255                &browser_session,
1256                Scope::from_iter([OPENID]),
1257            )
1258            .await
1259            .unwrap();
1260
1261        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1262            generate_token_pair(
1263                &mut state.rng(),
1264                &state.clock,
1265                &mut repo,
1266                &session,
1267                Duration::microseconds(5 * 60 * 1000 * 1000),
1268            )
1269            .await
1270            .unwrap();
1271
1272        repo.save().await.unwrap();
1273
1274        // First check that the token is valid
1275        assert!(state.is_access_token_valid(&access_token).await);
1276
1277        // Now call the token endpoint to get an access token.
1278        let request =
1279            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1280                "grant_type": "refresh_token",
1281                "refresh_token": refresh_token,
1282                "client_id": client.client_id,
1283            }));
1284
1285        let response = state.request(request).await;
1286        response.assert_status(StatusCode::OK);
1287
1288        let old_access_token = access_token;
1289        let old_refresh_token = refresh_token;
1290        let response: AccessTokenResponse = response.json();
1291        let access_token = response.access_token;
1292        let refresh_token = response.refresh_token.expect("to have a refresh token");
1293
1294        // Check that the new token is valid
1295        assert!(state.is_access_token_valid(&access_token).await);
1296
1297        // Check that the old token is no longer valid
1298        assert!(!state.is_access_token_valid(&old_access_token).await);
1299
1300        // Call it again with the old token, it should fail
1301        let request =
1302            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1303                "grant_type": "refresh_token",
1304                "refresh_token": old_refresh_token,
1305                "client_id": client.client_id,
1306            }));
1307
1308        let response = state.request(request).await;
1309        response.assert_status(StatusCode::BAD_REQUEST);
1310        let ClientError { error, .. } = response.json();
1311        assert_eq!(error, ClientErrorCode::InvalidGrant);
1312
1313        // Call it again with the new token, it should work
1314        let request =
1315            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1316                "grant_type": "refresh_token",
1317                "refresh_token": refresh_token,
1318                "client_id": client.client_id,
1319            }));
1320
1321        let response = state.request(request).await;
1322        response.assert_status(StatusCode::OK);
1323        let _: AccessTokenResponse = response.json();
1324    }
1325
1326    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1327    async fn test_double_refresh(pool: PgPool) {
1328        setup();
1329        let state = TestState::from_pool(pool).await.unwrap();
1330
1331        // Provision a client
1332        let request =
1333            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1334                "client_uri": "https://example.com/",
1335                "redirect_uris": ["https://example.com/callback"],
1336                "token_endpoint_auth_method": "none",
1337                "response_types": ["code"],
1338                "grant_types": ["authorization_code", "refresh_token"],
1339            }));
1340
1341        let response = state.request(request).await;
1342        response.assert_status(StatusCode::CREATED);
1343
1344        let ClientRegistrationResponse { client_id, .. } = response.json();
1345
1346        // Let's provision a user and create a session for them. This part is hard to
1347        // test with just HTTP requests, so we'll use the repository directly.
1348        let mut repo = state.repository().await.unwrap();
1349
1350        let user = repo
1351            .user()
1352            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1353            .await
1354            .unwrap();
1355
1356        let browser_session = repo
1357            .browser_session()
1358            .add(&mut state.rng(), &state.clock, &user, None)
1359            .await
1360            .unwrap();
1361
1362        // Lookup the client in the database.
1363        let client = repo
1364            .oauth2_client()
1365            .find_by_client_id(&client_id)
1366            .await
1367            .unwrap()
1368            .unwrap();
1369
1370        // Get a token pair
1371        let session = repo
1372            .oauth2_session()
1373            .add_from_browser_session(
1374                &mut state.rng(),
1375                &state.clock,
1376                &client,
1377                &browser_session,
1378                Scope::from_iter([OPENID]),
1379            )
1380            .await
1381            .unwrap();
1382
1383        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1384            generate_token_pair(
1385                &mut state.rng(),
1386                &state.clock,
1387                &mut repo,
1388                &session,
1389                Duration::microseconds(5 * 60 * 1000 * 1000),
1390            )
1391            .await
1392            .unwrap();
1393
1394        repo.save().await.unwrap();
1395
1396        // First check that the token is valid
1397        assert!(state.is_access_token_valid(&access_token).await);
1398
1399        // Now call the token endpoint to get an access token.
1400        let request =
1401            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1402                "grant_type": "refresh_token",
1403                "refresh_token": refresh_token,
1404                "client_id": client.client_id,
1405            }));
1406
1407        let first_response = state.request(request).await;
1408        first_response.assert_status(StatusCode::OK);
1409        let first_response: AccessTokenResponse = first_response.json();
1410
1411        // Call a second time, it should work, as we haven't done anything yet with the
1412        // token
1413        let request =
1414            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1415                "grant_type": "refresh_token",
1416                "refresh_token": refresh_token,
1417                "client_id": client.client_id,
1418            }));
1419
1420        let second_response = state.request(request).await;
1421        second_response.assert_status(StatusCode::OK);
1422        let second_response: AccessTokenResponse = second_response.json();
1423
1424        // Check that we got new tokens
1425        assert_ne!(first_response.access_token, second_response.access_token);
1426        assert_ne!(first_response.refresh_token, second_response.refresh_token);
1427
1428        // Check that the old-new token is invalid
1429        assert!(
1430            !state
1431                .is_access_token_valid(&first_response.access_token)
1432                .await
1433        );
1434
1435        // Check that the new-new token is valid
1436        assert!(
1437            state
1438                .is_access_token_valid(&second_response.access_token)
1439                .await
1440        );
1441
1442        // Do a third refresh, this one should not work, as we've used the new
1443        // access token
1444        let request =
1445            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1446                "grant_type": "refresh_token",
1447                "refresh_token": refresh_token,
1448                "client_id": client.client_id,
1449            }));
1450
1451        let third_response = state.request(request).await;
1452        third_response.assert_status(StatusCode::BAD_REQUEST);
1453
1454        // The other reason we consider a new refresh token to be 'used' is if
1455        // it was already used in a refresh
1456        // So, if we do a refresh with the second_response.refresh_token, then
1457        // another refresh with the result, redoing one with
1458        // second_response.refresh_token again should fail
1459        let request =
1460            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1461                "grant_type": "refresh_token",
1462                "refresh_token": second_response.refresh_token,
1463                "client_id": client.client_id,
1464            }));
1465
1466        // This one is fine
1467        let fourth_response = state.request(request).await;
1468        fourth_response.assert_status(StatusCode::OK);
1469        let fourth_response: AccessTokenResponse = fourth_response.json();
1470
1471        // Do another one, it should be fine as well
1472        let request =
1473            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1474                "grant_type": "refresh_token",
1475                "refresh_token": fourth_response.refresh_token,
1476                "client_id": client.client_id,
1477            }));
1478
1479        let fifth_response = state.request(request).await;
1480        fifth_response.assert_status(StatusCode::OK);
1481
1482        // But now, if we re-do with the second_response.refresh_token, it should
1483        // fail
1484        let request =
1485            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1486                "grant_type": "refresh_token",
1487                "refresh_token": second_response.refresh_token,
1488                "client_id": client.client_id,
1489            }));
1490
1491        let sixth_response = state.request(request).await;
1492        sixth_response.assert_status(StatusCode::BAD_REQUEST);
1493    }
1494
1495    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1496    async fn test_client_credentials(pool: PgPool) {
1497        setup();
1498        let state = TestState::from_pool(pool).await.unwrap();
1499
1500        // Provision a client
1501        let request =
1502            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1503                "client_uri": "https://example.com/",
1504                "token_endpoint_auth_method": "client_secret_post",
1505                "grant_types": ["client_credentials"],
1506            }));
1507
1508        let response = state.request(request).await;
1509        response.assert_status(StatusCode::CREATED);
1510
1511        let response: ClientRegistrationResponse = response.json();
1512        let client_id = response.client_id;
1513        let client_secret = response.client_secret.expect("to have a client secret");
1514
1515        // Call the token endpoint with an empty scope
1516        let request =
1517            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1518                "grant_type": "client_credentials",
1519                "client_id": client_id,
1520                "client_secret": client_secret,
1521            }));
1522
1523        let response = state.request(request).await;
1524        response.assert_status(StatusCode::OK);
1525
1526        let response: AccessTokenResponse = response.json();
1527        assert!(response.refresh_token.is_none());
1528        assert!(response.expires_in.is_some());
1529        assert!(response.scope.is_none());
1530
1531        // Revoke the token
1532        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1533            "token": response.access_token,
1534            "client_id": client_id,
1535            "client_secret": client_secret,
1536        }));
1537
1538        let response = state.request(request).await;
1539        response.assert_status(StatusCode::OK);
1540
1541        // We should be allowed to ask for the GraphQL API scope
1542        let request =
1543            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1544                "grant_type": "client_credentials",
1545                "client_id": client_id,
1546                "client_secret": client_secret,
1547                "scope": "urn:mas:graphql:*"
1548            }));
1549
1550        let response = state.request(request).await;
1551        response.assert_status(StatusCode::OK);
1552
1553        let response: AccessTokenResponse = response.json();
1554        assert!(response.refresh_token.is_none());
1555        assert!(response.expires_in.is_some());
1556        assert_eq!(response.scope, Some("urn:mas:graphql:*".parse().unwrap()));
1557
1558        // Revoke the token
1559        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1560            "token": response.access_token,
1561            "client_id": client_id,
1562            "client_secret": client_secret,
1563        }));
1564
1565        let response = state.request(request).await;
1566        response.assert_status(StatusCode::OK);
1567
1568        // We should be NOT allowed to ask for the MAS admin scope
1569        let request =
1570            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1571                "grant_type": "client_credentials",
1572                "client_id": client_id,
1573                "client_secret": client_secret,
1574                "scope": "urn:mas:admin"
1575            }));
1576
1577        let response = state.request(request).await;
1578        response.assert_status(StatusCode::FORBIDDEN);
1579
1580        let ClientError { error, .. } = response.json();
1581        assert_eq!(error, ClientErrorCode::InvalidScope);
1582
1583        // Now, if we add the client to the admin list in the policy, it should work
1584        let state = {
1585            let mut state = state;
1586            state.policy_factory = crate::test_utils::policy_factory(
1587                "example.com",
1588                serde_json::json!({
1589                    "admin_clients": [client_id]
1590                }),
1591            )
1592            .await
1593            .unwrap();
1594            state
1595        };
1596
1597        let request =
1598            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1599                "grant_type": "client_credentials",
1600                "client_id": client_id,
1601                "client_secret": client_secret,
1602                "scope": "urn:mas:admin"
1603            }));
1604
1605        let response = state.request(request).await;
1606        response.assert_status(StatusCode::OK);
1607
1608        let response: AccessTokenResponse = response.json();
1609        assert!(response.refresh_token.is_none());
1610        assert!(response.expires_in.is_some());
1611        assert_eq!(response.scope, Some("urn:mas:admin".parse().unwrap()));
1612
1613        // Revoke the token
1614        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1615            "token": response.access_token,
1616            "client_id": client_id,
1617            "client_secret": client_secret,
1618        }));
1619
1620        let response = state.request(request).await;
1621        response.assert_status(StatusCode::OK);
1622    }
1623
1624    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1625    async fn test_device_code_grant(pool: PgPool) {
1626        setup();
1627        let state = TestState::from_pool(pool).await.unwrap();
1628
1629        // Provision a client
1630        let request =
1631            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1632                "client_uri": "https://example.com/",
1633                "token_endpoint_auth_method": "none",
1634                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"],
1635                "response_types": [],
1636            }));
1637
1638        let response = state.request(request).await;
1639        response.assert_status(StatusCode::CREATED);
1640
1641        let response: ClientRegistrationResponse = response.json();
1642        let client_id = response.client_id;
1643
1644        // Start a device code grant
1645        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1646            serde_json::json!({
1647                "client_id": client_id,
1648                "scope": "openid",
1649            }),
1650        );
1651        let response = state.request(request).await;
1652        response.assert_status(StatusCode::OK);
1653
1654        let device_grant: DeviceAuthorizationResponse = response.json();
1655
1656        // Poll the token endpoint, it should be pending
1657        let request =
1658            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1659                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1660                "device_code": device_grant.device_code,
1661                "client_id": client_id,
1662            }));
1663        let response = state.request(request).await;
1664        response.assert_status(StatusCode::FORBIDDEN);
1665
1666        let ClientError { error, .. } = response.json();
1667        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1668
1669        // Let's provision a user and create a browser session for them. This part is
1670        // hard to test with just HTTP requests, so we'll use the repository
1671        // directly.
1672        let mut repo = state.repository().await.unwrap();
1673
1674        let user = repo
1675            .user()
1676            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1677            .await
1678            .unwrap();
1679
1680        let browser_session = repo
1681            .browser_session()
1682            .add(&mut state.rng(), &state.clock, &user, None)
1683            .await
1684            .unwrap();
1685
1686        // Find the grant
1687        let grant = repo
1688            .oauth2_device_code_grant()
1689            .find_by_user_code(&device_grant.user_code)
1690            .await
1691            .unwrap()
1692            .unwrap();
1693
1694        // And fulfill it
1695        let grant = repo
1696            .oauth2_device_code_grant()
1697            .fulfill(&state.clock, grant, &browser_session)
1698            .await
1699            .unwrap();
1700
1701        repo.save().await.unwrap();
1702
1703        // Now call the token endpoint to get an access token.
1704        let request =
1705            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1706                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1707                "device_code": grant.device_code,
1708                "client_id": client_id,
1709            }));
1710
1711        let response = state.request(request).await;
1712        response.assert_status(StatusCode::OK);
1713
1714        let response: AccessTokenResponse = response.json();
1715
1716        // Check that the token is valid
1717        assert!(state.is_access_token_valid(&response.access_token).await);
1718        // We advertised the refresh token grant type, so we should have a refresh token
1719        assert!(response.refresh_token.is_some());
1720        // We asked for the openid scope, so we should have an ID token
1721        assert!(response.id_token.is_some());
1722
1723        // Calling it again should fail
1724        let request =
1725            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1726                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1727                "device_code": grant.device_code,
1728                "client_id": client_id,
1729            }));
1730        let response = state.request(request).await;
1731        response.assert_status(StatusCode::BAD_REQUEST);
1732
1733        let ClientError { error, .. } = response.json();
1734        assert_eq!(error, ClientErrorCode::InvalidGrant);
1735
1736        // Do another grant and make it expire
1737        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1738            serde_json::json!({
1739                "client_id": client_id,
1740                "scope": "openid",
1741            }),
1742        );
1743        let response = state.request(request).await;
1744        response.assert_status(StatusCode::OK);
1745
1746        let device_grant: DeviceAuthorizationResponse = response.json();
1747
1748        // Poll the token endpoint, it should be pending
1749        let request =
1750            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1751                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1752                "device_code": device_grant.device_code,
1753                "client_id": client_id,
1754            }));
1755        let response = state.request(request).await;
1756        response.assert_status(StatusCode::FORBIDDEN);
1757
1758        let ClientError { error, .. } = response.json();
1759        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1760
1761        state.clock.advance(Duration::try_hours(1).unwrap());
1762
1763        // Poll again, it should be expired
1764        let request =
1765            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1766                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1767                "device_code": device_grant.device_code,
1768                "client_id": client_id,
1769            }));
1770        let response = state.request(request).await;
1771        response.assert_status(StatusCode::FORBIDDEN);
1772
1773        let ClientError { error, .. } = response.json();
1774        assert_eq!(error, ClientErrorCode::ExpiredToken);
1775
1776        // Do another grant and reject it
1777        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1778            serde_json::json!({
1779                "client_id": client_id,
1780                "scope": "openid",
1781            }),
1782        );
1783        let response = state.request(request).await;
1784        response.assert_status(StatusCode::OK);
1785
1786        let device_grant: DeviceAuthorizationResponse = response.json();
1787
1788        // Find the grant and reject it
1789        let mut repo = state.repository().await.unwrap();
1790
1791        // Find the grant
1792        let grant = repo
1793            .oauth2_device_code_grant()
1794            .find_by_user_code(&device_grant.user_code)
1795            .await
1796            .unwrap()
1797            .unwrap();
1798
1799        // And reject it
1800        let grant = repo
1801            .oauth2_device_code_grant()
1802            .reject(&state.clock, grant, &browser_session)
1803            .await
1804            .unwrap();
1805
1806        repo.save().await.unwrap();
1807
1808        // Poll the token endpoint, it should be rejected
1809        let request =
1810            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1811                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1812                "device_code": grant.device_code,
1813                "client_id": client_id,
1814            }));
1815        let response = state.request(request).await;
1816        response.assert_status(StatusCode::FORBIDDEN);
1817
1818        let ClientError { error, .. } = response.json();
1819        assert_eq!(error, ClientErrorCode::AccessDenied);
1820    }
1821
1822    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1823    async fn test_unsupported_grant(pool: PgPool) {
1824        setup();
1825        let state = TestState::from_pool(pool).await.unwrap();
1826
1827        // Provision a client
1828        let request =
1829            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1830                "client_uri": "https://example.com/",
1831                "redirect_uris": ["https://example.com/callback"],
1832                "token_endpoint_auth_method": "client_secret_post",
1833                "grant_types": ["password"],
1834                "response_types": [],
1835            }));
1836
1837        let response = state.request(request).await;
1838        response.assert_status(StatusCode::CREATED);
1839
1840        let response: ClientRegistrationResponse = response.json();
1841        let client_id = response.client_id;
1842        let client_secret = response.client_secret.expect("to have a client secret");
1843
1844        // Call the token endpoint with an unsupported grant type
1845        let request =
1846            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1847                "grant_type": "password",
1848                "client_id": client_id,
1849                "client_secret": client_secret,
1850                "username": "john",
1851                "password": "hunter2",
1852            }));
1853
1854        let response = state.request(request).await;
1855        response.assert_status(StatusCode::BAD_REQUEST);
1856        let ClientError { error, .. } = response.json();
1857        assert_eq!(error, ClientErrorCode::UnsupportedGrantType);
1858    }
1859}