mas_handlers/oauth2/authorization/
consent.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{sync::Arc, time::Duration};
8
9use axum::{
10    extract::{Form, Path, State},
11    response::{Html, IntoResponse, Response},
12};
13use axum_extra::TypedHeader;
14use hyper::StatusCode;
15use mas_axum_utils::{
16    GenericError, InternalError,
17    cookies::CookieJar,
18    csrf::{CsrfExt, ProtectedForm},
19};
20use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng, MatrixUser};
21use mas_keystore::Keystore;
22use mas_matrix::HomeserverConnection;
23use mas_policy::Policy;
24use mas_router::{PostAuthAction, UrlBuilder};
25use mas_storage::{
26    BoxRepository,
27    oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
28};
29use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Templates};
30use oauth2_types::requests::AuthorizationResponse;
31use thiserror::Error;
32use ulid::Ulid;
33
34use super::callback::CallbackDestination;
35use crate::{
36    BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
37    oauth2::generate_id_token,
38    session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
39};
40
41#[derive(Debug, Error)]
42pub enum RouteError {
43    #[error(transparent)]
44    Internal(Box<dyn std::error::Error + Send + Sync>),
45
46    #[error(transparent)]
47    Csrf(#[from] mas_axum_utils::csrf::CsrfError),
48
49    #[error("Authorization grant not found")]
50    GrantNotFound,
51
52    #[error("Authorization grant {0} already used")]
53    GrantNotPending(Ulid),
54
55    #[error("Failed to load client {0}")]
56    NoSuchClient(Ulid),
57}
58
59impl_from_error_for_route!(mas_templates::TemplateError);
60impl_from_error_for_route!(mas_storage::RepositoryError);
61impl_from_error_for_route!(mas_policy::LoadError);
62impl_from_error_for_route!(mas_policy::EvaluationError);
63impl_from_error_for_route!(crate::session::SessionLoadError);
64impl_from_error_for_route!(crate::oauth2::IdTokenSignatureError);
65impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
66impl_from_error_for_route!(super::callback::CallbackDestinationError);
67
68impl IntoResponse for RouteError {
69    fn into_response(self) -> axum::response::Response {
70        match self {
71            Self::Internal(e) => InternalError::new(e).into_response(),
72            e @ Self::NoSuchClient(_) => InternalError::new(Box::new(e)).into_response(),
73            e @ Self::GrantNotFound => GenericError::new(StatusCode::NOT_FOUND, e).into_response(),
74            e @ Self::GrantNotPending(_) => {
75                GenericError::new(StatusCode::CONFLICT, e).into_response()
76            }
77            e @ Self::Csrf(_) => GenericError::new(StatusCode::BAD_REQUEST, e).into_response(),
78        }
79    }
80}
81
82#[tracing::instrument(
83    name = "handlers.oauth2.authorization.consent.get",
84    fields(grant.id = %grant_id),
85    skip_all,
86)]
87pub(crate) async fn get(
88    mut rng: BoxRng,
89    clock: BoxClock,
90    PreferredLanguage(locale): PreferredLanguage,
91    State(templates): State<Templates>,
92    State(url_builder): State<UrlBuilder>,
93    State(homeserver): State<Arc<dyn HomeserverConnection>>,
94    mut policy: Policy,
95    mut repo: BoxRepository,
96    activity_tracker: BoundActivityTracker,
97    user_agent: Option<TypedHeader<headers::UserAgent>>,
98    cookie_jar: CookieJar,
99    Path(grant_id): Path<Ulid>,
100) -> Result<Response, RouteError> {
101    let (cookie_jar, maybe_session) = match load_session_or_fallback(
102        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
103    )
104    .await?
105    {
106        SessionOrFallback::MaybeSession {
107            cookie_jar,
108            maybe_session,
109            ..
110        } => (cookie_jar, maybe_session),
111        SessionOrFallback::Fallback { response } => return Ok(response),
112    };
113
114    let user_agent = user_agent.map(|ua| ua.to_string());
115
116    let grant = repo
117        .oauth2_authorization_grant()
118        .lookup(grant_id)
119        .await?
120        .ok_or(RouteError::GrantNotFound)?;
121
122    let client = repo
123        .oauth2_client()
124        .lookup(grant.client_id)
125        .await?
126        .ok_or(RouteError::NoSuchClient(grant.client_id))?;
127
128    if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
129        return Err(RouteError::GrantNotPending(grant.id));
130    }
131
132    let Some(session) = maybe_session else {
133        let login = mas_router::Login::and_continue_grant(grant_id);
134        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
135    };
136
137    activity_tracker
138        .record_browser_session(&clock, &session)
139        .await;
140
141    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
142
143    let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
144
145    // We can close the repository early, we don't need it at this point
146    repo.save().await?;
147
148    let res = policy
149        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
150            user: Some(&session.user),
151            client: &client,
152            session_counts: Some(session_counts),
153            scope: &grant.scope,
154            grant_type: mas_policy::GrantType::AuthorizationCode,
155            requester: mas_policy::Requester {
156                ip_address: activity_tracker.ip(),
157                user_agent,
158            },
159        })
160        .await?;
161    if !res.valid() {
162        let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
163            .with_session(session)
164            .with_csrf(csrf_token.form_value())
165            .with_language(locale);
166
167        let content = templates.render_policy_violation(&ctx)?;
168
169        return Ok((cookie_jar, Html(content)).into_response());
170    }
171
172    // Fetch informations about the user. This is purely cosmetic, so we let it
173    // fail and put a 1s timeout to it in case we fail to query it
174    // XXX: we're likely to need this in other places
175    let localpart = &session.user.username;
176    let display_name = match tokio::time::timeout(
177        Duration::from_secs(1),
178        homeserver.query_user(localpart),
179    )
180    .await
181    {
182        Ok(Ok(user)) => user.displayname,
183        Ok(Err(err)) => {
184            tracing::warn!(
185                error = &*err as &dyn std::error::Error,
186                localpart,
187                "Failed to query user"
188            );
189            None
190        }
191        Err(_) => {
192            tracing::warn!(localpart, "Timed out while querying user");
193            None
194        }
195    };
196
197    let matrix_user = MatrixUser {
198        mxid: homeserver.mxid(localpart),
199        display_name,
200    };
201
202    let ctx = ConsentContext::new(grant, client, matrix_user)
203        .with_session(session)
204        .with_csrf(csrf_token.form_value())
205        .with_language(locale);
206
207    let content = templates.render_consent(&ctx)?;
208
209    Ok((cookie_jar, Html(content)).into_response())
210}
211
212#[tracing::instrument(
213    name = "handlers.oauth2.authorization.consent.post",
214    fields(grant.id = %grant_id),
215    skip_all,
216)]
217pub(crate) async fn post(
218    mut rng: BoxRng,
219    clock: BoxClock,
220    PreferredLanguage(locale): PreferredLanguage,
221    State(templates): State<Templates>,
222    State(key_store): State<Keystore>,
223    mut policy: Policy,
224    mut repo: BoxRepository,
225    activity_tracker: BoundActivityTracker,
226    user_agent: Option<TypedHeader<headers::UserAgent>>,
227    cookie_jar: CookieJar,
228    State(url_builder): State<UrlBuilder>,
229    Path(grant_id): Path<Ulid>,
230    Form(form): Form<ProtectedForm<()>>,
231) -> Result<Response, RouteError> {
232    cookie_jar.verify_form(&clock, form)?;
233
234    let (cookie_jar, maybe_session) = match load_session_or_fallback(
235        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
236    )
237    .await?
238    {
239        SessionOrFallback::MaybeSession {
240            cookie_jar,
241            maybe_session,
242            ..
243        } => (cookie_jar, maybe_session),
244        SessionOrFallback::Fallback { response } => return Ok(response),
245    };
246
247    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
248
249    let user_agent = user_agent.map(|ua| ua.to_string());
250
251    let grant = repo
252        .oauth2_authorization_grant()
253        .lookup(grant_id)
254        .await?
255        .ok_or(RouteError::GrantNotFound)?;
256    let callback_destination = CallbackDestination::try_from(&grant)?;
257
258    let Some(browser_session) = maybe_session else {
259        let next = PostAuthAction::continue_grant(grant_id);
260        let login = mas_router::Login::and_then(next);
261        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
262    };
263
264    activity_tracker
265        .record_browser_session(&clock, &browser_session)
266        .await;
267
268    let client = repo
269        .oauth2_client()
270        .lookup(grant.client_id)
271        .await?
272        .ok_or(RouteError::NoSuchClient(grant.client_id))?;
273
274    if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
275        return Err(RouteError::GrantNotPending(grant.id));
276    }
277
278    let session_counts = count_user_sessions_for_limiting(&mut repo, &browser_session.user).await?;
279
280    let res = policy
281        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
282            user: Some(&browser_session.user),
283            client: &client,
284            session_counts: Some(session_counts),
285            scope: &grant.scope,
286            grant_type: mas_policy::GrantType::AuthorizationCode,
287            requester: mas_policy::Requester {
288                ip_address: activity_tracker.ip(),
289                user_agent,
290            },
291        })
292        .await?;
293
294    if !res.valid() {
295        let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
296            .with_session(browser_session)
297            .with_csrf(csrf_token.form_value())
298            .with_language(locale);
299
300        let content = templates.render_policy_violation(&ctx)?;
301
302        return Ok((cookie_jar, Html(content)).into_response());
303    }
304
305    // All good, let's start the session
306    let session = repo
307        .oauth2_session()
308        .add_from_browser_session(
309            &mut rng,
310            &clock,
311            &client,
312            &browser_session,
313            grant.scope.clone(),
314        )
315        .await?;
316
317    let grant = repo
318        .oauth2_authorization_grant()
319        .fulfill(&clock, &session, grant)
320        .await?;
321
322    let mut params = AuthorizationResponse::default();
323
324    // Did they request an ID token?
325    if grant.response_type_id_token {
326        // Fetch the last authentication
327        let last_authentication = repo
328            .browser_session()
329            .get_last_authentication(&browser_session)
330            .await?;
331
332        params.id_token = Some(generate_id_token(
333            &mut rng,
334            &clock,
335            &url_builder,
336            &key_store,
337            &client,
338            Some(&grant),
339            &browser_session,
340            None,
341            last_authentication.as_ref(),
342        )?);
343    }
344
345    // Did they request an auth code?
346    if let Some(code) = grant.code {
347        params.code = Some(code.code);
348    }
349
350    repo.save().await?;
351
352    activity_tracker
353        .record_oauth2_session(&clock, &session)
354        .await;
355
356    Ok((
357        cookie_jar,
358        callback_destination.go(&templates, &locale, params)?,
359    )
360        .into_response())
361}