1use axum::{
8 extract::{Form, State},
9 response::{Html, IntoResponse, Response},
10};
11use axum_extra::TypedHeader;
12use hyper::StatusCode;
13use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, csrf::CsrfExt, sentry::SentryEventID};
14use mas_data_model::{AuthorizationCode, Pkce};
15use mas_keystore::Keystore;
16use mas_policy::Policy;
17use mas_router::{PostAuthAction, UrlBuilder};
18use mas_storage::{
19 BoxClock, BoxRepository, BoxRng,
20 oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
21};
22use mas_templates::{PolicyViolationContext, TemplateContext, Templates};
23use oauth2_types::{
24 errors::{ClientError, ClientErrorCode},
25 pkce,
26 requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
27 response_type::ResponseType,
28};
29use rand::{Rng, distributions::Alphanumeric};
30use serde::Deserialize;
31use thiserror::Error;
32use tracing::warn;
33
34use self::{callback::CallbackDestination, complete::GrantCompletionError};
35use crate::{BoundActivityTracker, PreferredLanguage, impl_from_error_for_route};
36
37mod callback;
38pub mod complete;
39
40#[derive(Debug, Error)]
41pub enum RouteError {
42 #[error(transparent)]
43 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
44
45 #[error("could not find client")]
46 ClientNotFound,
47
48 #[error("invalid response mode")]
49 InvalidResponseMode,
50
51 #[error("invalid parameters")]
52 IntoCallbackDestination(#[from] self::callback::IntoCallbackDestinationError),
53
54 #[error("invalid redirect uri")]
55 UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError),
56}
57
58impl IntoResponse for RouteError {
59 fn into_response(self) -> axum::response::Response {
60 let event_id = sentry::capture_error(&self);
61 let response = match self {
63 RouteError::Internal(e) => {
64 (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
65 }
66 RouteError::ClientNotFound => {
67 (StatusCode::BAD_REQUEST, "could not find client").into_response()
68 }
69 RouteError::InvalidResponseMode => {
70 (StatusCode::BAD_REQUEST, "invalid response mode").into_response()
71 }
72 RouteError::IntoCallbackDestination(e) => {
73 (StatusCode::BAD_REQUEST, e.to_string()).into_response()
74 }
75 RouteError::UnknownRedirectUri(e) => (
76 StatusCode::BAD_REQUEST,
77 format!("Invalid redirect URI ({e})"),
78 )
79 .into_response(),
80 };
81
82 (SentryEventID::from(event_id), response).into_response()
83 }
84}
85
86impl_from_error_for_route!(mas_storage::RepositoryError);
87impl_from_error_for_route!(mas_templates::TemplateError);
88impl_from_error_for_route!(self::callback::CallbackDestinationError);
89impl_from_error_for_route!(mas_policy::LoadError);
90impl_from_error_for_route!(mas_policy::EvaluationError);
91
92#[derive(Deserialize)]
93pub(crate) struct Params {
94 #[serde(flatten)]
95 auth: AuthorizationRequest,
96
97 #[serde(flatten)]
98 pkce: Option<pkce::AuthorizationRequest>,
99}
100
101fn resolve_response_mode(
105 response_type: &ResponseType,
106 suggested_response_mode: Option<ResponseMode>,
107) -> Result<ResponseMode, RouteError> {
108 use ResponseMode as M;
109
110 if response_type.has_token() || response_type.has_id_token() {
114 match suggested_response_mode {
115 None => Ok(M::Fragment),
116 Some(M::Query) => Err(RouteError::InvalidResponseMode),
117 Some(mode) => Ok(mode),
118 }
119 } else {
120 Ok(suggested_response_mode.unwrap_or(M::Query))
122 }
123}
124
125#[tracing::instrument(
126 name = "handlers.oauth2.authorization.get",
127 fields(client.id = %params.auth.client_id),
128 skip_all,
129 err,
130)]
131#[allow(clippy::too_many_lines)]
132pub(crate) async fn get(
133 mut rng: BoxRng,
134 clock: BoxClock,
135 PreferredLanguage(locale): PreferredLanguage,
136 State(templates): State<Templates>,
137 State(key_store): State<Keystore>,
138 State(url_builder): State<UrlBuilder>,
139 policy: Policy,
140 user_agent: Option<TypedHeader<headers::UserAgent>>,
141 activity_tracker: BoundActivityTracker,
142 mut repo: BoxRepository,
143 cookie_jar: CookieJar,
144 Form(params): Form<Params>,
145) -> Result<Response, RouteError> {
146 let client = repo
148 .oauth2_client()
149 .find_by_client_id(¶ms.auth.client_id)
150 .await?
151 .ok_or(RouteError::ClientNotFound)?;
152
153 let redirect_uri = client
155 .resolve_redirect_uri(¶ms.auth.redirect_uri)?
156 .clone();
157 let response_type = params.auth.response_type;
158 let response_mode = resolve_response_mode(&response_type, params.auth.response_mode)?;
159
160 let callback_destination = CallbackDestination::try_new(
162 &response_mode,
163 redirect_uri.clone(),
164 params.auth.state.clone(),
165 )?;
166
167 let (session_info, cookie_jar) = cookie_jar.session_info();
169 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
170
171 let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());
172
173 let res: Result<Response, RouteError> = ({
175 let templates = templates.clone();
176 let callback_destination = callback_destination.clone();
177 let locale = locale.clone();
178 async move {
179 let maybe_session = session_info.load_active_session(&mut repo).await?;
180 let prompt = params.auth.prompt.as_deref().unwrap_or_default();
181
182 if params.auth.request.is_some() {
185 return Ok(callback_destination
186 .go(
187 &templates,
188 &locale,
189 ClientError::from(ClientErrorCode::RequestNotSupported),
190 )
191 .await?);
192 }
193
194 if params.auth.request_uri.is_some() {
195 return Ok(callback_destination
196 .go(
197 &templates,
198 &locale,
199 ClientError::from(ClientErrorCode::RequestUriNotSupported),
200 )
201 .await?);
202 }
203
204 if response_type.has_token() {
207 return Ok(callback_destination
208 .go(
209 &templates,
210 &locale,
211 ClientError::from(ClientErrorCode::UnsupportedResponseType),
212 )
213 .await?);
214 }
215
216 if response_type.has_id_token() && !client.grant_types.contains(&GrantType::Implicit) {
219 return Ok(callback_destination
220 .go(
221 &templates,
222 &locale,
223 ClientError::from(ClientErrorCode::UnauthorizedClient),
224 )
225 .await?);
226 }
227
228 if params.auth.registration.is_some() {
229 return Ok(callback_destination
230 .go(
231 &templates,
232 &locale,
233 ClientError::from(ClientErrorCode::RegistrationNotSupported),
234 )
235 .await?);
236 }
237
238 if prompt.contains(&Prompt::None) && maybe_session.is_none() {
240 return Ok(callback_destination
241 .go(
242 &templates,
243 &locale,
244 ClientError::from(ClientErrorCode::LoginRequired),
245 )
246 .await?);
247 }
248
249 let code: Option<AuthorizationCode> = if response_type.has_code() {
250 if !client.grant_types.contains(&GrantType::AuthorizationCode) {
252 return Ok(callback_destination
253 .go(
254 &templates,
255 &locale,
256 ClientError::from(ClientErrorCode::UnauthorizedClient),
257 )
258 .await?);
259 }
260
261 let code: String = (&mut rng)
263 .sample_iter(&Alphanumeric)
264 .take(32)
265 .map(char::from)
266 .collect();
267
268 let pkce = params.pkce.map(|p| Pkce {
269 challenge: p.code_challenge,
270 challenge_method: p.code_challenge_method,
271 });
272
273 Some(AuthorizationCode { code, pkce })
274 } else {
275 if params.pkce.is_some() {
278 return Ok(callback_destination
279 .go(
280 &templates,
281 &locale,
282 ClientError::from(ClientErrorCode::InvalidRequest),
283 )
284 .await?);
285 }
286
287 None
288 };
289
290 let requires_consent = prompt.contains(&Prompt::Consent);
291
292 let grant = repo
293 .oauth2_authorization_grant()
294 .add(
295 &mut rng,
296 &clock,
297 &client,
298 redirect_uri.clone(),
299 params.auth.scope,
300 code,
301 params.auth.state.clone(),
302 params.auth.nonce,
303 params.auth.max_age,
304 response_mode,
305 response_type.has_id_token(),
306 requires_consent,
307 params.auth.login_hint,
308 )
309 .await?;
310 let continue_grant = PostAuthAction::continue_grant(grant.id);
311
312 let res = match maybe_session {
313 None if prompt.contains(&Prompt::None) => {
315 unreachable!();
317 }
318 None if prompt.contains(&Prompt::Create) => {
319 repo.save().await?;
321
322 url_builder.redirect(&mas_router::Register::and_then(continue_grant))
323 .into_response()
324 }
325 None => {
326 repo.save().await?;
328
329 url_builder.redirect(&mas_router::Login::and_then(continue_grant))
330 .into_response()
331 }
332
333 Some(session)
335 if prompt.contains(&Prompt::Login)
336 || prompt.contains(&Prompt::SelectAccount) =>
337 {
338 repo.save().await?;
340
341 activity_tracker.record_browser_session(&clock, &session).await;
342
343 url_builder.redirect(&mas_router::Reauth::and_then(continue_grant))
344 .into_response()
345 }
346
347 Some(user_session) if prompt.contains(&Prompt::None) => {
349 activity_tracker.record_browser_session(&clock, &user_session).await;
350
351 match self::complete::complete(
353 &mut rng,
354 &clock,
355 &activity_tracker,
356 user_agent,
357 repo,
358 key_store,
359 policy,
360 &url_builder,
361 grant,
362 &client,
363 &user_session,
364 )
365 .await
366 {
367 Ok(params) => callback_destination.go(&templates, &locale, params).await?,
368 Err(GrantCompletionError::RequiresConsent) => {
369 callback_destination
370 .go(
371 &templates,
372 &locale,
373 ClientError::from(ClientErrorCode::ConsentRequired),
374 )
375 .await?
376 }
377 Err(GrantCompletionError::RequiresReauth) => {
378 callback_destination
379 .go(
380 &templates,
381 &locale,
382 ClientError::from(ClientErrorCode::InteractionRequired),
383 )
384 .await?
385 }
386 Err(GrantCompletionError::PolicyViolation(_grant, _res)) => {
387 callback_destination
388 .go(&templates, &locale, ClientError::from(ClientErrorCode::AccessDenied))
389 .await?
390 }
391 Err(GrantCompletionError::Internal(e)) => {
392 return Err(RouteError::Internal(e))
393 }
394 Err(e @ GrantCompletionError::NotPending) => {
395 return Err(RouteError::Internal(Box::new(e)));
397 }
398 }
399 }
400 Some(user_session) => {
401 activity_tracker.record_browser_session(&clock, &user_session).await;
402
403 let grant_id = grant.id;
404 match self::complete::complete(
406 &mut rng,
407 &clock,
408 &activity_tracker,
409 user_agent,
410 repo,
411 key_store,
412 policy,
413 &url_builder,
414 grant,
415 &client,
416 &user_session,
417 )
418 .await
419 {
420 Ok(params) => callback_destination.go(&templates, &locale, params).await?,
421 Err(GrantCompletionError::RequiresConsent) => {
422 url_builder.redirect(&mas_router::Consent(grant_id)).into_response()
423 }
424 Err(GrantCompletionError::PolicyViolation(grant, res)) => {
425 warn!(violation = ?res, "Authorization grant for client {} denied by policy", client.id);
426
427 let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
428 .with_session(user_session)
429 .with_csrf(csrf_token.form_value())
430 .with_language(locale);
431
432 let content = templates.render_policy_violation(&ctx)?;
433 Html(content).into_response()
434 }
435 Err(GrantCompletionError::RequiresReauth) => {
436 url_builder.redirect(&mas_router::Reauth::and_then(continue_grant))
437 .into_response()
438 }
439 Err(GrantCompletionError::Internal(e)) => {
440 return Err(RouteError::Internal(e))
441 }
442 Err(e @ GrantCompletionError::NotPending) => {
443 return Err(RouteError::Internal(Box::new(e)));
445 }
446 }
447 }
448 };
449
450 Ok(res)
451 }
452 })
453 .await;
454
455 let response = match res {
456 Ok(r) => r,
457 Err(err) => {
458 tracing::error!(%err);
459 callback_destination
460 .go(
461 &templates,
462 &locale,
463 ClientError::from(ClientErrorCode::ServerError),
464 )
465 .await?
466 }
467 };
468
469 Ok((cookie_jar, response).into_response())
470}