1use std::{sync::Arc, time::Duration};
8
9use axum::{
10 extract::{Form, Path, State},
11 response::{Html, IntoResponse, Response},
12};
13use axum_extra::TypedHeader;
14use hyper::StatusCode;
15use mas_axum_utils::{
16 GenericError, InternalError,
17 cookies::CookieJar,
18 csrf::{CsrfExt, ProtectedForm},
19};
20use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng, MatrixUser};
21use mas_keystore::Keystore;
22use mas_matrix::HomeserverConnection;
23use mas_policy::Policy;
24use mas_router::{PostAuthAction, UrlBuilder};
25use mas_storage::{
26 BoxRepository,
27 oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
28};
29use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
30use oauth2_types::requests::AuthorizationResponse;
31use thiserror::Error;
32use ulid::Ulid;
33
34use super::callback::CallbackDestination;
35use crate::{
36 BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
37 oauth2::generate_id_token,
38 session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
39};
40
41#[derive(Debug, Error)]
42pub enum RouteError {
43 #[error(transparent)]
44 Internal(Box<dyn std::error::Error + Send + Sync>),
45
46 #[error(transparent)]
47 Csrf(#[from] mas_axum_utils::csrf::CsrfError),
48
49 #[error("Authorization grant not found")]
50 GrantNotFound,
51
52 #[error("Authorization grant {0} already used")]
53 GrantNotPending(Ulid),
54
55 #[error("Failed to load client {0}")]
56 NoSuchClient(Ulid),
57}
58
59impl_from_error_for_route!(mas_templates::TemplateError);
60impl_from_error_for_route!(mas_storage::RepositoryError);
61impl_from_error_for_route!(mas_policy::LoadError);
62impl_from_error_for_route!(mas_policy::EvaluationError);
63impl_from_error_for_route!(crate::session::SessionLoadError);
64impl_from_error_for_route!(crate::oauth2::IdTokenSignatureError);
65impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
66impl_from_error_for_route!(super::callback::CallbackDestinationError);
67
68impl IntoResponse for RouteError {
69 fn into_response(self) -> axum::response::Response {
70 match self {
71 Self::Internal(e) => InternalError::new(e).into_response(),
72 e @ Self::NoSuchClient(_) => InternalError::new(Box::new(e)).into_response(),
73 e @ Self::GrantNotFound => GenericError::new(StatusCode::NOT_FOUND, e).into_response(),
74 e @ Self::GrantNotPending(_) => {
75 GenericError::new(StatusCode::CONFLICT, e).into_response()
76 }
77 e @ Self::Csrf(_) => GenericError::new(StatusCode::BAD_REQUEST, e).into_response(),
78 }
79 }
80}
81
82#[tracing::instrument(
83 name = "handlers.oauth2.authorization.consent.get",
84 fields(grant.id = %grant_id),
85 skip_all,
86)]
87pub(crate) async fn get(
88 mut rng: BoxRng,
89 clock: BoxClock,
90 PreferredLanguage(locale): PreferredLanguage,
91 State(templates): State<Templates>,
92 State(url_builder): State<UrlBuilder>,
93 State(homeserver): State<Arc<dyn HomeserverConnection>>,
94 mut policy: Policy,
95 mut repo: BoxRepository,
96 activity_tracker: BoundActivityTracker,
97 user_agent: Option<TypedHeader<headers::UserAgent>>,
98 cookie_jar: CookieJar,
99 Path(grant_id): Path<Ulid>,
100) -> Result<Response, RouteError> {
101 let (cookie_jar, maybe_session) = match load_session_or_fallback(
102 cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
103 )
104 .await?
105 {
106 SessionOrFallback::MaybeSession {
107 cookie_jar,
108 maybe_session,
109 ..
110 } => (cookie_jar, maybe_session),
111 SessionOrFallback::Fallback { response } => return Ok(response),
112 };
113
114 let user_agent = user_agent.map(|ua| ua.to_string());
115
116 let grant = repo
117 .oauth2_authorization_grant()
118 .lookup(grant_id)
119 .await?
120 .ok_or(RouteError::GrantNotFound)?;
121
122 let client = repo
123 .oauth2_client()
124 .lookup(grant.client_id)
125 .await?
126 .ok_or(RouteError::NoSuchClient(grant.client_id))?;
127
128 if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
129 return Err(RouteError::GrantNotPending(grant.id));
130 }
131
132 let Some(session) = maybe_session else {
133 let login = mas_router::Login::and_continue_grant(grant_id);
134 return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
135 };
136
137 activity_tracker
138 .record_browser_session(&clock, &session)
139 .await;
140
141 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
142
143 let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
144
145 repo.save().await?;
147
148 let res = policy
149 .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
150 user: Some(&session.user),
151 client: &client,
152 session_counts: Some(session_counts),
153 scope: &grant.scope,
154 grant_type: mas_policy::GrantType::AuthorizationCode,
155 requester: mas_policy::Requester {
156 ip_address: activity_tracker.ip(),
157 user_agent,
158 },
159 })
160 .await?;
161 if !res.valid() {
162 let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
163 .with_session(session)
164 .with_csrf(csrf_token.form_value())
165 .with_language(locale);
166
167 let content = templates.render_policy_violation(&ctx)?;
168
169 return Ok((cookie_jar, Html(content)).into_response());
170 }
171
172 let localpart = &session.user.username;
176 let display_name = match tokio::time::timeout(
177 Duration::from_secs(1),
178 homeserver.query_user(localpart),
179 )
180 .await
181 {
182 Ok(Ok(user)) => user.displayname,
183 Ok(Err(err)) => {
184 tracing::warn!(
185 error = &*err as &dyn std::error::Error,
186 localpart,
187 "Failed to query user"
188 );
189 None
190 }
191 Err(_) => {
192 tracing::warn!(localpart, "Timed out while querying user");
193 None
194 }
195 };
196
197 let matrix_user = MatrixUser {
198 mxid: homeserver.mxid(localpart),
199 display_name,
200 };
201
202 let ctx = ConsentContext::new(grant, client, matrix_user)
203 .with_session(session)
204 .with_csrf(csrf_token.form_value())
205 .with_language(locale);
206
207 let content = templates.render_consent(&ctx)?;
208
209 Ok((cookie_jar, Html(content)).into_response())
210}
211
212#[tracing::instrument(
213 name = "handlers.oauth2.authorization.consent.post",
214 fields(grant.id = %grant_id),
215 skip_all,
216)]
217pub(crate) async fn post(
218 mut rng: BoxRng,
219 clock: BoxClock,
220 PreferredLanguage(locale): PreferredLanguage,
221 State(templates): State<Templates>,
222 State(key_store): State<Keystore>,
223 mut policy: Policy,
224 mut repo: BoxRepository,
225 activity_tracker: BoundActivityTracker,
226 user_agent: Option<TypedHeader<headers::UserAgent>>,
227 cookie_jar: CookieJar,
228 State(url_builder): State<UrlBuilder>,
229 Path(grant_id): Path<Ulid>,
230 Form(form): Form<ProtectedForm<()>>,
231) -> Result<Response, RouteError> {
232 cookie_jar.verify_form(&clock, form)?;
233
234 let (cookie_jar, maybe_session) = match load_session_or_fallback(
235 cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
236 )
237 .await?
238 {
239 SessionOrFallback::MaybeSession {
240 cookie_jar,
241 maybe_session,
242 ..
243 } => (cookie_jar, maybe_session),
244 SessionOrFallback::Fallback { response } => return Ok(response),
245 };
246
247 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
248
249 let user_agent = user_agent.map(|ua| ua.to_string());
250
251 let grant = repo
252 .oauth2_authorization_grant()
253 .lookup(grant_id)
254 .await?
255 .ok_or(RouteError::GrantNotFound)?;
256 let callback_destination = CallbackDestination::try_from(&grant)?;
257
258 let Some(browser_session) = maybe_session else {
259 let next = PostAuthAction::continue_grant(grant_id);
260 let login = mas_router::Login::and_then(next);
261 return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
262 };
263
264 activity_tracker
265 .record_browser_session(&clock, &browser_session)
266 .await;
267
268 let client = repo
269 .oauth2_client()
270 .lookup(grant.client_id)
271 .await?
272 .ok_or(RouteError::NoSuchClient(grant.client_id))?;
273
274 if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
275 return Err(RouteError::GrantNotPending(grant.id));
276 }
277
278 let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user).await?;
279
280 let res = policy
281 .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
282 user: Some(&browser_session.user),
283 client: &client,
284 session_counts: Some(session_counts),
285 scope: &grant.scope,
286 grant_type: mas_policy::GrantType::AuthorizationCode,
287 requester: mas_policy::Requester {
288 ip_address: activity_tracker.ip(),
289 user_agent,
290 },
291 })
292 .await?;
293
294 if !res.valid() {
295 let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
296 .with_session(browser_session)
297 .with_csrf(csrf_token.form_value())
298 .with_language(locale);
299
300 let content = templates.render_policy_violation(&ctx)?;
301
302 return Ok((cookie_jar, Html(content)).into_response());
303 }
304
305 let session = repo
307 .oauth2_session()
308 .add_from_browser_session(
309 &mut rng,
310 &clock,
311 &client,
312 &browser_session,
313 grant.scope.clone(),
314 )
315 .await?;
316
317 let grant = repo
318 .oauth2_authorization_grant()
319 .fulfill(&clock, &session, grant)
320 .await?;
321
322 let mut params = AuthorizationResponse::default();
323
324 if grant.response_type_id_token {
326 let last_authentication = repo
328 .browser_session()
329 .get_last_authentication(&browser_session)
330 .await?;
331
332 params.id_token = Some(generate_id_token(
333 &mut rng,
334 &clock,
335 &url_builder,
336 &key_store,
337 &client,
338 Some(&grant),
339 &browser_session,
340 None,
341 last_authentication.as_ref(),
342 )?);
343 }
344
345 if let Some(code) = grant.code {
347 params.code = Some(code.code);
348 }
349
350 repo.save().await?;
351
352 activity_tracker
353 .record_oauth2_session(&clock, &session)
354 .await;
355
356 Ok((
357 cookie_jar,
358 callback_destination.go(&templates, &locale, params)?,
359 )
360 .into_response())
361}