mas_handlers/oauth2/
token.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::Arc;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::typed_header::TypedHeader;
11use chrono::Duration;
12use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
13use hyper::StatusCode;
14use mas_axum_utils::{
15    client_authorization::{ClientAuthorization, CredentialsVerificationError},
16    sentry::SentryEventID,
17};
18use mas_data_model::{
19    AuthorizationGrantStage, Client, Device, DeviceCodeGrantState, SiteConfig, TokenType, UserAgent,
20};
21use mas_keystore::{Encrypter, Keystore};
22use mas_matrix::HomeserverConnection;
23use mas_oidc_client::types::scope::ScopeToken;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27    BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
28    oauth2::{
29        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
30        OAuth2RefreshTokenRepository, OAuth2SessionRepository,
31    },
32    user::BrowserSessionRepository,
33};
34use oauth2_types::{
35    errors::{ClientError, ClientErrorCode},
36    pkce::CodeChallengeError,
37    requests::{
38        AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, ClientCredentialsGrant,
39        DeviceCodeGrant, GrantType, RefreshTokenGrant,
40    },
41    scope,
42};
43use thiserror::Error;
44use tracing::{debug, info};
45use ulid::Ulid;
46
47use super::{generate_id_token, generate_token_pair};
48use crate::{BoundActivityTracker, impl_from_error_for_route};
49
50#[derive(Debug, Error)]
51pub(crate) enum RouteError {
52    #[error(transparent)]
53    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
54
55    #[error("bad request")]
56    BadRequest,
57
58    #[error("pkce verification failed")]
59    PkceVerification(#[from] CodeChallengeError),
60
61    #[error("client not found")]
62    ClientNotFound,
63
64    #[error("client not allowed")]
65    ClientNotAllowed,
66
67    #[error("could not verify client credentials")]
68    ClientCredentialsVerification(#[from] CredentialsVerificationError),
69
70    #[error("grant not found")]
71    GrantNotFound,
72
73    #[error("invalid grant")]
74    InvalidGrant,
75
76    #[error("refresh token not found")]
77    RefreshTokenNotFound,
78
79    #[error("refresh token {0} is invalid")]
80    RefreshTokenInvalid(Ulid),
81
82    #[error("session {0} is invalid")]
83    SessionInvalid(Ulid),
84
85    #[error("client id mismatch: expected {expected}, got {actual}")]
86    ClientIDMismatch { expected: Ulid, actual: Ulid },
87
88    #[error("policy denied the request")]
89    DeniedByPolicy(Vec<mas_policy::Violation>),
90
91    #[error("unsupported grant type")]
92    UnsupportedGrantType,
93
94    #[error("unauthorized client")]
95    UnauthorizedClient,
96
97    #[error("failed to load browser session")]
98    NoSuchBrowserSession,
99
100    #[error("failed to load oauth session")]
101    NoSuchOAuthSession,
102
103    #[error(
104        "failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
105    )]
106    NoSuchNextRefreshToken { next: Ulid, previous: Ulid },
107
108    #[error(
109        "failed to load the access token ({access_token:?}) associated with the next refresh token ({refresh_token:?})"
110    )]
111    NoSuchNextAccessToken {
112        access_token: Ulid,
113        refresh_token: Ulid,
114    },
115
116    #[error("no access token associated with the refresh token {refresh_token:?}")]
117    NoAccessTokenOnRefreshToken { refresh_token: Ulid },
118
119    #[error("device code grant expired")]
120    DeviceCodeExpired,
121
122    #[error("device code grant is still pending")]
123    DeviceCodePending,
124
125    #[error("device code grant was rejected")]
126    DeviceCodeRejected,
127
128    #[error("device code grant was already exchanged")]
129    DeviceCodeExchanged,
130
131    #[error("failed to provision device")]
132    ProvisionDeviceFailed(#[source] anyhow::Error),
133}
134
135impl IntoResponse for RouteError {
136    fn into_response(self) -> axum::response::Response {
137        let event_id = sentry::capture_error(&self);
138
139        let response = match self {
140            Self::Internal(_)
141            | Self::NoSuchBrowserSession
142            | Self::NoSuchOAuthSession
143            | Self::ProvisionDeviceFailed(_)
144            | Self::NoSuchNextRefreshToken { .. }
145            | Self::NoSuchNextAccessToken { .. }
146            | Self::NoAccessTokenOnRefreshToken { .. } => (
147                StatusCode::INTERNAL_SERVER_ERROR,
148                Json(ClientError::from(ClientErrorCode::ServerError)),
149            ),
150            Self::BadRequest => (
151                StatusCode::BAD_REQUEST,
152                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
153            ),
154            Self::PkceVerification(err) => (
155                StatusCode::BAD_REQUEST,
156                Json(
157                    ClientError::from(ClientErrorCode::InvalidGrant)
158                        .with_description(format!("PKCE verification failed: {err}")),
159                ),
160            ),
161            Self::ClientNotFound | Self::ClientCredentialsVerification(_) => (
162                StatusCode::UNAUTHORIZED,
163                Json(ClientError::from(ClientErrorCode::InvalidClient)),
164            ),
165            Self::ClientNotAllowed | Self::UnauthorizedClient => (
166                StatusCode::UNAUTHORIZED,
167                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
168            ),
169            Self::DeniedByPolicy(violations) => (
170                StatusCode::FORBIDDEN,
171                Json(
172                    ClientError::from(ClientErrorCode::InvalidScope).with_description(
173                        violations
174                            .into_iter()
175                            .map(|violation| violation.msg)
176                            .collect::<Vec<_>>()
177                            .join(", "),
178                    ),
179                ),
180            ),
181            Self::DeviceCodeRejected => (
182                StatusCode::FORBIDDEN,
183                Json(ClientError::from(ClientErrorCode::AccessDenied)),
184            ),
185            Self::DeviceCodeExpired => (
186                StatusCode::FORBIDDEN,
187                Json(ClientError::from(ClientErrorCode::ExpiredToken)),
188            ),
189            Self::DeviceCodePending => (
190                StatusCode::FORBIDDEN,
191                Json(ClientError::from(ClientErrorCode::AuthorizationPending)),
192            ),
193            Self::InvalidGrant
194            | Self::DeviceCodeExchanged
195            | Self::RefreshTokenNotFound
196            | Self::RefreshTokenInvalid(_)
197            | Self::SessionInvalid(_)
198            | Self::ClientIDMismatch { .. }
199            | Self::GrantNotFound => (
200                StatusCode::BAD_REQUEST,
201                Json(ClientError::from(ClientErrorCode::InvalidGrant)),
202            ),
203            Self::UnsupportedGrantType => (
204                StatusCode::BAD_REQUEST,
205                Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
206            ),
207        };
208
209        (SentryEventID::from(event_id), response).into_response()
210    }
211}
212
213impl_from_error_for_route!(mas_storage::RepositoryError);
214impl_from_error_for_route!(mas_policy::EvaluationError);
215impl_from_error_for_route!(super::IdTokenSignatureError);
216
217#[tracing::instrument(
218    name = "handlers.oauth2.token.post",
219    fields(client.id = client_authorization.client_id()),
220    skip_all,
221    err,
222)]
223pub(crate) async fn post(
224    mut rng: BoxRng,
225    clock: BoxClock,
226    State(http_client): State<reqwest::Client>,
227    State(key_store): State<Keystore>,
228    State(url_builder): State<UrlBuilder>,
229    activity_tracker: BoundActivityTracker,
230    mut repo: BoxRepository,
231    State(homeserver): State<Arc<dyn HomeserverConnection>>,
232    State(site_config): State<SiteConfig>,
233    State(encrypter): State<Encrypter>,
234    policy: Policy,
235    user_agent: Option<TypedHeader<headers::UserAgent>>,
236    client_authorization: ClientAuthorization<AccessTokenRequest>,
237) -> Result<impl IntoResponse, RouteError> {
238    let user_agent = user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned()));
239    let client = client_authorization
240        .credentials
241        .fetch(&mut repo)
242        .await?
243        .ok_or(RouteError::ClientNotFound)?;
244
245    let method = client
246        .token_endpoint_auth_method
247        .as_ref()
248        .ok_or(RouteError::ClientNotAllowed)?;
249
250    client_authorization
251        .credentials
252        .verify(&http_client, &encrypter, method, &client)
253        .await?;
254
255    let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
256
257    let (reply, repo) = match form {
258        AccessTokenRequest::AuthorizationCode(grant) => {
259            authorization_code_grant(
260                &mut rng,
261                &clock,
262                &activity_tracker,
263                &grant,
264                &client,
265                &key_store,
266                &url_builder,
267                &site_config,
268                repo,
269                &homeserver,
270                user_agent,
271            )
272            .await?
273        }
274        AccessTokenRequest::RefreshToken(grant) => {
275            refresh_token_grant(
276                &mut rng,
277                &clock,
278                &activity_tracker,
279                &grant,
280                &client,
281                &site_config,
282                repo,
283                user_agent,
284            )
285            .await?
286        }
287        AccessTokenRequest::ClientCredentials(grant) => {
288            client_credentials_grant(
289                &mut rng,
290                &clock,
291                &activity_tracker,
292                &grant,
293                &client,
294                &site_config,
295                repo,
296                policy,
297                user_agent,
298            )
299            .await?
300        }
301        AccessTokenRequest::DeviceCode(grant) => {
302            device_code_grant(
303                &mut rng,
304                &clock,
305                &activity_tracker,
306                &grant,
307                &client,
308                &key_store,
309                &url_builder,
310                &site_config,
311                repo,
312                &homeserver,
313                user_agent,
314            )
315            .await?
316        }
317        _ => {
318            return Err(RouteError::UnsupportedGrantType);
319        }
320    };
321
322    repo.save().await?;
323
324    let mut headers = HeaderMap::new();
325    headers.typed_insert(CacheControl::new().with_no_store());
326    headers.typed_insert(Pragma::no_cache());
327
328    Ok((headers, Json(reply)))
329}
330
331#[allow(clippy::too_many_lines)] // TODO: refactor some parts out
332async fn authorization_code_grant(
333    mut rng: &mut BoxRng,
334    clock: &impl Clock,
335    activity_tracker: &BoundActivityTracker,
336    grant: &AuthorizationCodeGrant,
337    client: &Client,
338    key_store: &Keystore,
339    url_builder: &UrlBuilder,
340    site_config: &SiteConfig,
341    mut repo: BoxRepository,
342    homeserver: &Arc<dyn HomeserverConnection>,
343    user_agent: Option<UserAgent>,
344) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
345    // Check that the client is allowed to use this grant type
346    if !client.grant_types.contains(&GrantType::AuthorizationCode) {
347        return Err(RouteError::UnauthorizedClient);
348    }
349
350    let authz_grant = repo
351        .oauth2_authorization_grant()
352        .find_by_code(&grant.code)
353        .await?
354        .ok_or(RouteError::GrantNotFound)?;
355
356    let now = clock.now();
357
358    let session_id = match authz_grant.stage {
359        AuthorizationGrantStage::Cancelled { cancelled_at } => {
360            debug!(%cancelled_at, "Authorization grant was cancelled");
361            return Err(RouteError::InvalidGrant);
362        }
363        AuthorizationGrantStage::Exchanged {
364            exchanged_at,
365            fulfilled_at,
366            session_id,
367        } => {
368            debug!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
369
370            // Ending the session if the token was already exchanged more than 20s ago
371            if now - exchanged_at > Duration::microseconds(20 * 1000 * 1000) {
372                debug!("Ending potentially compromised session");
373                let session = repo
374                    .oauth2_session()
375                    .lookup(session_id)
376                    .await?
377                    .ok_or(RouteError::NoSuchOAuthSession)?;
378                repo.oauth2_session().finish(clock, session).await?;
379                repo.save().await?;
380            }
381
382            return Err(RouteError::InvalidGrant);
383        }
384        AuthorizationGrantStage::Pending => {
385            debug!("Authorization grant has not been fulfilled yet");
386            return Err(RouteError::InvalidGrant);
387        }
388        AuthorizationGrantStage::Fulfilled {
389            session_id,
390            fulfilled_at,
391        } => {
392            if now - fulfilled_at > Duration::microseconds(10 * 60 * 1000 * 1000) {
393                debug!("Code exchange took more than 10 minutes");
394                return Err(RouteError::InvalidGrant);
395            }
396
397            session_id
398        }
399    };
400
401    let mut session = repo
402        .oauth2_session()
403        .lookup(session_id)
404        .await?
405        .ok_or(RouteError::NoSuchOAuthSession)?;
406
407    if let Some(user_agent) = user_agent {
408        session = repo
409            .oauth2_session()
410            .record_user_agent(session, user_agent)
411            .await?;
412    }
413
414    // This should never happen, since we looked up in the database using the code
415    let code = authz_grant.code.as_ref().ok_or(RouteError::InvalidGrant)?;
416
417    if client.id != session.client_id {
418        return Err(RouteError::UnauthorizedClient);
419    }
420
421    match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
422        (None, None) => {}
423        // We have a challenge but no verifier (or vice-versa)? Bad request.
424        (Some(_), None) | (None, Some(_)) => return Err(RouteError::BadRequest),
425        // If we have both, we need to check the code validity
426        (Some(pkce), Some(verifier)) => {
427            pkce.verify(verifier)?;
428        }
429    }
430
431    let Some(user_session_id) = session.user_session_id else {
432        tracing::warn!("No user session associated with this OAuth2 session");
433        return Err(RouteError::InvalidGrant);
434    };
435
436    let browser_session = repo
437        .browser_session()
438        .lookup(user_session_id)
439        .await?
440        .ok_or(RouteError::NoSuchBrowserSession)?;
441
442    let last_authentication = repo
443        .browser_session()
444        .get_last_authentication(&browser_session)
445        .await?;
446
447    let ttl = site_config.access_token_ttl;
448    let (access_token, refresh_token) =
449        generate_token_pair(&mut rng, clock, &mut repo, &session, ttl).await?;
450
451    let id_token = if session.scope.contains(&scope::OPENID) {
452        Some(generate_id_token(
453            &mut rng,
454            clock,
455            url_builder,
456            key_store,
457            client,
458            Some(&authz_grant),
459            &browser_session,
460            Some(&access_token),
461            last_authentication.as_ref(),
462        )?)
463    } else {
464        None
465    };
466
467    let mut params = AccessTokenResponse::new(access_token.access_token)
468        .with_expires_in(ttl)
469        .with_refresh_token(refresh_token.refresh_token)
470        .with_scope(session.scope.clone());
471
472    if let Some(id_token) = id_token {
473        params = params.with_id_token(id_token);
474    }
475
476    // Lock the user sync to make sure we don't get into a race condition
477    repo.user()
478        .acquire_lock_for_sync(&browser_session.user)
479        .await?;
480
481    // Look for device to provision
482    let mxid = homeserver.mxid(&browser_session.user.username);
483    for scope in &*session.scope {
484        if let Some(device) = Device::from_scope_token(scope) {
485            homeserver
486                .create_device(&mxid, device.as_str())
487                .await
488                .map_err(RouteError::ProvisionDeviceFailed)?;
489        }
490    }
491
492    repo.oauth2_authorization_grant()
493        .exchange(clock, authz_grant)
494        .await?;
495
496    // XXX: there is a potential (but unlikely) race here, where the activity for
497    // the session is recorded before the transaction is committed. We would have to
498    // save the repository here to fix that.
499    activity_tracker
500        .record_oauth2_session(clock, &session)
501        .await;
502
503    Ok((params, repo))
504}
505
506#[allow(clippy::too_many_lines)]
507async fn refresh_token_grant(
508    rng: &mut BoxRng,
509    clock: &impl Clock,
510    activity_tracker: &BoundActivityTracker,
511    grant: &RefreshTokenGrant,
512    client: &Client,
513    site_config: &SiteConfig,
514    mut repo: BoxRepository,
515    user_agent: Option<UserAgent>,
516) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
517    // Check that the client is allowed to use this grant type
518    if !client.grant_types.contains(&GrantType::RefreshToken) {
519        return Err(RouteError::UnauthorizedClient);
520    }
521
522    let refresh_token = repo
523        .oauth2_refresh_token()
524        .find_by_token(&grant.refresh_token)
525        .await?
526        .ok_or(RouteError::RefreshTokenNotFound)?;
527
528    let mut session = repo
529        .oauth2_session()
530        .lookup(refresh_token.session_id)
531        .await?
532        .ok_or(RouteError::NoSuchOAuthSession)?;
533
534    // Let's for now record the user agent on each refresh, that should be
535    // responsive enough and not too much of a burden on the database.
536    if let Some(user_agent) = user_agent {
537        session = repo
538            .oauth2_session()
539            .record_user_agent(session, user_agent)
540            .await?;
541    }
542
543    if !session.is_valid() {
544        return Err(RouteError::SessionInvalid(session.id));
545    }
546
547    if client.id != session.client_id {
548        // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
549        return Err(RouteError::ClientIDMismatch {
550            expected: session.client_id,
551            actual: client.id,
552        });
553    }
554
555    if !refresh_token.is_valid() {
556        // We're seing a refresh token that already has been consumed, this might be a
557        // double-refresh or a replay attack
558
559        // First, get the next refresh token
560        let Some(next_refresh_token_id) = refresh_token.next_refresh_token_id() else {
561            // If we don't have a 'next' refresh token, it may just be because this was
562            // before we were recording those. Let's just treat it as a replay.
563            return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
564        };
565
566        let Some(next_refresh_token) = repo
567            .oauth2_refresh_token()
568            .lookup(next_refresh_token_id)
569            .await?
570        else {
571            return Err(RouteError::NoSuchNextRefreshToken {
572                next: next_refresh_token_id,
573                previous: refresh_token.id,
574            });
575        };
576
577        // Check if the next refresh token was already consumed or not
578        if !next_refresh_token.is_valid() {
579            // XXX: This is a replay, we *may* want to invalidate the session
580            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
581        }
582
583        // Check if the associated access token was already used
584        let Some(access_token_id) = next_refresh_token.access_token_id else {
585            // This should in theory not happen: this means an access token got cleaned up,
586            // but the refresh token was still valid.
587            return Err(RouteError::NoAccessTokenOnRefreshToken {
588                refresh_token: next_refresh_token.id,
589            });
590        };
591
592        // Load it
593        let next_access_token = repo
594            .oauth2_access_token()
595            .lookup(access_token_id)
596            .await?
597            .ok_or(RouteError::NoSuchNextAccessToken {
598                access_token: access_token_id,
599                refresh_token: next_refresh_token_id,
600            })?;
601
602        if next_access_token.is_used() {
603            // XXX: This is a replay, we *may* want to invalidate the session
604            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
605        }
606
607        // Looks like it's a double-refresh, client lost their refresh token on
608        // the way back. Let's revoke the unused access and refresh tokens, and
609        // issue new ones
610        info!(
611            oauth2_session.id = %session.id,
612            oauth2_client.id = %client.id,
613            %refresh_token.id,
614            "Refresh token already used, but issued refresh and access tokens are unused. Assuming those were lost; revoking those and reissuing new ones."
615        );
616
617        repo.oauth2_access_token()
618            .revoke(clock, next_access_token)
619            .await?;
620
621        repo.oauth2_refresh_token()
622            .revoke(clock, next_refresh_token)
623            .await?;
624    }
625
626    activity_tracker
627        .record_oauth2_session(clock, &session)
628        .await;
629
630    let ttl = site_config.access_token_ttl;
631    let (new_access_token, new_refresh_token) =
632        generate_token_pair(rng, clock, &mut repo, &session, ttl).await?;
633
634    let refresh_token = repo
635        .oauth2_refresh_token()
636        .consume(clock, refresh_token, &new_refresh_token)
637        .await?;
638
639    if let Some(access_token_id) = refresh_token.access_token_id {
640        let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
641        if let Some(access_token) = access_token {
642            // If it is a double-refresh, it might already be revoked
643            if !access_token.state.is_revoked() {
644                repo.oauth2_access_token()
645                    .revoke(clock, access_token)
646                    .await?;
647            }
648        }
649    }
650
651    let params = AccessTokenResponse::new(new_access_token.access_token)
652        .with_expires_in(ttl)
653        .with_refresh_token(new_refresh_token.refresh_token)
654        .with_scope(session.scope);
655
656    Ok((params, repo))
657}
658
659async fn client_credentials_grant(
660    rng: &mut BoxRng,
661    clock: &impl Clock,
662    activity_tracker: &BoundActivityTracker,
663    grant: &ClientCredentialsGrant,
664    client: &Client,
665    site_config: &SiteConfig,
666    mut repo: BoxRepository,
667    mut policy: Policy,
668    user_agent: Option<UserAgent>,
669) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
670    // Check that the client is allowed to use this grant type
671    if !client.grant_types.contains(&GrantType::ClientCredentials) {
672        return Err(RouteError::UnauthorizedClient);
673    }
674
675    // Default to an empty scope if none is provided
676    let scope = grant
677        .scope
678        .clone()
679        .unwrap_or_else(|| std::iter::empty::<ScopeToken>().collect());
680
681    // Make the request go through the policy engine
682    let res = policy
683        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
684            user: None,
685            client,
686            scope: &scope,
687            grant_type: mas_policy::GrantType::ClientCredentials,
688            requester: mas_policy::Requester {
689                ip_address: activity_tracker.ip(),
690                user_agent: user_agent.clone().map(|ua| ua.raw),
691            },
692        })
693        .await?;
694    if !res.valid() {
695        return Err(RouteError::DeniedByPolicy(res.violations));
696    }
697
698    // Start the session
699    let mut session = repo
700        .oauth2_session()
701        .add_from_client_credentials(rng, clock, client, scope)
702        .await?;
703
704    if let Some(user_agent) = user_agent {
705        session = repo
706            .oauth2_session()
707            .record_user_agent(session, user_agent)
708            .await?;
709    }
710
711    let ttl = site_config.access_token_ttl;
712    let access_token_str = TokenType::AccessToken.generate(rng);
713
714    let access_token = repo
715        .oauth2_access_token()
716        .add(rng, clock, &session, access_token_str, Some(ttl))
717        .await?;
718
719    let mut params = AccessTokenResponse::new(access_token.access_token).with_expires_in(ttl);
720
721    // XXX: there is a potential (but unlikely) race here, where the activity for
722    // the session is recorded before the transaction is committed. We would have to
723    // save the repository here to fix that.
724    activity_tracker
725        .record_oauth2_session(clock, &session)
726        .await;
727
728    if !session.scope.is_empty() {
729        // We only return the scope if it's not empty
730        params = params.with_scope(session.scope);
731    }
732
733    Ok((params, repo))
734}
735
736async fn device_code_grant(
737    rng: &mut BoxRng,
738    clock: &impl Clock,
739    activity_tracker: &BoundActivityTracker,
740    grant: &DeviceCodeGrant,
741    client: &Client,
742    key_store: &Keystore,
743    url_builder: &UrlBuilder,
744    site_config: &SiteConfig,
745    mut repo: BoxRepository,
746    homeserver: &Arc<dyn HomeserverConnection>,
747    user_agent: Option<UserAgent>,
748) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
749    // Check that the client is allowed to use this grant type
750    if !client.grant_types.contains(&GrantType::DeviceCode) {
751        return Err(RouteError::UnauthorizedClient);
752    }
753
754    let grant = repo
755        .oauth2_device_code_grant()
756        .find_by_device_code(&grant.device_code)
757        .await?
758        .ok_or(RouteError::GrantNotFound)?;
759
760    // Check that the client match
761    if client.id != grant.client_id {
762        return Err(RouteError::ClientIDMismatch {
763            expected: grant.client_id,
764            actual: client.id,
765        });
766    }
767
768    if grant.expires_at < clock.now() {
769        return Err(RouteError::DeviceCodeExpired);
770    }
771
772    let browser_session_id = match &grant.state {
773        DeviceCodeGrantState::Pending => {
774            return Err(RouteError::DeviceCodePending);
775        }
776        DeviceCodeGrantState::Rejected { .. } => {
777            return Err(RouteError::DeviceCodeRejected);
778        }
779        DeviceCodeGrantState::Exchanged { .. } => {
780            return Err(RouteError::DeviceCodeExchanged);
781        }
782        DeviceCodeGrantState::Fulfilled {
783            browser_session_id, ..
784        } => browser_session_id,
785    };
786
787    let browser_session = repo
788        .browser_session()
789        .lookup(*browser_session_id)
790        .await?
791        .ok_or(RouteError::NoSuchBrowserSession)?;
792
793    // Start the session
794    let mut session = repo
795        .oauth2_session()
796        .add_from_browser_session(rng, clock, client, &browser_session, grant.scope.clone())
797        .await?;
798
799    repo.oauth2_device_code_grant()
800        .exchange(clock, grant, &session)
801        .await?;
802
803    // XXX: should we get the user agent from the device code grant instead?
804    if let Some(user_agent) = user_agent {
805        session = repo
806            .oauth2_session()
807            .record_user_agent(session, user_agent)
808            .await?;
809    }
810
811    let ttl = site_config.access_token_ttl;
812    let access_token_str = TokenType::AccessToken.generate(rng);
813
814    let access_token = repo
815        .oauth2_access_token()
816        .add(rng, clock, &session, access_token_str, Some(ttl))
817        .await?;
818
819    let mut params =
820        AccessTokenResponse::new(access_token.access_token.clone()).with_expires_in(ttl);
821
822    // If the client uses the refresh token grant type, we also generate a refresh
823    // token
824    if client.grant_types.contains(&GrantType::RefreshToken) {
825        let refresh_token_str = TokenType::RefreshToken.generate(rng);
826
827        let refresh_token = repo
828            .oauth2_refresh_token()
829            .add(rng, clock, &session, &access_token, refresh_token_str)
830            .await?;
831
832        params = params.with_refresh_token(refresh_token.refresh_token);
833    }
834
835    // If the client asked for an ID token, we generate one
836    if session.scope.contains(&scope::OPENID) {
837        let id_token = generate_id_token(
838            rng,
839            clock,
840            url_builder,
841            key_store,
842            client,
843            None,
844            &browser_session,
845            Some(&access_token),
846            None,
847        )?;
848
849        params = params.with_id_token(id_token);
850    }
851
852    // Lock the user sync to make sure we don't get into a race condition
853    repo.user()
854        .acquire_lock_for_sync(&browser_session.user)
855        .await?;
856
857    // Look for device to provision
858    let mxid = homeserver.mxid(&browser_session.user.username);
859    for scope in &*session.scope {
860        if let Some(device) = Device::from_scope_token(scope) {
861            homeserver
862                .create_device(&mxid, device.as_str())
863                .await
864                .map_err(RouteError::ProvisionDeviceFailed)?;
865        }
866    }
867
868    // XXX: there is a potential (but unlikely) race here, where the activity for
869    // the session is recorded before the transaction is committed. We would have to
870    // save the repository here to fix that.
871    activity_tracker
872        .record_oauth2_session(clock, &session)
873        .await;
874
875    if !session.scope.is_empty() {
876        // We only return the scope if it's not empty
877        params = params.with_scope(session.scope);
878    }
879
880    Ok((params, repo))
881}
882
883#[cfg(test)]
884mod tests {
885    use hyper::Request;
886    use mas_data_model::{AccessToken, AuthorizationCode, RefreshToken};
887    use mas_router::SimpleRoute;
888    use oauth2_types::{
889        registration::ClientRegistrationResponse,
890        requests::{DeviceAuthorizationResponse, ResponseMode},
891        scope::{OPENID, Scope},
892    };
893    use sqlx::PgPool;
894
895    use super::*;
896    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
897
898    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
899    async fn test_auth_code_grant(pool: PgPool) {
900        setup();
901        let state = TestState::from_pool(pool).await.unwrap();
902
903        // Provision a client
904        let request =
905            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
906                "client_uri": "https://example.com/",
907                "redirect_uris": ["https://example.com/callback"],
908                "token_endpoint_auth_method": "none",
909                "response_types": ["code"],
910                "grant_types": ["authorization_code"],
911            }));
912
913        let response = state.request(request).await;
914        response.assert_status(StatusCode::CREATED);
915
916        let ClientRegistrationResponse { client_id, .. } = response.json();
917
918        // Let's provision a user and create a session for them. This part is hard to
919        // test with just HTTP requests, so we'll use the repository directly.
920        let mut repo = state.repository().await.unwrap();
921
922        let user = repo
923            .user()
924            .add(&mut state.rng(), &state.clock, "alice".to_owned())
925            .await
926            .unwrap();
927
928        let browser_session = repo
929            .browser_session()
930            .add(&mut state.rng(), &state.clock, &user, None)
931            .await
932            .unwrap();
933
934        // Lookup the client in the database.
935        let client = repo
936            .oauth2_client()
937            .find_by_client_id(&client_id)
938            .await
939            .unwrap()
940            .unwrap();
941
942        // Start a grant
943        let code = "thisisaverysecurecode";
944        let grant = repo
945            .oauth2_authorization_grant()
946            .add(
947                &mut state.rng(),
948                &state.clock,
949                &client,
950                "https://example.com/redirect".parse().unwrap(),
951                Scope::from_iter([OPENID]),
952                Some(AuthorizationCode {
953                    code: code.to_owned(),
954                    pkce: None,
955                }),
956                Some("state".to_owned()),
957                Some("nonce".to_owned()),
958                None,
959                ResponseMode::Query,
960                false,
961                false,
962                None,
963            )
964            .await
965            .unwrap();
966
967        let session = repo
968            .oauth2_session()
969            .add_from_browser_session(
970                &mut state.rng(),
971                &state.clock,
972                &client,
973                &browser_session,
974                grant.scope.clone(),
975            )
976            .await
977            .unwrap();
978
979        // And fulfill it
980        let grant = repo
981            .oauth2_authorization_grant()
982            .fulfill(&state.clock, &session, grant)
983            .await
984            .unwrap();
985
986        repo.save().await.unwrap();
987
988        // Now call the token endpoint to get an access token.
989        let request =
990            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
991                "grant_type": "authorization_code",
992                "code": code,
993                "redirect_uri": grant.redirect_uri,
994                "client_id": client.client_id,
995            }));
996
997        let response = state.request(request).await;
998        response.assert_status(StatusCode::OK);
999
1000        let AccessTokenResponse { access_token, .. } = response.json();
1001
1002        // Check that the token is valid
1003        assert!(state.is_access_token_valid(&access_token).await);
1004
1005        // Exchange it again, this it should fail
1006        let request =
1007            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1008                "grant_type": "authorization_code",
1009                "code": code,
1010                "redirect_uri": grant.redirect_uri,
1011                "client_id": client.client_id,
1012            }));
1013
1014        let response = state.request(request).await;
1015        response.assert_status(StatusCode::BAD_REQUEST);
1016        let error: ClientError = response.json();
1017        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1018
1019        // The token should still be valid
1020        assert!(state.is_access_token_valid(&access_token).await);
1021
1022        // Now wait a bit
1023        state.clock.advance(Duration::try_minutes(1).unwrap());
1024
1025        // Exchange it again, this it should fail
1026        let request =
1027            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1028                "grant_type": "authorization_code",
1029                "code": code,
1030                "redirect_uri": grant.redirect_uri,
1031                "client_id": client.client_id,
1032            }));
1033
1034        let response = state.request(request).await;
1035        response.assert_status(StatusCode::BAD_REQUEST);
1036        let error: ClientError = response.json();
1037        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1038
1039        // And it should have revoked the token we got
1040        assert!(!state.is_access_token_valid(&access_token).await);
1041
1042        // Try another one and wait for too long before exchanging it
1043        let mut repo = state.repository().await.unwrap();
1044        let code = "thisisanothercode";
1045        let grant = repo
1046            .oauth2_authorization_grant()
1047            .add(
1048                &mut state.rng(),
1049                &state.clock,
1050                &client,
1051                "https://example.com/redirect".parse().unwrap(),
1052                Scope::from_iter([OPENID]),
1053                Some(AuthorizationCode {
1054                    code: code.to_owned(),
1055                    pkce: None,
1056                }),
1057                Some("state".to_owned()),
1058                Some("nonce".to_owned()),
1059                None,
1060                ResponseMode::Query,
1061                false,
1062                false,
1063                None,
1064            )
1065            .await
1066            .unwrap();
1067
1068        let session = repo
1069            .oauth2_session()
1070            .add_from_browser_session(
1071                &mut state.rng(),
1072                &state.clock,
1073                &client,
1074                &browser_session,
1075                grant.scope.clone(),
1076            )
1077            .await
1078            .unwrap();
1079
1080        // And fulfill it
1081        let grant = repo
1082            .oauth2_authorization_grant()
1083            .fulfill(&state.clock, &session, grant)
1084            .await
1085            .unwrap();
1086
1087        repo.save().await.unwrap();
1088
1089        // Now wait a bit
1090        state
1091            .clock
1092            .advance(Duration::microseconds(15 * 60 * 1000 * 1000));
1093
1094        // Exchange it, it should fail
1095        let request =
1096            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1097                "grant_type": "authorization_code",
1098                "code": code,
1099                "redirect_uri": grant.redirect_uri,
1100                "client_id": client.client_id,
1101            }));
1102
1103        let response = state.request(request).await;
1104        response.assert_status(StatusCode::BAD_REQUEST);
1105        let ClientError { error, .. } = response.json();
1106        assert_eq!(error, ClientErrorCode::InvalidGrant);
1107    }
1108
1109    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1110    async fn test_refresh_token_grant(pool: PgPool) {
1111        setup();
1112        let state = TestState::from_pool(pool).await.unwrap();
1113
1114        // Provision a client
1115        let request =
1116            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1117                "client_uri": "https://example.com/",
1118                "redirect_uris": ["https://example.com/callback"],
1119                "token_endpoint_auth_method": "none",
1120                "response_types": ["code"],
1121                "grant_types": ["authorization_code", "refresh_token"],
1122            }));
1123
1124        let response = state.request(request).await;
1125        response.assert_status(StatusCode::CREATED);
1126
1127        let ClientRegistrationResponse { client_id, .. } = response.json();
1128
1129        // Let's provision a user and create a session for them. This part is hard to
1130        // test with just HTTP requests, so we'll use the repository directly.
1131        let mut repo = state.repository().await.unwrap();
1132
1133        let user = repo
1134            .user()
1135            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1136            .await
1137            .unwrap();
1138
1139        let browser_session = repo
1140            .browser_session()
1141            .add(&mut state.rng(), &state.clock, &user, None)
1142            .await
1143            .unwrap();
1144
1145        // Lookup the client in the database.
1146        let client = repo
1147            .oauth2_client()
1148            .find_by_client_id(&client_id)
1149            .await
1150            .unwrap()
1151            .unwrap();
1152
1153        // Get a token pair
1154        let session = repo
1155            .oauth2_session()
1156            .add_from_browser_session(
1157                &mut state.rng(),
1158                &state.clock,
1159                &client,
1160                &browser_session,
1161                Scope::from_iter([OPENID]),
1162            )
1163            .await
1164            .unwrap();
1165
1166        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1167            generate_token_pair(
1168                &mut state.rng(),
1169                &state.clock,
1170                &mut repo,
1171                &session,
1172                Duration::microseconds(5 * 60 * 1000 * 1000),
1173            )
1174            .await
1175            .unwrap();
1176
1177        repo.save().await.unwrap();
1178
1179        // First check that the token is valid
1180        assert!(state.is_access_token_valid(&access_token).await);
1181
1182        // Now call the token endpoint to get an access token.
1183        let request =
1184            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1185                "grant_type": "refresh_token",
1186                "refresh_token": refresh_token,
1187                "client_id": client.client_id,
1188            }));
1189
1190        let response = state.request(request).await;
1191        response.assert_status(StatusCode::OK);
1192
1193        let old_access_token = access_token;
1194        let old_refresh_token = refresh_token;
1195        let response: AccessTokenResponse = response.json();
1196        let access_token = response.access_token;
1197        let refresh_token = response.refresh_token.expect("to have a refresh token");
1198
1199        // Check that the new token is valid
1200        assert!(state.is_access_token_valid(&access_token).await);
1201
1202        // Check that the old token is no longer valid
1203        assert!(!state.is_access_token_valid(&old_access_token).await);
1204
1205        // Call it again with the old token, it should fail
1206        let request =
1207            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1208                "grant_type": "refresh_token",
1209                "refresh_token": old_refresh_token,
1210                "client_id": client.client_id,
1211            }));
1212
1213        let response = state.request(request).await;
1214        response.assert_status(StatusCode::BAD_REQUEST);
1215        let ClientError { error, .. } = response.json();
1216        assert_eq!(error, ClientErrorCode::InvalidGrant);
1217
1218        // Call it again with the new token, it should work
1219        let request =
1220            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1221                "grant_type": "refresh_token",
1222                "refresh_token": refresh_token,
1223                "client_id": client.client_id,
1224            }));
1225
1226        let response = state.request(request).await;
1227        response.assert_status(StatusCode::OK);
1228        let _: AccessTokenResponse = response.json();
1229    }
1230
1231    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1232    async fn test_double_refresh(pool: PgPool) {
1233        setup();
1234        let state = TestState::from_pool(pool).await.unwrap();
1235
1236        // Provision a client
1237        let request =
1238            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1239                "client_uri": "https://example.com/",
1240                "redirect_uris": ["https://example.com/callback"],
1241                "token_endpoint_auth_method": "none",
1242                "response_types": ["code"],
1243                "grant_types": ["authorization_code", "refresh_token"],
1244            }));
1245
1246        let response = state.request(request).await;
1247        response.assert_status(StatusCode::CREATED);
1248
1249        let ClientRegistrationResponse { client_id, .. } = response.json();
1250
1251        // Let's provision a user and create a session for them. This part is hard to
1252        // test with just HTTP requests, so we'll use the repository directly.
1253        let mut repo = state.repository().await.unwrap();
1254
1255        let user = repo
1256            .user()
1257            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1258            .await
1259            .unwrap();
1260
1261        let browser_session = repo
1262            .browser_session()
1263            .add(&mut state.rng(), &state.clock, &user, None)
1264            .await
1265            .unwrap();
1266
1267        // Lookup the client in the database.
1268        let client = repo
1269            .oauth2_client()
1270            .find_by_client_id(&client_id)
1271            .await
1272            .unwrap()
1273            .unwrap();
1274
1275        // Get a token pair
1276        let session = repo
1277            .oauth2_session()
1278            .add_from_browser_session(
1279                &mut state.rng(),
1280                &state.clock,
1281                &client,
1282                &browser_session,
1283                Scope::from_iter([OPENID]),
1284            )
1285            .await
1286            .unwrap();
1287
1288        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1289            generate_token_pair(
1290                &mut state.rng(),
1291                &state.clock,
1292                &mut repo,
1293                &session,
1294                Duration::microseconds(5 * 60 * 1000 * 1000),
1295            )
1296            .await
1297            .unwrap();
1298
1299        repo.save().await.unwrap();
1300
1301        // First check that the token is valid
1302        assert!(state.is_access_token_valid(&access_token).await);
1303
1304        // Now call the token endpoint to get an access token.
1305        let request =
1306            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1307                "grant_type": "refresh_token",
1308                "refresh_token": refresh_token,
1309                "client_id": client.client_id,
1310            }));
1311
1312        let first_response = state.request(request).await;
1313        first_response.assert_status(StatusCode::OK);
1314        let first_response: AccessTokenResponse = first_response.json();
1315
1316        // Call a second time, it should work, as we haven't done anything yet with the
1317        // token
1318        let request =
1319            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1320                "grant_type": "refresh_token",
1321                "refresh_token": refresh_token,
1322                "client_id": client.client_id,
1323            }));
1324
1325        let second_response = state.request(request).await;
1326        second_response.assert_status(StatusCode::OK);
1327        let second_response: AccessTokenResponse = second_response.json();
1328
1329        // Check that we got new tokens
1330        assert_ne!(first_response.access_token, second_response.access_token);
1331        assert_ne!(first_response.refresh_token, second_response.refresh_token);
1332
1333        // Check that the old-new token is invalid
1334        assert!(
1335            !state
1336                .is_access_token_valid(&first_response.access_token)
1337                .await
1338        );
1339
1340        // Check that the new-new token is valid
1341        assert!(
1342            state
1343                .is_access_token_valid(&second_response.access_token)
1344                .await
1345        );
1346
1347        // Do a third refresh, this one should not work, as we've used the new
1348        // access token
1349        let request =
1350            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1351                "grant_type": "refresh_token",
1352                "refresh_token": refresh_token,
1353                "client_id": client.client_id,
1354            }));
1355
1356        let third_response = state.request(request).await;
1357        third_response.assert_status(StatusCode::BAD_REQUEST);
1358
1359        // The other reason we consider a new refresh token to be 'used' is if
1360        // it was already used in a refresh
1361        // So, if we do a refresh with the second_response.refresh_token, then
1362        // another refresh with the result, redoing one with
1363        // second_response.refresh_token again should fail
1364        let request =
1365            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1366                "grant_type": "refresh_token",
1367                "refresh_token": second_response.refresh_token,
1368                "client_id": client.client_id,
1369            }));
1370
1371        // This one is fine
1372        let fourth_response = state.request(request).await;
1373        fourth_response.assert_status(StatusCode::OK);
1374        let fourth_response: AccessTokenResponse = fourth_response.json();
1375
1376        // Do another one, it should be fine as well
1377        let request =
1378            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1379                "grant_type": "refresh_token",
1380                "refresh_token": fourth_response.refresh_token,
1381                "client_id": client.client_id,
1382            }));
1383
1384        let fifth_response = state.request(request).await;
1385        fifth_response.assert_status(StatusCode::OK);
1386
1387        // But now, if we re-do with the second_response.refresh_token, it should
1388        // fail
1389        let request =
1390            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1391                "grant_type": "refresh_token",
1392                "refresh_token": second_response.refresh_token,
1393                "client_id": client.client_id,
1394            }));
1395
1396        let sixth_response = state.request(request).await;
1397        sixth_response.assert_status(StatusCode::BAD_REQUEST);
1398    }
1399
1400    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1401    async fn test_client_credentials(pool: PgPool) {
1402        setup();
1403        let state = TestState::from_pool(pool).await.unwrap();
1404
1405        // Provision a client
1406        let request =
1407            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1408                "client_uri": "https://example.com/",
1409                "token_endpoint_auth_method": "client_secret_post",
1410                "grant_types": ["client_credentials"],
1411            }));
1412
1413        let response = state.request(request).await;
1414        response.assert_status(StatusCode::CREATED);
1415
1416        let response: ClientRegistrationResponse = response.json();
1417        let client_id = response.client_id;
1418        let client_secret = response.client_secret.expect("to have a client secret");
1419
1420        // Call the token endpoint with an empty scope
1421        let request =
1422            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1423                "grant_type": "client_credentials",
1424                "client_id": client_id,
1425                "client_secret": client_secret,
1426            }));
1427
1428        let response = state.request(request).await;
1429        response.assert_status(StatusCode::OK);
1430
1431        let response: AccessTokenResponse = response.json();
1432        assert!(response.refresh_token.is_none());
1433        assert!(response.expires_in.is_some());
1434        assert!(response.scope.is_none());
1435
1436        // Revoke the token
1437        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1438            "token": response.access_token,
1439            "client_id": client_id,
1440            "client_secret": client_secret,
1441        }));
1442
1443        let response = state.request(request).await;
1444        response.assert_status(StatusCode::OK);
1445
1446        // We should be allowed to ask for the GraphQL API scope
1447        let request =
1448            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1449                "grant_type": "client_credentials",
1450                "client_id": client_id,
1451                "client_secret": client_secret,
1452                "scope": "urn:mas:graphql:*"
1453            }));
1454
1455        let response = state.request(request).await;
1456        response.assert_status(StatusCode::OK);
1457
1458        let response: AccessTokenResponse = response.json();
1459        assert!(response.refresh_token.is_none());
1460        assert!(response.expires_in.is_some());
1461        assert_eq!(response.scope, Some("urn:mas:graphql:*".parse().unwrap()));
1462
1463        // Revoke the token
1464        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1465            "token": response.access_token,
1466            "client_id": client_id,
1467            "client_secret": client_secret,
1468        }));
1469
1470        let response = state.request(request).await;
1471        response.assert_status(StatusCode::OK);
1472
1473        // We should be NOT allowed to ask for the MAS admin scope
1474        let request =
1475            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1476                "grant_type": "client_credentials",
1477                "client_id": client_id,
1478                "client_secret": client_secret,
1479                "scope": "urn:mas:admin"
1480            }));
1481
1482        let response = state.request(request).await;
1483        response.assert_status(StatusCode::FORBIDDEN);
1484
1485        let ClientError { error, .. } = response.json();
1486        assert_eq!(error, ClientErrorCode::InvalidScope);
1487
1488        // Now, if we add the client to the admin list in the policy, it should work
1489        let state = {
1490            let mut state = state;
1491            state.policy_factory = crate::test_utils::policy_factory(
1492                "example.com",
1493                serde_json::json!({
1494                    "admin_clients": [client_id]
1495                }),
1496            )
1497            .await
1498            .unwrap();
1499            state
1500        };
1501
1502        let request =
1503            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1504                "grant_type": "client_credentials",
1505                "client_id": client_id,
1506                "client_secret": client_secret,
1507                "scope": "urn:mas:admin"
1508            }));
1509
1510        let response = state.request(request).await;
1511        response.assert_status(StatusCode::OK);
1512
1513        let response: AccessTokenResponse = response.json();
1514        assert!(response.refresh_token.is_none());
1515        assert!(response.expires_in.is_some());
1516        assert_eq!(response.scope, Some("urn:mas:admin".parse().unwrap()));
1517
1518        // Revoke the token
1519        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1520            "token": response.access_token,
1521            "client_id": client_id,
1522            "client_secret": client_secret,
1523        }));
1524
1525        let response = state.request(request).await;
1526        response.assert_status(StatusCode::OK);
1527    }
1528
1529    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1530    async fn test_device_code_grant(pool: PgPool) {
1531        setup();
1532        let state = TestState::from_pool(pool).await.unwrap();
1533
1534        // Provision a client
1535        let request =
1536            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1537                "client_uri": "https://example.com/",
1538                "token_endpoint_auth_method": "none",
1539                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"],
1540                "response_types": [],
1541            }));
1542
1543        let response = state.request(request).await;
1544        response.assert_status(StatusCode::CREATED);
1545
1546        let response: ClientRegistrationResponse = response.json();
1547        let client_id = response.client_id;
1548
1549        // Start a device code grant
1550        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1551            serde_json::json!({
1552                "client_id": client_id,
1553                "scope": "openid",
1554            }),
1555        );
1556        let response = state.request(request).await;
1557        response.assert_status(StatusCode::OK);
1558
1559        let device_grant: DeviceAuthorizationResponse = response.json();
1560
1561        // Poll the token endpoint, it should be pending
1562        let request =
1563            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1564                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1565                "device_code": device_grant.device_code,
1566                "client_id": client_id,
1567            }));
1568        let response = state.request(request).await;
1569        response.assert_status(StatusCode::FORBIDDEN);
1570
1571        let ClientError { error, .. } = response.json();
1572        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1573
1574        // Let's provision a user and create a browser session for them. This part is
1575        // hard to test with just HTTP requests, so we'll use the repository
1576        // directly.
1577        let mut repo = state.repository().await.unwrap();
1578
1579        let user = repo
1580            .user()
1581            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1582            .await
1583            .unwrap();
1584
1585        let browser_session = repo
1586            .browser_session()
1587            .add(&mut state.rng(), &state.clock, &user, None)
1588            .await
1589            .unwrap();
1590
1591        // Find the grant
1592        let grant = repo
1593            .oauth2_device_code_grant()
1594            .find_by_user_code(&device_grant.user_code)
1595            .await
1596            .unwrap()
1597            .unwrap();
1598
1599        // And fulfill it
1600        let grant = repo
1601            .oauth2_device_code_grant()
1602            .fulfill(&state.clock, grant, &browser_session)
1603            .await
1604            .unwrap();
1605
1606        repo.save().await.unwrap();
1607
1608        // Now call the token endpoint to get an access token.
1609        let request =
1610            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1611                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1612                "device_code": grant.device_code,
1613                "client_id": client_id,
1614            }));
1615
1616        let response = state.request(request).await;
1617        response.assert_status(StatusCode::OK);
1618
1619        let response: AccessTokenResponse = response.json();
1620
1621        // Check that the token is valid
1622        assert!(state.is_access_token_valid(&response.access_token).await);
1623        // We advertised the refresh token grant type, so we should have a refresh token
1624        assert!(response.refresh_token.is_some());
1625        // We asked for the openid scope, so we should have an ID token
1626        assert!(response.id_token.is_some());
1627
1628        // Calling it again should fail
1629        let request =
1630            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1631                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1632                "device_code": grant.device_code,
1633                "client_id": client_id,
1634            }));
1635        let response = state.request(request).await;
1636        response.assert_status(StatusCode::BAD_REQUEST);
1637
1638        let ClientError { error, .. } = response.json();
1639        assert_eq!(error, ClientErrorCode::InvalidGrant);
1640
1641        // Do another grant and make it expire
1642        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1643            serde_json::json!({
1644                "client_id": client_id,
1645                "scope": "openid",
1646            }),
1647        );
1648        let response = state.request(request).await;
1649        response.assert_status(StatusCode::OK);
1650
1651        let device_grant: DeviceAuthorizationResponse = response.json();
1652
1653        // Poll the token endpoint, it should be pending
1654        let request =
1655            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1656                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1657                "device_code": device_grant.device_code,
1658                "client_id": client_id,
1659            }));
1660        let response = state.request(request).await;
1661        response.assert_status(StatusCode::FORBIDDEN);
1662
1663        let ClientError { error, .. } = response.json();
1664        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1665
1666        state.clock.advance(Duration::try_hours(1).unwrap());
1667
1668        // Poll again, it should be expired
1669        let request =
1670            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1671                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1672                "device_code": device_grant.device_code,
1673                "client_id": client_id,
1674            }));
1675        let response = state.request(request).await;
1676        response.assert_status(StatusCode::FORBIDDEN);
1677
1678        let ClientError { error, .. } = response.json();
1679        assert_eq!(error, ClientErrorCode::ExpiredToken);
1680
1681        // Do another grant and reject it
1682        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1683            serde_json::json!({
1684                "client_id": client_id,
1685                "scope": "openid",
1686            }),
1687        );
1688        let response = state.request(request).await;
1689        response.assert_status(StatusCode::OK);
1690
1691        let device_grant: DeviceAuthorizationResponse = response.json();
1692
1693        // Find the grant and reject it
1694        let mut repo = state.repository().await.unwrap();
1695
1696        // Find the grant
1697        let grant = repo
1698            .oauth2_device_code_grant()
1699            .find_by_user_code(&device_grant.user_code)
1700            .await
1701            .unwrap()
1702            .unwrap();
1703
1704        // And reject it
1705        let grant = repo
1706            .oauth2_device_code_grant()
1707            .reject(&state.clock, grant, &browser_session)
1708            .await
1709            .unwrap();
1710
1711        repo.save().await.unwrap();
1712
1713        // Poll the token endpoint, it should be rejected
1714        let request =
1715            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1716                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1717                "device_code": grant.device_code,
1718                "client_id": client_id,
1719            }));
1720        let response = state.request(request).await;
1721        response.assert_status(StatusCode::FORBIDDEN);
1722
1723        let ClientError { error, .. } = response.json();
1724        assert_eq!(error, ClientErrorCode::AccessDenied);
1725    }
1726
1727    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1728    async fn test_unsupported_grant(pool: PgPool) {
1729        setup();
1730        let state = TestState::from_pool(pool).await.unwrap();
1731
1732        // Provision a client
1733        let request =
1734            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1735                "client_uri": "https://example.com/",
1736                "redirect_uris": ["https://example.com/callback"],
1737                "token_endpoint_auth_method": "client_secret_post",
1738                "grant_types": ["password"],
1739                "response_types": [],
1740            }));
1741
1742        let response = state.request(request).await;
1743        response.assert_status(StatusCode::CREATED);
1744
1745        let response: ClientRegistrationResponse = response.json();
1746        let client_id = response.client_id;
1747        let client_secret = response.client_secret.expect("to have a client secret");
1748
1749        // Call the token endpoint with an unsupported grant type
1750        let request =
1751            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1752                "grant_type": "password",
1753                "client_id": client_id,
1754                "client_secret": client_secret,
1755                "username": "john",
1756                "password": "hunter2",
1757            }));
1758
1759        let response = state.request(request).await;
1760        response.assert_status(StatusCode::BAD_REQUEST);
1761        let ClientError { error, .. } = response.json();
1762        assert_eq!(error, ClientErrorCode::UnsupportedGrantType);
1763    }
1764}