1use std::sync::{Arc, LazyLock};
9
10use axum::{Json, extract::State, response::IntoResponse};
11use axum_extra::typed_header::TypedHeader;
12use chrono::Duration;
13use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
14use hyper::StatusCode;
15use mas_axum_utils::{
16 client_authorization::{ClientAuthorization, CredentialsVerificationError},
17 record_error,
18};
19use mas_data_model::{
20 AuthorizationGrantStage, BoxClock, BoxRng, Client, Clock, Device, DeviceCodeGrantState,
21 SiteConfig, TokenType,
22};
23use mas_i18n::DataLocale;
24use mas_keystore::{Encrypter, Keystore};
25use mas_matrix::HomeserverConnection;
26use mas_oidc_client::types::scope::ScopeToken;
27use mas_policy::Policy;
28use mas_router::UrlBuilder;
29use mas_storage::{
30 BoxRepository, RepositoryAccess,
31 oauth2::{
32 OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
33 OAuth2RefreshTokenRepository, OAuth2SessionRepository,
34 },
35 user::BrowserSessionRepository,
36};
37use mas_templates::{DeviceNameContext, TemplateContext, Templates};
38use oauth2_types::{
39 errors::{ClientError, ClientErrorCode},
40 pkce::CodeChallengeError,
41 requests::{
42 AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, ClientCredentialsGrant,
43 DeviceCodeGrant, GrantType, RefreshTokenGrant,
44 },
45 scope,
46};
47use opentelemetry::{Key, KeyValue, metrics::Counter};
48use thiserror::Error;
49use tracing::{debug, info, warn};
50use ulid::Ulid;
51
52use super::{generate_id_token, generate_token_pair};
53use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
54
55static TOKEN_REQUEST_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
56 METER
57 .u64_counter("mas.oauth2.token_request")
58 .with_description("How many OAuth 2.0 token requests have gone through")
59 .with_unit("{request}")
60 .build()
61});
62const GRANT_TYPE: Key = Key::from_static_str("grant_type");
63const RESULT: Key = Key::from_static_str("successful");
64
65#[derive(Debug, Error)]
66pub(crate) enum RouteError {
67 #[error(transparent)]
68 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
69
70 #[error("bad request")]
71 BadRequest,
72
73 #[error("pkce verification failed")]
74 PkceVerification(#[from] CodeChallengeError),
75
76 #[error("client not found")]
77 ClientNotFound,
78
79 #[error("client not allowed to use the token endpoint: {0}")]
80 ClientNotAllowed(Ulid),
81
82 #[error("invalid client credentials for client {client_id}")]
83 InvalidClientCredentials {
84 client_id: Ulid,
85 #[source]
86 source: CredentialsVerificationError,
87 },
88
89 #[error("could not verify client credentials for client {client_id}")]
90 ClientCredentialsVerification {
91 client_id: Ulid,
92 #[source]
93 source: CredentialsVerificationError,
94 },
95
96 #[error("grant not found")]
97 GrantNotFound,
98
99 #[error("invalid grant {0}")]
100 InvalidGrant(Ulid),
101
102 #[error("refresh token not found")]
103 RefreshTokenNotFound,
104
105 #[error("refresh token {0} is invalid")]
106 RefreshTokenInvalid(Ulid),
107
108 #[error("session {0} is invalid")]
109 SessionInvalid(Ulid),
110
111 #[error("client id mismatch: expected {expected}, got {actual}")]
112 ClientIDMismatch { expected: Ulid, actual: Ulid },
113
114 #[error("policy denied the request: {0}")]
115 DeniedByPolicy(mas_policy::EvaluationResult),
116
117 #[error("unsupported grant type")]
118 UnsupportedGrantType,
119
120 #[error("client {0} is not authorized to use this grant type")]
121 UnauthorizedClient(Ulid),
122
123 #[error("unexpected client {was} (expected {expected})")]
124 UnexptectedClient { was: Ulid, expected: Ulid },
125
126 #[error("failed to load browser session {0}")]
127 NoSuchBrowserSession(Ulid),
128
129 #[error("failed to load oauth session {0}")]
130 NoSuchOAuthSession(Ulid),
131
132 #[error(
133 "failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
134 )]
135 NoSuchNextRefreshToken { next: Ulid, previous: Ulid },
136
137 #[error(
138 "failed to load the access token ({access_token:?}) associated with the next refresh token ({refresh_token:?})"
139 )]
140 NoSuchNextAccessToken {
141 access_token: Ulid,
142 refresh_token: Ulid,
143 },
144
145 #[error("device code grant expired")]
146 DeviceCodeExpired,
147
148 #[error("device code grant is still pending")]
149 DeviceCodePending,
150
151 #[error("device code grant was rejected")]
152 DeviceCodeRejected,
153
154 #[error("device code grant was already exchanged")]
155 DeviceCodeExchanged,
156
157 #[error("failed to provision device")]
158 ProvisionDeviceFailed(#[source] anyhow::Error),
159}
160
161impl IntoResponse for RouteError {
162 fn into_response(self) -> axum::response::Response {
163 let sentry_event_id = record_error!(
164 self,
165 Self::Internal(_)
166 | Self::ClientCredentialsVerification { .. }
167 | Self::NoSuchBrowserSession(_)
168 | Self::NoSuchOAuthSession(_)
169 | Self::ProvisionDeviceFailed(_)
170 | Self::NoSuchNextRefreshToken { .. }
171 | Self::NoSuchNextAccessToken { .. }
172 );
173
174 TOKEN_REQUEST_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
175
176 let response = match self {
177 Self::Internal(_)
178 | Self::ClientCredentialsVerification { .. }
179 | Self::NoSuchBrowserSession(_)
180 | Self::NoSuchOAuthSession(_)
181 | Self::ProvisionDeviceFailed(_)
182 | Self::NoSuchNextRefreshToken { .. }
183 | Self::NoSuchNextAccessToken { .. } => (
184 StatusCode::INTERNAL_SERVER_ERROR,
185 Json(ClientError::from(ClientErrorCode::ServerError)),
186 ),
187
188 Self::BadRequest => (
189 StatusCode::BAD_REQUEST,
190 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
191 ),
192
193 Self::PkceVerification(err) => (
194 StatusCode::BAD_REQUEST,
195 Json(
196 ClientError::from(ClientErrorCode::InvalidGrant)
197 .with_description(format!("PKCE verification failed: {err}")),
198 ),
199 ),
200
201 Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
202 StatusCode::UNAUTHORIZED,
203 Json(ClientError::from(ClientErrorCode::InvalidClient)),
204 ),
205
206 Self::ClientNotAllowed(_)
207 | Self::UnauthorizedClient(_)
208 | Self::UnexptectedClient { .. } => (
209 StatusCode::UNAUTHORIZED,
210 Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
211 ),
212
213 Self::DeniedByPolicy(evaluation) => (
214 StatusCode::FORBIDDEN,
215 Json(
216 ClientError::from(ClientErrorCode::InvalidScope).with_description(
217 evaluation
218 .violations
219 .into_iter()
220 .map(|violation| violation.msg)
221 .collect::<Vec<_>>()
222 .join(", "),
223 ),
224 ),
225 ),
226
227 Self::DeviceCodeRejected => (
228 StatusCode::FORBIDDEN,
229 Json(ClientError::from(ClientErrorCode::AccessDenied)),
230 ),
231
232 Self::DeviceCodeExpired => (
233 StatusCode::FORBIDDEN,
234 Json(ClientError::from(ClientErrorCode::ExpiredToken)),
235 ),
236
237 Self::DeviceCodePending => (
238 StatusCode::FORBIDDEN,
239 Json(ClientError::from(ClientErrorCode::AuthorizationPending)),
240 ),
241
242 Self::InvalidGrant(_)
243 | Self::DeviceCodeExchanged
244 | Self::RefreshTokenNotFound
245 | Self::RefreshTokenInvalid(_)
246 | Self::SessionInvalid(_)
247 | Self::ClientIDMismatch { .. }
248 | Self::GrantNotFound => (
249 StatusCode::BAD_REQUEST,
250 Json(ClientError::from(ClientErrorCode::InvalidGrant)),
251 ),
252
253 Self::UnsupportedGrantType => (
254 StatusCode::BAD_REQUEST,
255 Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
256 ),
257 };
258
259 (sentry_event_id, response).into_response()
260 }
261}
262
263impl_from_error_for_route!(mas_i18n::DataError);
264impl_from_error_for_route!(mas_templates::TemplateError);
265impl_from_error_for_route!(mas_storage::RepositoryError);
266impl_from_error_for_route!(mas_policy::EvaluationError);
267impl_from_error_for_route!(super::IdTokenSignatureError);
268
269#[tracing::instrument(
270 name = "handlers.oauth2.token.post",
271 fields(client.id = client_authorization.client_id()),
272 skip_all,
273)]
274pub(crate) async fn post(
275 mut rng: BoxRng,
276 clock: BoxClock,
277 State(http_client): State<reqwest::Client>,
278 State(key_store): State<Keystore>,
279 State(url_builder): State<UrlBuilder>,
280 activity_tracker: BoundActivityTracker,
281 mut repo: BoxRepository,
282 State(homeserver): State<Arc<dyn HomeserverConnection>>,
283 State(site_config): State<SiteConfig>,
284 State(encrypter): State<Encrypter>,
285 State(templates): State<Templates>,
286 policy: Policy,
287 user_agent: Option<TypedHeader<headers::UserAgent>>,
288 client_authorization: ClientAuthorization<AccessTokenRequest>,
289) -> Result<impl IntoResponse, RouteError> {
290 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
291 let client = client_authorization
292 .credentials
293 .fetch(&mut repo)
294 .await?
295 .ok_or(RouteError::ClientNotFound)?;
296
297 let method = client
298 .token_endpoint_auth_method
299 .as_ref()
300 .ok_or(RouteError::ClientNotAllowed(client.id))?;
301
302 client_authorization
303 .credentials
304 .verify(&http_client, &encrypter, method, &client)
305 .await
306 .map_err(|err| {
307 if err.is_internal() {
310 RouteError::ClientCredentialsVerification {
311 client_id: client.id,
312 source: err,
313 }
314 } else {
315 RouteError::InvalidClientCredentials {
316 client_id: client.id,
317 source: err,
318 }
319 }
320 })?;
321
322 let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
323
324 let grant_type = form.grant_type();
325
326 let (reply, repo) = match form {
327 AccessTokenRequest::AuthorizationCode(grant) => {
328 authorization_code_grant(
329 &mut rng,
330 &clock,
331 &activity_tracker,
332 &grant,
333 &client,
334 &key_store,
335 &url_builder,
336 &site_config,
337 repo,
338 &homeserver,
339 &templates,
340 user_agent,
341 )
342 .await?
343 }
344 AccessTokenRequest::RefreshToken(grant) => {
345 refresh_token_grant(
346 &mut rng,
347 &clock,
348 &activity_tracker,
349 &grant,
350 &client,
351 &site_config,
352 repo,
353 user_agent,
354 )
355 .await?
356 }
357 AccessTokenRequest::ClientCredentials(grant) => {
358 client_credentials_grant(
359 &mut rng,
360 &clock,
361 &activity_tracker,
362 &grant,
363 &client,
364 &site_config,
365 repo,
366 policy,
367 user_agent,
368 )
369 .await?
370 }
371 AccessTokenRequest::DeviceCode(grant) => {
372 device_code_grant(
373 &mut rng,
374 &clock,
375 &activity_tracker,
376 &grant,
377 &client,
378 &key_store,
379 &url_builder,
380 &site_config,
381 repo,
382 &homeserver,
383 user_agent,
384 )
385 .await?
386 }
387 _ => {
388 return Err(RouteError::UnsupportedGrantType);
389 }
390 };
391
392 repo.save().await?;
393
394 TOKEN_REQUEST_COUNTER.add(
395 1,
396 &[
397 KeyValue::new(GRANT_TYPE, grant_type),
398 KeyValue::new(RESULT, "success"),
399 ],
400 );
401
402 let mut headers = HeaderMap::new();
403 headers.typed_insert(CacheControl::new().with_no_store());
404 headers.typed_insert(Pragma::no_cache());
405
406 Ok((headers, Json(reply)))
407}
408
409async fn authorization_code_grant(
410 mut rng: &mut BoxRng,
411 clock: &impl Clock,
412 activity_tracker: &BoundActivityTracker,
413 grant: &AuthorizationCodeGrant,
414 client: &Client,
415 key_store: &Keystore,
416 url_builder: &UrlBuilder,
417 site_config: &SiteConfig,
418 mut repo: BoxRepository,
419 homeserver: &Arc<dyn HomeserverConnection>,
420 templates: &Templates,
421 user_agent: Option<String>,
422) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
423 if !client.grant_types.contains(&GrantType::AuthorizationCode) {
425 return Err(RouteError::UnauthorizedClient(client.id));
426 }
427
428 let authz_grant = repo
429 .oauth2_authorization_grant()
430 .find_by_code(&grant.code)
431 .await?
432 .ok_or(RouteError::GrantNotFound)?;
433
434 let now = clock.now();
435
436 let session_id = match authz_grant.stage {
437 AuthorizationGrantStage::Cancelled { cancelled_at } => {
438 debug!(%cancelled_at, "Authorization grant was cancelled");
439 return Err(RouteError::InvalidGrant(authz_grant.id));
440 }
441 AuthorizationGrantStage::Exchanged {
442 exchanged_at,
443 fulfilled_at,
444 session_id,
445 } => {
446 warn!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
447
448 if now - exchanged_at > Duration::microseconds(20 * 1000 * 1000) {
450 warn!(oauth_session.id = %session_id, "Ending potentially compromised session");
451 let session = repo
452 .oauth2_session()
453 .lookup(session_id)
454 .await?
455 .ok_or(RouteError::NoSuchOAuthSession(session_id))?;
456
457 repo.oauth2_session().finish(clock, session).await?;
459 repo.save().await?;
460 }
462
463 return Err(RouteError::InvalidGrant(authz_grant.id));
464 }
465 AuthorizationGrantStage::Pending => {
466 warn!("Authorization grant has not been fulfilled yet");
467 return Err(RouteError::InvalidGrant(authz_grant.id));
468 }
469 AuthorizationGrantStage::Fulfilled {
470 session_id,
471 fulfilled_at,
472 } => {
473 if now - fulfilled_at > Duration::microseconds(10 * 60 * 1000 * 1000) {
474 warn!("Code exchange took more than 10 minutes");
475 return Err(RouteError::InvalidGrant(authz_grant.id));
476 }
477
478 session_id
479 }
480 };
481
482 let mut session = repo
483 .oauth2_session()
484 .lookup(session_id)
485 .await?
486 .ok_or(RouteError::NoSuchOAuthSession(session_id))?;
487
488 let lang: DataLocale = authz_grant.locale.as_deref().unwrap_or("en").parse()?;
490 let ctx = DeviceNameContext::new(client.clone(), user_agent.clone()).with_language(lang);
491 let device_name = templates.render_device_name(&ctx)?;
492
493 if let Some(user_agent) = user_agent {
494 session = repo
495 .oauth2_session()
496 .record_user_agent(session, user_agent)
497 .await?;
498 }
499
500 let code = authz_grant
502 .code
503 .as_ref()
504 .ok_or(RouteError::InvalidGrant(authz_grant.id))?;
505
506 if client.id != session.client_id {
507 return Err(RouteError::UnexptectedClient {
508 was: client.id,
509 expected: session.client_id,
510 });
511 }
512
513 match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
514 (None, None) => {}
515 (Some(_), None) | (None, Some(_)) => return Err(RouteError::BadRequest),
517 (Some(pkce), Some(verifier)) => {
519 pkce.verify(verifier)?;
520 }
521 }
522
523 let Some(user_session_id) = session.user_session_id else {
524 tracing::warn!("No user session associated with this OAuth2 session");
525 return Err(RouteError::InvalidGrant(authz_grant.id));
526 };
527
528 let browser_session = repo
529 .browser_session()
530 .lookup(user_session_id)
531 .await?
532 .ok_or(RouteError::NoSuchBrowserSession(user_session_id))?;
533
534 let last_authentication = repo
535 .browser_session()
536 .get_last_authentication(&browser_session)
537 .await?;
538
539 let ttl = site_config.access_token_ttl;
540 let (access_token, refresh_token) =
541 generate_token_pair(&mut rng, clock, &mut repo, &session, ttl).await?;
542
543 let id_token = if session.scope.contains(&scope::OPENID) {
544 Some(generate_id_token(
545 &mut rng,
546 clock,
547 url_builder,
548 key_store,
549 client,
550 Some(&authz_grant),
551 &browser_session,
552 Some(&access_token),
553 last_authentication.as_ref(),
554 )?)
555 } else {
556 None
557 };
558
559 let mut params = AccessTokenResponse::new(access_token.access_token)
560 .with_expires_in(ttl)
561 .with_refresh_token(refresh_token.refresh_token)
562 .with_scope(session.scope.clone());
563
564 if let Some(id_token) = id_token {
565 params = params.with_id_token(id_token);
566 }
567
568 repo.user()
570 .acquire_lock_for_sync(&browser_session.user)
571 .await?;
572
573 for scope in &*session.scope {
575 if let Some(device) = Device::from_scope_token(scope) {
576 homeserver
577 .upsert_device(
578 &browser_session.user.username,
579 device.as_str(),
580 Some(&device_name),
581 )
582 .await
583 .map_err(RouteError::ProvisionDeviceFailed)?;
584 }
585 }
586
587 repo.oauth2_authorization_grant()
588 .exchange(clock, authz_grant)
589 .await?;
590
591 activity_tracker
595 .record_oauth2_session(clock, &session)
596 .await;
597
598 Ok((params, repo))
599}
600
601async fn refresh_token_grant(
602 rng: &mut BoxRng,
603 clock: &impl Clock,
604 activity_tracker: &BoundActivityTracker,
605 grant: &RefreshTokenGrant,
606 client: &Client,
607 site_config: &SiteConfig,
608 mut repo: BoxRepository,
609 user_agent: Option<String>,
610) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
611 if !client.grant_types.contains(&GrantType::RefreshToken) {
613 return Err(RouteError::UnauthorizedClient(client.id));
614 }
615
616 let refresh_token = repo
617 .oauth2_refresh_token()
618 .find_by_token(&grant.refresh_token)
619 .await?
620 .ok_or(RouteError::RefreshTokenNotFound)?;
621
622 let mut session = repo
623 .oauth2_session()
624 .lookup(refresh_token.session_id)
625 .await?
626 .ok_or(RouteError::NoSuchOAuthSession(refresh_token.session_id))?;
627
628 if let Some(user_agent) = user_agent {
631 session = repo
632 .oauth2_session()
633 .record_user_agent(session, user_agent)
634 .await?;
635 }
636
637 if !session.is_valid() {
638 return Err(RouteError::SessionInvalid(session.id));
639 }
640
641 if client.id != session.client_id {
642 return Err(RouteError::ClientIDMismatch {
644 expected: session.client_id,
645 actual: client.id,
646 });
647 }
648
649 if !refresh_token.is_valid() {
650 let Some(next_refresh_token_id) = refresh_token.next_refresh_token_id() else {
655 return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
658 };
659
660 let Some(next_refresh_token) = repo
661 .oauth2_refresh_token()
662 .lookup(next_refresh_token_id)
663 .await?
664 else {
665 return Err(RouteError::NoSuchNextRefreshToken {
666 next: next_refresh_token_id,
667 previous: refresh_token.id,
668 });
669 };
670
671 if !next_refresh_token.is_valid() {
673 return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
675 }
676
677 if let Some(access_token_id) = next_refresh_token.access_token_id {
704 let next_access_token = repo
706 .oauth2_access_token()
707 .lookup(access_token_id)
708 .await?
709 .ok_or(RouteError::NoSuchNextAccessToken {
710 access_token: access_token_id,
711 refresh_token: next_refresh_token_id,
712 })?;
713
714 if next_access_token.is_used() {
715 return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
717 }
718
719 repo.oauth2_access_token()
721 .revoke(clock, next_access_token)
722 .await?;
723 }
724
725 info!(
729 oauth2_session.id = %session.id,
730 oauth2_client.id = %client.id,
731 %refresh_token.id,
732 "Refresh token already used, but issued refresh and access tokens are unused. Assuming those were lost; revoking those and reissuing new ones."
733 );
734
735 repo.oauth2_refresh_token()
736 .revoke(clock, next_refresh_token)
737 .await?;
738 }
739
740 activity_tracker
741 .record_oauth2_session(clock, &session)
742 .await;
743
744 let ttl = site_config.access_token_ttl;
745 let (new_access_token, new_refresh_token) =
746 generate_token_pair(rng, clock, &mut repo, &session, ttl).await?;
747
748 let refresh_token = repo
749 .oauth2_refresh_token()
750 .consume(clock, refresh_token, &new_refresh_token)
751 .await?;
752
753 if let Some(access_token_id) = refresh_token.access_token_id {
754 let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
755 if let Some(access_token) = access_token {
756 if !access_token.state.is_revoked() {
758 repo.oauth2_access_token()
759 .revoke(clock, access_token)
760 .await?;
761 }
762 }
763 }
764
765 let params = AccessTokenResponse::new(new_access_token.access_token)
766 .with_expires_in(ttl)
767 .with_refresh_token(new_refresh_token.refresh_token)
768 .with_scope(session.scope);
769
770 Ok((params, repo))
771}
772
773async fn client_credentials_grant(
774 rng: &mut BoxRng,
775 clock: &impl Clock,
776 activity_tracker: &BoundActivityTracker,
777 grant: &ClientCredentialsGrant,
778 client: &Client,
779 site_config: &SiteConfig,
780 mut repo: BoxRepository,
781 mut policy: Policy,
782 user_agent: Option<String>,
783) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
784 if !client.grant_types.contains(&GrantType::ClientCredentials) {
786 return Err(RouteError::UnauthorizedClient(client.id));
787 }
788
789 let scope = grant
791 .scope
792 .clone()
793 .unwrap_or_else(|| std::iter::empty::<ScopeToken>().collect());
794
795 let res = policy
797 .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
798 user: None,
799 client,
800 session_counts: None,
801 scope: &scope,
802 grant_type: mas_policy::GrantType::ClientCredentials,
803 requester: mas_policy::Requester {
804 ip_address: activity_tracker.ip(),
805 user_agent: user_agent.clone(),
806 },
807 })
808 .await?;
809 if !res.valid() {
810 return Err(RouteError::DeniedByPolicy(res));
811 }
812
813 let mut session = repo
815 .oauth2_session()
816 .add_from_client_credentials(rng, clock, client, scope)
817 .await?;
818
819 if let Some(user_agent) = user_agent {
820 session = repo
821 .oauth2_session()
822 .record_user_agent(session, user_agent)
823 .await?;
824 }
825
826 let ttl = site_config.access_token_ttl;
827 let access_token_str = TokenType::AccessToken.generate(rng);
828
829 let access_token = repo
830 .oauth2_access_token()
831 .add(rng, clock, &session, access_token_str, Some(ttl))
832 .await?;
833
834 let mut params = AccessTokenResponse::new(access_token.access_token).with_expires_in(ttl);
835
836 activity_tracker
840 .record_oauth2_session(clock, &session)
841 .await;
842
843 if !session.scope.is_empty() {
844 params = params.with_scope(session.scope);
846 }
847
848 Ok((params, repo))
849}
850
851async fn device_code_grant(
852 rng: &mut BoxRng,
853 clock: &impl Clock,
854 activity_tracker: &BoundActivityTracker,
855 grant: &DeviceCodeGrant,
856 client: &Client,
857 key_store: &Keystore,
858 url_builder: &UrlBuilder,
859 site_config: &SiteConfig,
860 mut repo: BoxRepository,
861 homeserver: &Arc<dyn HomeserverConnection>,
862 user_agent: Option<String>,
863) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
864 if !client.grant_types.contains(&GrantType::DeviceCode) {
866 return Err(RouteError::UnauthorizedClient(client.id));
867 }
868
869 let grant = repo
870 .oauth2_device_code_grant()
871 .find_by_device_code(&grant.device_code)
872 .await?
873 .ok_or(RouteError::GrantNotFound)?;
874
875 if client.id != grant.client_id {
877 return Err(RouteError::ClientIDMismatch {
878 expected: grant.client_id,
879 actual: client.id,
880 });
881 }
882
883 if grant.expires_at < clock.now() {
884 return Err(RouteError::DeviceCodeExpired);
885 }
886
887 let browser_session_id = match &grant.state {
888 DeviceCodeGrantState::Pending => {
889 return Err(RouteError::DeviceCodePending);
890 }
891 DeviceCodeGrantState::Rejected { .. } => {
892 return Err(RouteError::DeviceCodeRejected);
893 }
894 DeviceCodeGrantState::Exchanged { .. } => {
895 return Err(RouteError::DeviceCodeExchanged);
896 }
897 DeviceCodeGrantState::Fulfilled {
898 browser_session_id, ..
899 } => *browser_session_id,
900 };
901
902 let browser_session = repo
903 .browser_session()
904 .lookup(browser_session_id)
905 .await?
906 .ok_or(RouteError::NoSuchBrowserSession(browser_session_id))?;
907
908 let mut session = repo
910 .oauth2_session()
911 .add_from_browser_session(rng, clock, client, &browser_session, grant.scope.clone())
912 .await?;
913
914 repo.oauth2_device_code_grant()
915 .exchange(clock, grant, &session)
916 .await?;
917
918 if let Some(user_agent) = user_agent {
920 session = repo
921 .oauth2_session()
922 .record_user_agent(session, user_agent)
923 .await?;
924 }
925
926 let ttl = site_config.access_token_ttl;
927 let access_token_str = TokenType::AccessToken.generate(rng);
928
929 let access_token = repo
930 .oauth2_access_token()
931 .add(rng, clock, &session, access_token_str, Some(ttl))
932 .await?;
933
934 let mut params =
935 AccessTokenResponse::new(access_token.access_token.clone()).with_expires_in(ttl);
936
937 if client.grant_types.contains(&GrantType::RefreshToken) {
940 let refresh_token_str = TokenType::RefreshToken.generate(rng);
941
942 let refresh_token = repo
943 .oauth2_refresh_token()
944 .add(rng, clock, &session, &access_token, refresh_token_str)
945 .await?;
946
947 params = params.with_refresh_token(refresh_token.refresh_token);
948 }
949
950 if session.scope.contains(&scope::OPENID) {
952 let id_token = generate_id_token(
953 rng,
954 clock,
955 url_builder,
956 key_store,
957 client,
958 None,
959 &browser_session,
960 Some(&access_token),
961 None,
962 )?;
963
964 params = params.with_id_token(id_token);
965 }
966
967 repo.user()
969 .acquire_lock_for_sync(&browser_session.user)
970 .await?;
971
972 for scope in &*session.scope {
974 if let Some(device) = Device::from_scope_token(scope) {
975 homeserver
976 .upsert_device(&browser_session.user.username, device.as_str(), None)
977 .await
978 .map_err(RouteError::ProvisionDeviceFailed)?;
979 }
980 }
981
982 activity_tracker
986 .record_oauth2_session(clock, &session)
987 .await;
988
989 if !session.scope.is_empty() {
990 params = params.with_scope(session.scope);
992 }
993
994 Ok((params, repo))
995}
996
997#[cfg(test)]
998mod tests {
999 use hyper::Request;
1000 use mas_data_model::{AccessToken, AuthorizationCode, RefreshToken};
1001 use mas_router::SimpleRoute;
1002 use oauth2_types::{
1003 registration::ClientRegistrationResponse,
1004 requests::{DeviceAuthorizationResponse, ResponseMode},
1005 scope::{OPENID, Scope},
1006 };
1007 use sqlx::PgPool;
1008
1009 use super::*;
1010 use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
1011
1012 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1013 async fn test_auth_code_grant(pool: PgPool) {
1014 setup();
1015 let state = TestState::from_pool(pool).await.unwrap();
1016
1017 let request =
1019 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1020 "client_uri": "https://example.com/",
1021 "redirect_uris": ["https://example.com/callback"],
1022 "token_endpoint_auth_method": "none",
1023 "response_types": ["code"],
1024 "grant_types": ["authorization_code"],
1025 }));
1026
1027 let response = state.request(request).await;
1028 response.assert_status(StatusCode::CREATED);
1029
1030 let ClientRegistrationResponse { client_id, .. } = response.json();
1031
1032 let mut repo = state.repository().await.unwrap();
1035
1036 let user = repo
1037 .user()
1038 .add(&mut state.rng(), &state.clock, "alice".to_owned())
1039 .await
1040 .unwrap();
1041
1042 let browser_session = repo
1043 .browser_session()
1044 .add(&mut state.rng(), &state.clock, &user, None)
1045 .await
1046 .unwrap();
1047
1048 let client = repo
1050 .oauth2_client()
1051 .find_by_client_id(&client_id)
1052 .await
1053 .unwrap()
1054 .unwrap();
1055
1056 let code = "thisisaverysecurecode";
1058 let grant = repo
1059 .oauth2_authorization_grant()
1060 .add(
1061 &mut state.rng(),
1062 &state.clock,
1063 &client,
1064 "https://example.com/redirect".parse().unwrap(),
1065 Scope::from_iter([OPENID]),
1066 Some(AuthorizationCode {
1067 code: code.to_owned(),
1068 pkce: None,
1069 }),
1070 Some("state".to_owned()),
1071 Some("nonce".to_owned()),
1072 ResponseMode::Query,
1073 false,
1074 None,
1075 None,
1076 )
1077 .await
1078 .unwrap();
1079
1080 let session = repo
1081 .oauth2_session()
1082 .add_from_browser_session(
1083 &mut state.rng(),
1084 &state.clock,
1085 &client,
1086 &browser_session,
1087 grant.scope.clone(),
1088 )
1089 .await
1090 .unwrap();
1091
1092 let grant = repo
1094 .oauth2_authorization_grant()
1095 .fulfill(&state.clock, &session, grant)
1096 .await
1097 .unwrap();
1098
1099 repo.save().await.unwrap();
1100
1101 let request =
1103 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1104 "grant_type": "authorization_code",
1105 "code": code,
1106 "redirect_uri": grant.redirect_uri,
1107 "client_id": client.client_id,
1108 }));
1109
1110 let response = state.request(request).await;
1111 response.assert_status(StatusCode::OK);
1112
1113 let AccessTokenResponse { access_token, .. } = response.json();
1114
1115 assert!(state.is_access_token_valid(&access_token).await);
1117
1118 let request =
1120 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1121 "grant_type": "authorization_code",
1122 "code": code,
1123 "redirect_uri": grant.redirect_uri,
1124 "client_id": client.client_id,
1125 }));
1126
1127 let response = state.request(request).await;
1128 response.assert_status(StatusCode::BAD_REQUEST);
1129 let error: ClientError = response.json();
1130 assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1131
1132 assert!(state.is_access_token_valid(&access_token).await);
1134
1135 state.clock.advance(Duration::try_minutes(1).unwrap());
1137
1138 let request =
1140 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1141 "grant_type": "authorization_code",
1142 "code": code,
1143 "redirect_uri": grant.redirect_uri,
1144 "client_id": client.client_id,
1145 }));
1146
1147 let response = state.request(request).await;
1148 response.assert_status(StatusCode::BAD_REQUEST);
1149 let error: ClientError = response.json();
1150 assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1151
1152 assert!(!state.is_access_token_valid(&access_token).await);
1154
1155 let mut repo = state.repository().await.unwrap();
1157 let code = "thisisanothercode";
1158 let grant = repo
1159 .oauth2_authorization_grant()
1160 .add(
1161 &mut state.rng(),
1162 &state.clock,
1163 &client,
1164 "https://example.com/redirect".parse().unwrap(),
1165 Scope::from_iter([OPENID]),
1166 Some(AuthorizationCode {
1167 code: code.to_owned(),
1168 pkce: None,
1169 }),
1170 Some("state".to_owned()),
1171 Some("nonce".to_owned()),
1172 ResponseMode::Query,
1173 false,
1174 None,
1175 None,
1176 )
1177 .await
1178 .unwrap();
1179
1180 let session = repo
1181 .oauth2_session()
1182 .add_from_browser_session(
1183 &mut state.rng(),
1184 &state.clock,
1185 &client,
1186 &browser_session,
1187 grant.scope.clone(),
1188 )
1189 .await
1190 .unwrap();
1191
1192 let grant = repo
1194 .oauth2_authorization_grant()
1195 .fulfill(&state.clock, &session, grant)
1196 .await
1197 .unwrap();
1198
1199 repo.save().await.unwrap();
1200
1201 state
1203 .clock
1204 .advance(Duration::microseconds(15 * 60 * 1000 * 1000));
1205
1206 let request =
1208 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1209 "grant_type": "authorization_code",
1210 "code": code,
1211 "redirect_uri": grant.redirect_uri,
1212 "client_id": client.client_id,
1213 }));
1214
1215 let response = state.request(request).await;
1216 response.assert_status(StatusCode::BAD_REQUEST);
1217 let ClientError { error, .. } = response.json();
1218 assert_eq!(error, ClientErrorCode::InvalidGrant);
1219 }
1220
1221 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1222 async fn test_refresh_token_grant(pool: PgPool) {
1223 setup();
1224 let state = TestState::from_pool(pool).await.unwrap();
1225
1226 let request =
1228 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1229 "client_uri": "https://example.com/",
1230 "redirect_uris": ["https://example.com/callback"],
1231 "token_endpoint_auth_method": "none",
1232 "response_types": ["code"],
1233 "grant_types": ["authorization_code", "refresh_token"],
1234 }));
1235
1236 let response = state.request(request).await;
1237 response.assert_status(StatusCode::CREATED);
1238
1239 let ClientRegistrationResponse { client_id, .. } = response.json();
1240
1241 let mut repo = state.repository().await.unwrap();
1244
1245 let user = repo
1246 .user()
1247 .add(&mut state.rng(), &state.clock, "alice".to_owned())
1248 .await
1249 .unwrap();
1250
1251 let browser_session = repo
1252 .browser_session()
1253 .add(&mut state.rng(), &state.clock, &user, None)
1254 .await
1255 .unwrap();
1256
1257 let client = repo
1259 .oauth2_client()
1260 .find_by_client_id(&client_id)
1261 .await
1262 .unwrap()
1263 .unwrap();
1264
1265 let session = repo
1267 .oauth2_session()
1268 .add_from_browser_session(
1269 &mut state.rng(),
1270 &state.clock,
1271 &client,
1272 &browser_session,
1273 Scope::from_iter([OPENID]),
1274 )
1275 .await
1276 .unwrap();
1277
1278 let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1279 generate_token_pair(
1280 &mut state.rng(),
1281 &state.clock,
1282 &mut repo,
1283 &session,
1284 Duration::microseconds(5 * 60 * 1000 * 1000),
1285 )
1286 .await
1287 .unwrap();
1288
1289 repo.save().await.unwrap();
1290
1291 assert!(state.is_access_token_valid(&access_token).await);
1293
1294 let request =
1296 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1297 "grant_type": "refresh_token",
1298 "refresh_token": refresh_token,
1299 "client_id": client.client_id,
1300 }));
1301
1302 let response = state.request(request).await;
1303 response.assert_status(StatusCode::OK);
1304
1305 let old_access_token = access_token;
1306 let old_refresh_token = refresh_token;
1307 let response: AccessTokenResponse = response.json();
1308 let access_token = response.access_token;
1309 let refresh_token = response.refresh_token.expect("to have a refresh token");
1310
1311 assert!(state.is_access_token_valid(&access_token).await);
1313
1314 assert!(!state.is_access_token_valid(&old_access_token).await);
1316
1317 let request =
1319 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1320 "grant_type": "refresh_token",
1321 "refresh_token": old_refresh_token,
1322 "client_id": client.client_id,
1323 }));
1324
1325 let response = state.request(request).await;
1326 response.assert_status(StatusCode::BAD_REQUEST);
1327 let ClientError { error, .. } = response.json();
1328 assert_eq!(error, ClientErrorCode::InvalidGrant);
1329
1330 let request =
1332 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1333 "grant_type": "refresh_token",
1334 "refresh_token": refresh_token,
1335 "client_id": client.client_id,
1336 }));
1337
1338 let response = state.request(request).await;
1339 response.assert_status(StatusCode::OK);
1340 let _: AccessTokenResponse = response.json();
1341 }
1342
1343 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1344 async fn test_double_refresh(pool: PgPool) {
1345 setup();
1346 let state = TestState::from_pool(pool).await.unwrap();
1347
1348 let request =
1350 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1351 "client_uri": "https://example.com/",
1352 "redirect_uris": ["https://example.com/callback"],
1353 "token_endpoint_auth_method": "none",
1354 "response_types": ["code"],
1355 "grant_types": ["authorization_code", "refresh_token"],
1356 }));
1357
1358 let response = state.request(request).await;
1359 response.assert_status(StatusCode::CREATED);
1360
1361 let ClientRegistrationResponse { client_id, .. } = response.json();
1362
1363 let mut repo = state.repository().await.unwrap();
1366
1367 let user = repo
1368 .user()
1369 .add(&mut state.rng(), &state.clock, "alice".to_owned())
1370 .await
1371 .unwrap();
1372
1373 let browser_session = repo
1374 .browser_session()
1375 .add(&mut state.rng(), &state.clock, &user, None)
1376 .await
1377 .unwrap();
1378
1379 let client = repo
1381 .oauth2_client()
1382 .find_by_client_id(&client_id)
1383 .await
1384 .unwrap()
1385 .unwrap();
1386
1387 let session = repo
1389 .oauth2_session()
1390 .add_from_browser_session(
1391 &mut state.rng(),
1392 &state.clock,
1393 &client,
1394 &browser_session,
1395 Scope::from_iter([OPENID]),
1396 )
1397 .await
1398 .unwrap();
1399
1400 let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1401 generate_token_pair(
1402 &mut state.rng(),
1403 &state.clock,
1404 &mut repo,
1405 &session,
1406 Duration::microseconds(5 * 60 * 1000 * 1000),
1407 )
1408 .await
1409 .unwrap();
1410
1411 repo.save().await.unwrap();
1412
1413 assert!(state.is_access_token_valid(&access_token).await);
1415
1416 let request =
1418 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1419 "grant_type": "refresh_token",
1420 "refresh_token": refresh_token,
1421 "client_id": client.client_id,
1422 }));
1423
1424 let first_response = state.request(request).await;
1425 first_response.assert_status(StatusCode::OK);
1426 let first_response: AccessTokenResponse = first_response.json();
1427
1428 let request =
1431 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1432 "grant_type": "refresh_token",
1433 "refresh_token": refresh_token,
1434 "client_id": client.client_id,
1435 }));
1436
1437 let second_response = state.request(request).await;
1438 second_response.assert_status(StatusCode::OK);
1439 let second_response: AccessTokenResponse = second_response.json();
1440
1441 assert_ne!(first_response.access_token, second_response.access_token);
1443 assert_ne!(first_response.refresh_token, second_response.refresh_token);
1444
1445 assert!(
1447 !state
1448 .is_access_token_valid(&first_response.access_token)
1449 .await
1450 );
1451
1452 assert!(
1454 state
1455 .is_access_token_valid(&second_response.access_token)
1456 .await
1457 );
1458
1459 let request =
1462 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1463 "grant_type": "refresh_token",
1464 "refresh_token": refresh_token,
1465 "client_id": client.client_id,
1466 }));
1467
1468 let third_response = state.request(request).await;
1469 third_response.assert_status(StatusCode::BAD_REQUEST);
1470
1471 let request =
1477 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1478 "grant_type": "refresh_token",
1479 "refresh_token": second_response.refresh_token,
1480 "client_id": client.client_id,
1481 }));
1482
1483 let fourth_response = state.request(request).await;
1485 fourth_response.assert_status(StatusCode::OK);
1486 let fourth_response: AccessTokenResponse = fourth_response.json();
1487
1488 let request =
1490 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1491 "grant_type": "refresh_token",
1492 "refresh_token": fourth_response.refresh_token,
1493 "client_id": client.client_id,
1494 }));
1495
1496 let fifth_response = state.request(request).await;
1497 fifth_response.assert_status(StatusCode::OK);
1498 let fifth_response: AccessTokenResponse = fifth_response.json();
1499
1500 let request =
1503 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1504 "grant_type": "refresh_token",
1505 "refresh_token": second_response.refresh_token,
1506 "client_id": client.client_id,
1507 }));
1508
1509 let sixth_response = state.request(request).await;
1510 sixth_response.assert_status(StatusCode::BAD_REQUEST);
1511
1512 assert!(
1519 state
1520 .is_access_token_valid(&fifth_response.access_token)
1521 .await
1522 );
1523
1524 state.run_jobs_in_queue().await;
1528 state.clock.advance(Duration::days(31));
1529 state.run_jobs_in_queue().await;
1530
1531 let request =
1534 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1535 "grant_type": "refresh_token",
1536 "refresh_token": fourth_response.refresh_token,
1537 "client_id": client.client_id,
1538 }));
1539
1540 let seventh_response = state.request(request).await;
1541 seventh_response.assert_status(StatusCode::OK);
1542
1543 let request =
1545 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1546 "grant_type": "refresh_token",
1547 "refresh_token": fifth_response.refresh_token,
1548 "client_id": client.client_id,
1549 }));
1550
1551 let eighth_response = state.request(request).await;
1552 eighth_response.assert_status(StatusCode::BAD_REQUEST);
1553 }
1554
1555 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1556 async fn test_client_credentials(pool: PgPool) {
1557 setup();
1558 let state = TestState::from_pool(pool).await.unwrap();
1559
1560 let request =
1562 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1563 "client_uri": "https://example.com/",
1564 "token_endpoint_auth_method": "client_secret_post",
1565 "grant_types": ["client_credentials"],
1566 }));
1567
1568 let response = state.request(request).await;
1569 response.assert_status(StatusCode::CREATED);
1570
1571 let response: ClientRegistrationResponse = response.json();
1572 let client_id = response.client_id;
1573 let client_secret = response.client_secret.expect("to have a client secret");
1574
1575 let request =
1577 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1578 "grant_type": "client_credentials",
1579 "client_id": client_id,
1580 "client_secret": client_secret,
1581 }));
1582
1583 let response = state.request(request).await;
1584 response.assert_status(StatusCode::OK);
1585
1586 let response: AccessTokenResponse = response.json();
1587 assert!(response.refresh_token.is_none());
1588 assert!(response.expires_in.is_some());
1589 assert!(response.scope.is_none());
1590
1591 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1593 "token": response.access_token,
1594 "client_id": client_id,
1595 "client_secret": client_secret,
1596 }));
1597
1598 let response = state.request(request).await;
1599 response.assert_status(StatusCode::OK);
1600
1601 let request =
1603 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1604 "grant_type": "client_credentials",
1605 "client_id": client_id,
1606 "client_secret": client_secret,
1607 "scope": "urn:mas:graphql:*"
1608 }));
1609
1610 let response = state.request(request).await;
1611 response.assert_status(StatusCode::OK);
1612
1613 let response: AccessTokenResponse = response.json();
1614 assert!(response.refresh_token.is_none());
1615 assert!(response.expires_in.is_some());
1616 assert_eq!(response.scope, Some("urn:mas:graphql:*".parse().unwrap()));
1617
1618 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1620 "token": response.access_token,
1621 "client_id": client_id,
1622 "client_secret": client_secret,
1623 }));
1624
1625 let response = state.request(request).await;
1626 response.assert_status(StatusCode::OK);
1627
1628 let request =
1630 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1631 "grant_type": "client_credentials",
1632 "client_id": client_id,
1633 "client_secret": client_secret,
1634 "scope": "urn:mas:admin"
1635 }));
1636
1637 let response = state.request(request).await;
1638 response.assert_status(StatusCode::FORBIDDEN);
1639
1640 let ClientError { error, .. } = response.json();
1641 assert_eq!(error, ClientErrorCode::InvalidScope);
1642
1643 let state = {
1645 let mut state = state;
1646 state.policy_factory = crate::test_utils::policy_factory(
1647 "example.com",
1648 serde_json::json!({
1649 "admin_clients": [client_id]
1650 }),
1651 )
1652 .await
1653 .unwrap();
1654 state
1655 };
1656
1657 let request =
1658 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1659 "grant_type": "client_credentials",
1660 "client_id": client_id,
1661 "client_secret": client_secret,
1662 "scope": "urn:mas:admin"
1663 }));
1664
1665 let response = state.request(request).await;
1666 response.assert_status(StatusCode::OK);
1667
1668 let response: AccessTokenResponse = response.json();
1669 assert!(response.refresh_token.is_none());
1670 assert!(response.expires_in.is_some());
1671 assert_eq!(response.scope, Some("urn:mas:admin".parse().unwrap()));
1672
1673 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1675 "token": response.access_token,
1676 "client_id": client_id,
1677 "client_secret": client_secret,
1678 }));
1679
1680 let response = state.request(request).await;
1681 response.assert_status(StatusCode::OK);
1682 }
1683
1684 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1685 async fn test_device_code_grant(pool: PgPool) {
1686 setup();
1687 let state = TestState::from_pool(pool).await.unwrap();
1688
1689 let request =
1691 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1692 "client_uri": "https://example.com/",
1693 "token_endpoint_auth_method": "none",
1694 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"],
1695 "response_types": [],
1696 }));
1697
1698 let response = state.request(request).await;
1699 response.assert_status(StatusCode::CREATED);
1700
1701 let response: ClientRegistrationResponse = response.json();
1702 let client_id = response.client_id;
1703
1704 let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1706 serde_json::json!({
1707 "client_id": client_id,
1708 "scope": "openid",
1709 }),
1710 );
1711 let response = state.request(request).await;
1712 response.assert_status(StatusCode::OK);
1713
1714 let device_grant: DeviceAuthorizationResponse = response.json();
1715
1716 let request =
1718 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1719 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1720 "device_code": device_grant.device_code,
1721 "client_id": client_id,
1722 }));
1723 let response = state.request(request).await;
1724 response.assert_status(StatusCode::FORBIDDEN);
1725
1726 let ClientError { error, .. } = response.json();
1727 assert_eq!(error, ClientErrorCode::AuthorizationPending);
1728
1729 let mut repo = state.repository().await.unwrap();
1733
1734 let user = repo
1735 .user()
1736 .add(&mut state.rng(), &state.clock, "alice".to_owned())
1737 .await
1738 .unwrap();
1739
1740 let browser_session = repo
1741 .browser_session()
1742 .add(&mut state.rng(), &state.clock, &user, None)
1743 .await
1744 .unwrap();
1745
1746 let grant = repo
1748 .oauth2_device_code_grant()
1749 .find_by_user_code(&device_grant.user_code)
1750 .await
1751 .unwrap()
1752 .unwrap();
1753
1754 let grant = repo
1756 .oauth2_device_code_grant()
1757 .fulfill(&state.clock, grant, &browser_session)
1758 .await
1759 .unwrap();
1760
1761 repo.save().await.unwrap();
1762
1763 let request =
1765 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1766 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1767 "device_code": grant.device_code,
1768 "client_id": client_id,
1769 }));
1770
1771 let response = state.request(request).await;
1772 response.assert_status(StatusCode::OK);
1773
1774 let response: AccessTokenResponse = response.json();
1775
1776 assert!(state.is_access_token_valid(&response.access_token).await);
1778 assert!(response.refresh_token.is_some());
1780 assert!(response.id_token.is_some());
1782
1783 let request =
1785 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1786 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1787 "device_code": grant.device_code,
1788 "client_id": client_id,
1789 }));
1790 let response = state.request(request).await;
1791 response.assert_status(StatusCode::BAD_REQUEST);
1792
1793 let ClientError { error, .. } = response.json();
1794 assert_eq!(error, ClientErrorCode::InvalidGrant);
1795
1796 let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1798 serde_json::json!({
1799 "client_id": client_id,
1800 "scope": "openid",
1801 }),
1802 );
1803 let response = state.request(request).await;
1804 response.assert_status(StatusCode::OK);
1805
1806 let device_grant: DeviceAuthorizationResponse = response.json();
1807
1808 let request =
1810 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1811 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1812 "device_code": device_grant.device_code,
1813 "client_id": client_id,
1814 }));
1815 let response = state.request(request).await;
1816 response.assert_status(StatusCode::FORBIDDEN);
1817
1818 let ClientError { error, .. } = response.json();
1819 assert_eq!(error, ClientErrorCode::AuthorizationPending);
1820
1821 state.clock.advance(Duration::try_hours(1).unwrap());
1822
1823 let request =
1825 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1826 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1827 "device_code": device_grant.device_code,
1828 "client_id": client_id,
1829 }));
1830 let response = state.request(request).await;
1831 response.assert_status(StatusCode::FORBIDDEN);
1832
1833 let ClientError { error, .. } = response.json();
1834 assert_eq!(error, ClientErrorCode::ExpiredToken);
1835
1836 let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1838 serde_json::json!({
1839 "client_id": client_id,
1840 "scope": "openid",
1841 }),
1842 );
1843 let response = state.request(request).await;
1844 response.assert_status(StatusCode::OK);
1845
1846 let device_grant: DeviceAuthorizationResponse = response.json();
1847
1848 let mut repo = state.repository().await.unwrap();
1850
1851 let grant = repo
1853 .oauth2_device_code_grant()
1854 .find_by_user_code(&device_grant.user_code)
1855 .await
1856 .unwrap()
1857 .unwrap();
1858
1859 let grant = repo
1861 .oauth2_device_code_grant()
1862 .reject(&state.clock, grant, &browser_session)
1863 .await
1864 .unwrap();
1865
1866 repo.save().await.unwrap();
1867
1868 let request =
1870 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1871 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1872 "device_code": grant.device_code,
1873 "client_id": client_id,
1874 }));
1875 let response = state.request(request).await;
1876 response.assert_status(StatusCode::FORBIDDEN);
1877
1878 let ClientError { error, .. } = response.json();
1879 assert_eq!(error, ClientErrorCode::AccessDenied);
1880 }
1881
1882 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1883 async fn test_unsupported_grant(pool: PgPool) {
1884 setup();
1885 let state = TestState::from_pool(pool).await.unwrap();
1886
1887 let request =
1889 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1890 "client_uri": "https://example.com/",
1891 "redirect_uris": ["https://example.com/callback"],
1892 "token_endpoint_auth_method": "client_secret_post",
1893 "grant_types": ["password"],
1894 "response_types": [],
1895 }));
1896
1897 let response = state.request(request).await;
1898 response.assert_status(StatusCode::CREATED);
1899
1900 let response: ClientRegistrationResponse = response.json();
1901 let client_id = response.client_id;
1902 let client_secret = response.client_secret.expect("to have a client secret");
1903
1904 let request =
1906 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1907 "grant_type": "password",
1908 "client_id": client_id,
1909 "client_secret": client_secret,
1910 "username": "john",
1911 "password": "hunter2",
1912 }));
1913
1914 let response = state.request(request).await;
1915 response.assert_status(StatusCode::BAD_REQUEST);
1916 let ClientError { error, .. } = response.json();
1917 assert_eq!(error, ClientErrorCode::UnsupportedGrantType);
1918 }
1919}