1use std::sync::LazyLock;
8
9use axum::{
10 Form,
11 extract::{Path, State},
12 http::Method,
13 response::{Html, IntoResponse, Response},
14};
15use hyper::StatusCode;
16use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
17use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
18use mas_jose::claims::TokenHash;
19use mas_keystore::{Encrypter, Keystore};
20use mas_oidc_client::requests::jose::JwtVerificationData;
21use mas_router::UrlBuilder;
22use mas_storage::{
23 BoxClock, BoxRepository, BoxRng, Clock,
24 upstream_oauth2::{
25 UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
26 UpstreamOAuthSessionRepository,
27 },
28};
29use mas_templates::{FormPostContext, Templates};
30use oauth2_types::{errors::ClientErrorCode, requests::AccessTokenRequest};
31use opentelemetry::{Key, KeyValue, metrics::Counter};
32use serde::{Deserialize, Serialize};
33use serde_json::json;
34use thiserror::Error;
35use ulid::Ulid;
36
37use super::{
38 UpstreamSessionsCookie,
39 cache::LazyProviderInfos,
40 client_credentials_for_provider,
41 template::{AttributeMappingContext, environment},
42};
43use crate::{
44 METER, PreferredLanguage, impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
45};
46
47static CALLBACK_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
48 METER
49 .u64_counter("mas.upstream_oauth2.callback")
50 .with_description("Number of requests to the upstream OAuth2 callback endpoint")
51 .build()
52});
53const PROVIDER: Key = Key::from_static_str("provider");
54const RESULT: Key = Key::from_static_str("result");
55
56#[derive(Serialize, Deserialize)]
57pub struct Params {
58 #[serde(skip_serializing_if = "Option::is_none")]
59 state: Option<String>,
60
61 #[serde(default)]
64 did_mas_repost_to_itself: bool,
65
66 #[serde(skip_serializing_if = "Option::is_none")]
67 code: Option<String>,
68
69 #[serde(skip_serializing_if = "Option::is_none")]
70 error: Option<ClientErrorCode>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 error_description: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 error_uri: Option<String>,
75
76 #[serde(flatten)]
77 extra_callback_parameters: Option<serde_json::Value>,
78}
79
80impl Params {
81 pub fn is_empty(&self) -> bool {
83 self.state.is_none()
84 && self.code.is_none()
85 && self.error.is_none()
86 && self.error_description.is_none()
87 && self.error_uri.is_none()
88 }
89}
90
91#[derive(Debug, Error)]
92pub(crate) enum RouteError {
93 #[error("Session not found")]
94 SessionNotFound,
95
96 #[error("Provider not found")]
97 ProviderNotFound,
98
99 #[error("Provider mismatch")]
100 ProviderMismatch,
101
102 #[error("Session already completed")]
103 AlreadyCompleted,
104
105 #[error("State parameter mismatch")]
106 StateMismatch,
107
108 #[error("Missing state parameter")]
109 MissingState,
110
111 #[error("Missing code parameter")]
112 MissingCode,
113
114 #[error("Could not extract subject from ID token")]
115 ExtractSubject(#[source] minijinja::Error),
116
117 #[error("Subject is empty")]
118 EmptySubject,
119
120 #[error("Error from the provider: {error}")]
121 ClientError {
122 error: ClientErrorCode,
123 error_description: Option<String>,
124 },
125
126 #[error("Missing session cookie")]
127 MissingCookie,
128
129 #[error("Missing query parameters")]
130 MissingQueryParams,
131
132 #[error("Missing form parameters")]
133 MissingFormParams,
134
135 #[error("Invalid response mode, expected '{expected}'")]
136 InvalidResponseMode {
137 expected: UpstreamOAuthProviderResponseMode,
138 },
139
140 #[error(transparent)]
141 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
142}
143
144impl_from_error_for_route!(mas_templates::TemplateError);
145impl_from_error_for_route!(mas_storage::RepositoryError);
146impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
147impl_from_error_for_route!(mas_oidc_client::error::JwksError);
148impl_from_error_for_route!(mas_oidc_client::error::TokenRequestError);
149impl_from_error_for_route!(mas_oidc_client::error::IdTokenError);
150impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
151impl_from_error_for_route!(super::ProviderCredentialsError);
152impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
153
154impl IntoResponse for RouteError {
155 fn into_response(self) -> axum::response::Response {
156 let event_id = sentry::capture_error(&self);
157 let response = match self {
158 Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
159 Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
160 Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
161 e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
162 };
163
164 (SentryEventID::from(event_id), response).into_response()
165 }
166}
167
168#[tracing::instrument(
169 name = "handlers.upstream_oauth2.callback.handler",
170 fields(upstream_oauth_provider.id = %provider_id),
171 skip_all,
172 err,
173)]
174#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
175pub(crate) async fn handler(
176 mut rng: BoxRng,
177 clock: BoxClock,
178 State(metadata_cache): State<MetadataCache>,
179 mut repo: BoxRepository,
180 State(url_builder): State<UrlBuilder>,
181 State(encrypter): State<Encrypter>,
182 State(keystore): State<Keystore>,
183 State(client): State<reqwest::Client>,
184 State(templates): State<Templates>,
185 method: Method,
186 PreferredLanguage(locale): PreferredLanguage,
187 cookie_jar: CookieJar,
188 Path(provider_id): Path<Ulid>,
189 Form(params): Form<Params>,
190) -> Result<Response, RouteError> {
191 let provider = repo
192 .upstream_oauth_provider()
193 .lookup(provider_id)
194 .await?
195 .filter(UpstreamOAuthProvider::enabled)
196 .ok_or(RouteError::ProviderNotFound)?;
197
198 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
199
200 if params.is_empty() {
201 if let Method::GET = method {
202 return Err(RouteError::MissingQueryParams);
203 }
204
205 return Err(RouteError::MissingFormParams);
206 }
207
208 match (provider.response_mode, method) {
212 (Some(UpstreamOAuthProviderResponseMode::FormPost) | None, Method::POST) => {
213 if sessions_cookie.is_empty() && !params.did_mas_repost_to_itself {
219 let params = Params {
220 did_mas_repost_to_itself: true,
221 ..params
222 };
223 let context = FormPostContext::new_for_current_url(params).with_language(&locale);
224 let html = templates.render_form_post(&context)?;
225 return Ok(Html(html).into_response());
226 }
227 }
228 (None, _) | (Some(UpstreamOAuthProviderResponseMode::Query), Method::GET) => {}
229 (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
230 }
231
232 if let Some(error) = params.error {
233 CALLBACK_COUNTER.add(
234 1,
235 &[
236 KeyValue::new(PROVIDER, provider_id.to_string()),
237 KeyValue::new(RESULT, "error"),
238 ],
239 );
240
241 return Err(RouteError::ClientError {
242 error,
243 error_description: params.error_description.clone(),
244 });
245 }
246
247 let Some(state) = params.state else {
248 return Err(RouteError::MissingState);
249 };
250
251 let (session_id, _post_auth_action) = sessions_cookie
252 .find_session(provider_id, &state)
253 .map_err(|_| RouteError::MissingCookie)?;
254
255 let session = repo
256 .upstream_oauth_session()
257 .lookup(session_id)
258 .await?
259 .ok_or(RouteError::SessionNotFound)?;
260
261 if provider.id != session.provider_id {
262 return Err(RouteError::ProviderMismatch);
264 }
265
266 if state != session.state_str {
267 return Err(RouteError::StateMismatch);
269 }
270
271 if !session.is_pending() {
272 return Err(RouteError::AlreadyCompleted);
274 }
275
276 let Some(code) = params.code else {
278 return Err(RouteError::MissingCode);
279 };
280
281 CALLBACK_COUNTER.add(
282 1,
283 &[
284 KeyValue::new(PROVIDER, provider_id.to_string()),
285 KeyValue::new(RESULT, "success"),
286 ],
287 );
288
289 let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
290
291 let client_credentials = client_credentials_for_provider(
293 &provider,
294 lazy_metadata.token_endpoint().await?,
295 &keystore,
296 &encrypter,
297 )?;
298
299 let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
300
301 let token_response = mas_oidc_client::requests::token::request_access_token(
302 &client,
303 client_credentials,
304 lazy_metadata.token_endpoint().await?,
305 AccessTokenRequest::AuthorizationCode(oauth2_types::requests::AuthorizationCodeGrant {
306 code: code.clone(),
307 redirect_uri: Some(redirect_uri),
308 code_verifier: session.code_challenge_verifier.clone(),
309 }),
310 clock.now(),
311 &mut rng,
312 )
313 .await?;
314
315 let mut jwks = None;
316
317 let mut context = AttributeMappingContext::new();
318 if let Some(id_token) = token_response.id_token.as_ref() {
319 jwks = Some(
320 mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
321 .await?,
322 );
323
324 let id_token_verification_data = JwtVerificationData {
325 issuer: provider.issuer.as_deref(),
326 jwks: jwks.as_ref().unwrap(),
327 signing_algorithm: &provider.id_token_signed_response_alg,
328 client_id: &provider.client_id,
329 };
330
331 let id_token = mas_oidc_client::requests::jose::verify_id_token(
333 id_token,
334 id_token_verification_data,
335 None,
336 clock.now(),
337 )?;
338
339 let (_headers, mut claims) = id_token.into_parts();
340
341 mas_jose::claims::AT_HASH
343 .extract_optional_with_options(
344 &mut claims,
345 TokenHash::new(
346 id_token_verification_data.signing_algorithm,
347 &token_response.access_token,
348 ),
349 )
350 .map_err(mas_oidc_client::error::IdTokenError::from)?;
351
352 mas_jose::claims::C_HASH
354 .extract_optional_with_options(
355 &mut claims,
356 TokenHash::new(id_token_verification_data.signing_algorithm, &code),
357 )
358 .map_err(mas_oidc_client::error::IdTokenError::from)?;
359
360 mas_jose::claims::NONCE
362 .extract_required_with_options(&mut claims, session.nonce.as_str())
363 .map_err(mas_oidc_client::error::IdTokenError::from)?;
364
365 context = context.with_id_token_claims(claims);
366 }
367
368 if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
369 context = context.with_extra_callback_parameters(extra_callback_parameters);
370 }
371
372 let userinfo = if provider.fetch_userinfo {
373 Some(json!(match &provider.userinfo_signed_response_alg {
374 Some(signing_algorithm) => {
375 let jwks = match jwks {
376 Some(jwks) => jwks,
377 None => {
378 mas_oidc_client::requests::jose::fetch_jwks(
379 &client,
380 lazy_metadata.jwks_uri().await?,
381 )
382 .await?
383 }
384 };
385
386 mas_oidc_client::requests::userinfo::fetch_userinfo(
387 &client,
388 lazy_metadata.userinfo_endpoint().await?,
389 token_response.access_token.as_str(),
390 Some(JwtVerificationData {
391 issuer: provider.issuer.as_deref(),
392 jwks: &jwks,
393 signing_algorithm,
394 client_id: &provider.client_id,
395 }),
396 )
397 .await?
398 }
399 None => {
400 mas_oidc_client::requests::userinfo::fetch_userinfo(
401 &client,
402 lazy_metadata.userinfo_endpoint().await?,
403 token_response.access_token.as_str(),
404 None,
405 )
406 .await?
407 }
408 }))
409 } else {
410 None
411 };
412
413 if let Some(userinfo) = userinfo.clone() {
414 context = context.with_userinfo_claims(userinfo);
415 }
416
417 let context = context.build();
418
419 let env = environment();
420
421 let template = provider
422 .claims_imports
423 .subject
424 .template
425 .as_deref()
426 .unwrap_or("{{ user.sub }}");
427 let subject = env
428 .render_str(template, context.clone())
429 .map_err(RouteError::ExtractSubject)?;
430
431 if subject.is_empty() {
432 return Err(RouteError::EmptySubject);
433 }
434
435 let maybe_link = repo
437 .upstream_oauth_link()
438 .find_by_subject(&provider, &subject)
439 .await?;
440
441 let link = if let Some(link) = maybe_link {
442 link
443 } else {
444 let human_account_name = provider
447 .claims_imports
448 .account_name
449 .template
450 .as_deref()
451 .and_then(|template| match env.render_str(template, context) {
452 Ok(name) => Some(name),
453 Err(e) => {
454 tracing::warn!(
455 error = &e as &dyn std::error::Error,
456 "Failed to render account name"
457 );
458 None
459 }
460 });
461
462 repo.upstream_oauth_link()
463 .add(&mut rng, &clock, &provider, subject, human_account_name)
464 .await?
465 };
466
467 let session = repo
468 .upstream_oauth_session()
469 .complete_with_link(
470 &clock,
471 session,
472 &link,
473 token_response.id_token,
474 params.extra_callback_parameters,
475 userinfo,
476 )
477 .await?;
478
479 let cookie_jar = sessions_cookie
480 .add_link_to_session(session.id, link.id)?
481 .save(cookie_jar, &clock);
482
483 repo.save().await?;
484
485 Ok((
486 cookie_jar,
487 url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)),
488 )
489 .into_response())
490}