mas_handlers/oauth2/authorization/
consent.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{
8    extract::{Form, Path, State},
9    response::{Html, IntoResponse, Response},
10};
11use axum_extra::TypedHeader;
12use hyper::StatusCode;
13use mas_axum_utils::{
14    cookies::CookieJar,
15    csrf::{CsrfExt, ProtectedForm},
16    sentry::SentryEventID,
17};
18use mas_data_model::AuthorizationGrantStage;
19use mas_keystore::Keystore;
20use mas_policy::Policy;
21use mas_router::{PostAuthAction, UrlBuilder};
22use mas_storage::{
23    BoxClock, BoxRepository, BoxRng,
24    oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
25};
26use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
27use oauth2_types::requests::AuthorizationResponse;
28use thiserror::Error;
29use ulid::Ulid;
30
31use super::callback::CallbackDestination;
32use crate::{
33    BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
34    oauth2::generate_id_token,
35    session::{SessionOrFallback, load_session_or_fallback},
36};
37
38#[derive(Debug, Error)]
39pub enum RouteError {
40    #[error(transparent)]
41    Internal(Box<dyn std::error::Error + Send + Sync>),
42
43    #[error(transparent)]
44    Csrf(#[from] mas_axum_utils::csrf::CsrfError),
45
46    #[error("Authorization grant not found")]
47    GrantNotFound,
48
49    #[error("Authorization grant already used")]
50    GrantNotPending,
51
52    #[error("Failed to load client")]
53    NoSuchClient,
54}
55
56impl_from_error_for_route!(mas_templates::TemplateError);
57impl_from_error_for_route!(mas_storage::RepositoryError);
58impl_from_error_for_route!(mas_policy::LoadError);
59impl_from_error_for_route!(mas_policy::EvaluationError);
60impl_from_error_for_route!(crate::session::SessionLoadError);
61impl_from_error_for_route!(crate::oauth2::IdTokenSignatureError);
62impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
63impl_from_error_for_route!(super::callback::CallbackDestinationError);
64
65impl IntoResponse for RouteError {
66    fn into_response(self) -> axum::response::Response {
67        let event_id = sentry::capture_error(&self);
68        (
69            StatusCode::INTERNAL_SERVER_ERROR,
70            SentryEventID::from(event_id),
71            self.to_string(),
72        )
73            .into_response()
74    }
75}
76
77#[tracing::instrument(
78    name = "handlers.oauth2.authorization.consent.get",
79    fields(grant.id = %grant_id),
80    skip_all,
81    err,
82)]
83pub(crate) async fn get(
84    mut rng: BoxRng,
85    clock: BoxClock,
86    PreferredLanguage(locale): PreferredLanguage,
87    State(templates): State<Templates>,
88    State(url_builder): State<UrlBuilder>,
89    mut policy: Policy,
90    mut repo: BoxRepository,
91    activity_tracker: BoundActivityTracker,
92    user_agent: Option<TypedHeader<headers::UserAgent>>,
93    cookie_jar: CookieJar,
94    Path(grant_id): Path<Ulid>,
95) -> Result<Response, RouteError> {
96    let (cookie_jar, maybe_session) = match load_session_or_fallback(
97        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
98    )
99    .await?
100    {
101        SessionOrFallback::MaybeSession {
102            cookie_jar,
103            maybe_session,
104            ..
105        } => (cookie_jar, maybe_session),
106        SessionOrFallback::Fallback { response } => return Ok(response),
107    };
108
109    let user_agent = user_agent.map(|ua| ua.to_string());
110
111    let grant = repo
112        .oauth2_authorization_grant()
113        .lookup(grant_id)
114        .await?
115        .ok_or(RouteError::GrantNotFound)?;
116
117    let client = repo
118        .oauth2_client()
119        .lookup(grant.client_id)
120        .await?
121        .ok_or(RouteError::NoSuchClient)?;
122
123    if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
124        return Err(RouteError::GrantNotPending);
125    }
126
127    let Some(session) = maybe_session else {
128        let login = mas_router::Login::and_continue_grant(grant_id);
129        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
130    };
131
132    activity_tracker
133        .record_browser_session(&clock, &session)
134        .await;
135
136    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
137
138    let res = policy
139        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
140            user: Some(&session.user),
141            client: &client,
142            scope: &grant.scope,
143            grant_type: mas_policy::GrantType::AuthorizationCode,
144            requester: mas_policy::Requester {
145                ip_address: activity_tracker.ip(),
146                user_agent,
147            },
148        })
149        .await?;
150    if !res.valid() {
151        let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
152            .with_session(session)
153            .with_csrf(csrf_token.form_value())
154            .with_language(locale);
155
156        let content = templates.render_policy_violation(&ctx)?;
157
158        return Ok((cookie_jar, Html(content)).into_response());
159    }
160
161    let ctx = ConsentContext::new(grant, client)
162        .with_session(session)
163        .with_csrf(csrf_token.form_value())
164        .with_language(locale);
165
166    let content = templates.render_consent(&ctx)?;
167
168    Ok((cookie_jar, Html(content)).into_response())
169}
170
171#[tracing::instrument(
172    name = "handlers.oauth2.authorization.consent.post",
173    fields(grant.id = %grant_id),
174    skip_all,
175    err,
176)]
177pub(crate) async fn post(
178    mut rng: BoxRng,
179    clock: BoxClock,
180    PreferredLanguage(locale): PreferredLanguage,
181    State(templates): State<Templates>,
182    State(key_store): State<Keystore>,
183    mut policy: Policy,
184    mut repo: BoxRepository,
185    activity_tracker: BoundActivityTracker,
186    user_agent: Option<TypedHeader<headers::UserAgent>>,
187    cookie_jar: CookieJar,
188    State(url_builder): State<UrlBuilder>,
189    Path(grant_id): Path<Ulid>,
190    Form(form): Form<ProtectedForm<()>>,
191) -> Result<Response, RouteError> {
192    cookie_jar.verify_form(&clock, form)?;
193
194    let (cookie_jar, maybe_session) = match load_session_or_fallback(
195        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
196    )
197    .await?
198    {
199        SessionOrFallback::MaybeSession {
200            cookie_jar,
201            maybe_session,
202            ..
203        } => (cookie_jar, maybe_session),
204        SessionOrFallback::Fallback { response } => return Ok(response),
205    };
206
207    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
208
209    let user_agent = user_agent.map(|ua| ua.to_string());
210
211    let grant = repo
212        .oauth2_authorization_grant()
213        .lookup(grant_id)
214        .await?
215        .ok_or(RouteError::GrantNotFound)?;
216    let callback_destination = CallbackDestination::try_from(&grant)?;
217
218    let Some(browser_session) = maybe_session else {
219        let next = PostAuthAction::continue_grant(grant_id);
220        let login = mas_router::Login::and_then(next);
221        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
222    };
223
224    activity_tracker
225        .record_browser_session(&clock, &browser_session)
226        .await;
227
228    let client = repo
229        .oauth2_client()
230        .lookup(grant.client_id)
231        .await?
232        .ok_or(RouteError::NoSuchClient)?;
233
234    let res = policy
235        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
236            user: Some(&browser_session.user),
237            client: &client,
238            scope: &grant.scope,
239            grant_type: mas_policy::GrantType::AuthorizationCode,
240            requester: mas_policy::Requester {
241                ip_address: activity_tracker.ip(),
242                user_agent,
243            },
244        })
245        .await?;
246
247    if !res.valid() {
248        let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
249            .with_session(browser_session)
250            .with_csrf(csrf_token.form_value())
251            .with_language(locale);
252
253        let content = templates.render_policy_violation(&ctx)?;
254
255        return Ok((cookie_jar, Html(content)).into_response());
256    }
257
258    // All good, let's start the session
259    let session = repo
260        .oauth2_session()
261        .add_from_browser_session(
262            &mut rng,
263            &clock,
264            &client,
265            &browser_session,
266            grant.scope.clone(),
267        )
268        .await?;
269
270    let grant = repo
271        .oauth2_authorization_grant()
272        .fulfill(&clock, &session, grant)
273        .await?;
274
275    let mut params = AuthorizationResponse::default();
276
277    // Did they request an ID token?
278    if grant.response_type_id_token {
279        // Fetch the last authentication
280        let last_authentication = repo
281            .browser_session()
282            .get_last_authentication(&browser_session)
283            .await?;
284
285        params.id_token = Some(generate_id_token(
286            &mut rng,
287            &clock,
288            &url_builder,
289            &key_store,
290            &client,
291            Some(&grant),
292            &browser_session,
293            None,
294            last_authentication.as_ref(),
295        )?);
296    }
297
298    // Did they request an auth code?
299    if let Some(code) = grant.code {
300        params.code = Some(code.code);
301    }
302
303    repo.save().await?;
304
305    activity_tracker
306        .record_oauth2_session(&clock, &session)
307        .await;
308
309    Ok((
310        cookie_jar,
311        callback_destination.go(&templates, &locale, params)?,
312    )
313        .into_response())
314}