mas_handlers/oauth2/authorization/
mod.rs1use axum::{
8 extract::{Form, State},
9 response::{IntoResponse, Response},
10};
11use hyper::StatusCode;
12use mas_axum_utils::{GenericError, InternalError, SessionInfoExt, cookies::CookieJar};
13use mas_data_model::{AuthorizationCode, BoxClock, BoxRng, Pkce};
14use mas_router::{PostAuthAction, UrlBuilder};
15use mas_storage::{
16 BoxRepository,
17 oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
18};
19use mas_templates::Templates;
20use oauth2_types::{
21 errors::{ClientError, ClientErrorCode},
22 pkce,
23 requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
24 response_type::ResponseType,
25};
26use rand::{Rng, distributions::Alphanumeric};
27use serde::Deserialize;
28use thiserror::Error;
29
30use self::callback::CallbackDestination;
31use crate::{BoundActivityTracker, PreferredLanguage, impl_from_error_for_route};
32
33mod callback;
34pub(crate) mod consent;
35
36#[derive(Debug, Error)]
37pub enum RouteError {
38 #[error(transparent)]
39 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
40
41 #[error("could not find client")]
42 ClientNotFound,
43
44 #[error("invalid response mode")]
45 InvalidResponseMode,
46
47 #[error("invalid parameters")]
48 IntoCallbackDestination(#[from] self::callback::IntoCallbackDestinationError),
49
50 #[error("invalid redirect uri")]
51 UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError),
52}
53
54impl IntoResponse for RouteError {
55 fn into_response(self) -> axum::response::Response {
56 match self {
57 Self::Internal(e) => InternalError::new(e).into_response(),
58 e @ (Self::ClientNotFound
59 | Self::InvalidResponseMode
60 | Self::IntoCallbackDestination(_)
61 | Self::UnknownRedirectUri(_)) => {
62 GenericError::new(StatusCode::BAD_REQUEST, e).into_response()
63 }
64 }
65 }
66}
67
68impl_from_error_for_route!(mas_storage::RepositoryError);
69impl_from_error_for_route!(mas_templates::TemplateError);
70impl_from_error_for_route!(self::callback::CallbackDestinationError);
71impl_from_error_for_route!(mas_policy::LoadError);
72impl_from_error_for_route!(mas_policy::EvaluationError);
73
74#[derive(Deserialize)]
75pub(crate) struct Params {
76 #[serde(flatten)]
77 auth: AuthorizationRequest,
78
79 #[serde(flatten)]
80 pkce: Option<pkce::AuthorizationRequest>,
81}
82
83fn resolve_response_mode(
87 response_type: &ResponseType,
88 suggested_response_mode: Option<ResponseMode>,
89) -> Result<ResponseMode, RouteError> {
90 use ResponseMode as M;
91
92 if response_type.has_token() || response_type.has_id_token() {
96 match suggested_response_mode {
97 None => Ok(M::Fragment),
98 Some(M::Query) => Err(RouteError::InvalidResponseMode),
99 Some(mode) => Ok(mode),
100 }
101 } else {
102 Ok(suggested_response_mode.unwrap_or(M::Query))
104 }
105}
106
107#[tracing::instrument(
108 name = "handlers.oauth2.authorization.get",
109 fields(client.id = %params.auth.client_id),
110 skip_all,
111)]
112pub(crate) async fn get(
113 mut rng: BoxRng,
114 clock: BoxClock,
115 PreferredLanguage(locale): PreferredLanguage,
116 State(templates): State<Templates>,
117 State(url_builder): State<UrlBuilder>,
118 activity_tracker: BoundActivityTracker,
119 mut repo: BoxRepository,
120 cookie_jar: CookieJar,
121 Form(params): Form<Params>,
122) -> Result<Response, RouteError> {
123 let client = repo
125 .oauth2_client()
126 .find_by_client_id(¶ms.auth.client_id)
127 .await?
128 .ok_or(RouteError::ClientNotFound)?;
129
130 let redirect_uri = client
132 .resolve_redirect_uri(¶ms.auth.redirect_uri)?
133 .clone();
134 let response_type = params.auth.response_type;
135 let response_mode = resolve_response_mode(&response_type, params.auth.response_mode)?;
136
137 let callback_destination = CallbackDestination::try_new(
139 &response_mode,
140 redirect_uri.clone(),
141 params.auth.state.clone(),
142 )?;
143
144 let (session_info, cookie_jar) = cookie_jar.session_info();
146
147 let res: Result<Response, RouteError> = ({
149 let templates = templates.clone();
150 let callback_destination = callback_destination.clone();
151 let locale = locale.clone();
152 async move {
153 let maybe_session = session_info.load_active_session(&mut repo).await?;
154 let prompt = params.auth.prompt.as_deref().unwrap_or_default();
155
156 if params.auth.request.is_some() {
159 return Ok(callback_destination.go(
160 &templates,
161 &locale,
162 ClientError::from(ClientErrorCode::RequestNotSupported),
163 )?);
164 }
165
166 if params.auth.request_uri.is_some() {
167 return Ok(callback_destination.go(
168 &templates,
169 &locale,
170 ClientError::from(ClientErrorCode::RequestUriNotSupported),
171 )?);
172 }
173
174 if response_type.has_token() {
177 return Ok(callback_destination.go(
178 &templates,
179 &locale,
180 ClientError::from(ClientErrorCode::UnsupportedResponseType),
181 )?);
182 }
183
184 if response_type.has_id_token() && !client.grant_types.contains(&GrantType::Implicit) {
187 return Ok(callback_destination.go(
188 &templates,
189 &locale,
190 ClientError::from(ClientErrorCode::UnauthorizedClient),
191 )?);
192 }
193
194 if params.auth.registration.is_some() {
195 return Ok(callback_destination.go(
196 &templates,
197 &locale,
198 ClientError::from(ClientErrorCode::RegistrationNotSupported),
199 )?);
200 }
201
202 if prompt.contains(&Prompt::None) {
204 return Ok(callback_destination.go(
205 &templates,
206 &locale,
207 ClientError::from(ClientErrorCode::LoginRequired),
208 )?);
209 }
210
211 let code: Option<AuthorizationCode> = if response_type.has_code() {
212 if !client.grant_types.contains(&GrantType::AuthorizationCode) {
214 return Ok(callback_destination.go(
215 &templates,
216 &locale,
217 ClientError::from(ClientErrorCode::UnauthorizedClient),
218 )?);
219 }
220
221 let code: String = (&mut rng)
223 .sample_iter(&Alphanumeric)
224 .take(32)
225 .map(char::from)
226 .collect();
227
228 let pkce = params.pkce.map(|p| Pkce {
229 challenge: p.code_challenge,
230 challenge_method: p.code_challenge_method,
231 });
232
233 Some(AuthorizationCode { code, pkce })
234 } else {
235 if params.pkce.is_some() {
238 return Ok(callback_destination.go(
239 &templates,
240 &locale,
241 ClientError::from(ClientErrorCode::InvalidRequest),
242 )?);
243 }
244
245 None
246 };
247
248 let grant = repo
249 .oauth2_authorization_grant()
250 .add(
251 &mut rng,
252 &clock,
253 &client,
254 redirect_uri.clone(),
255 params.auth.scope,
256 code,
257 params.auth.state.clone(),
258 params.auth.nonce,
259 response_mode,
260 response_type.has_id_token(),
261 params.auth.login_hint,
262 Some(locale.to_string()),
263 )
264 .await?;
265 let continue_grant = PostAuthAction::continue_grant(grant.id);
266
267 let res = match maybe_session {
268 None if prompt.contains(&Prompt::Create) => {
269 repo.save().await?;
271
272 url_builder
273 .redirect(&mas_router::Register::and_then(continue_grant))
274 .into_response()
275 }
276
277 None => {
278 repo.save().await?;
280
281 url_builder
282 .redirect(&mas_router::Login::and_then(continue_grant))
283 .into_response()
284 }
285
286 Some(user_session) => {
287 repo.save().await?;
289
290 activity_tracker
291 .record_browser_session(&clock, &user_session)
292 .await;
293 url_builder
294 .redirect(&mas_router::Consent(grant.id))
295 .into_response()
296 }
297 };
298
299 Ok(res)
300 }
301 })
302 .await;
303
304 let response = match res {
305 Ok(r) => r,
306 Err(err) => {
307 tracing::error!(message = &err as &dyn std::error::Error);
308 callback_destination.go(
309 &templates,
310 &locale,
311 ClientError::from(ClientErrorCode::ServerError),
312 )?
313 }
314 };
315
316 Ok((cookie_jar, response).into_response())
317}