mas_handlers/oauth2/device/
consent.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use anyhow::Context;
8use axum::{
9    Form,
10    extract::{Path, State},
11    response::{Html, IntoResponse, Response},
12};
13use axum_extra::TypedHeader;
14use mas_axum_utils::{
15    InternalError,
16    cookies::CookieJar,
17    csrf::{CsrfExt, ProtectedForm},
18};
19use mas_data_model::{BoxClock, BoxRng};
20use mas_policy::Policy;
21use mas_router::UrlBuilder;
22use mas_storage::BoxRepository;
23use mas_templates::{DeviceConsentContext, PolicyViolationContext, TemplateContext, Templates};
24use serde::Deserialize;
25use tracing::warn;
26use ulid::Ulid;
27
28use crate::{
29    BoundActivityTracker, PreferredLanguage,
30    session::{SessionOrFallback, load_session_or_fallback},
31};
32
33#[derive(Deserialize, Debug)]
34#[serde(rename_all = "lowercase")]
35enum Action {
36    Consent,
37    Reject,
38}
39
40#[derive(Deserialize, Debug)]
41pub(crate) struct ConsentForm {
42    action: Action,
43}
44
45#[tracing::instrument(name = "handlers.oauth2.device.consent.get", skip_all)]
46pub(crate) async fn get(
47    mut rng: BoxRng,
48    clock: BoxClock,
49    PreferredLanguage(locale): PreferredLanguage,
50    State(templates): State<Templates>,
51    State(url_builder): State<UrlBuilder>,
52    mut repo: BoxRepository,
53    mut policy: Policy,
54    activity_tracker: BoundActivityTracker,
55    user_agent: Option<TypedHeader<headers::UserAgent>>,
56    cookie_jar: CookieJar,
57    Path(grant_id): Path<Ulid>,
58) -> Result<Response, InternalError> {
59    let (cookie_jar, maybe_session) = match load_session_or_fallback(
60        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
61    )
62    .await?
63    {
64        SessionOrFallback::MaybeSession {
65            cookie_jar,
66            maybe_session,
67            ..
68        } => (cookie_jar, maybe_session),
69        SessionOrFallback::Fallback { response } => return Ok(response),
70    };
71
72    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
73
74    let user_agent = user_agent.map(|ua| ua.to_string());
75
76    let Some(session) = maybe_session else {
77        let login = mas_router::Login::and_continue_device_code_grant(grant_id);
78        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
79    };
80
81    activity_tracker
82        .record_browser_session(&clock, &session)
83        .await;
84
85    // TODO: better error handling
86    let grant = repo
87        .oauth2_device_code_grant()
88        .lookup(grant_id)
89        .await?
90        .context("Device grant not found")
91        .map_err(InternalError::from_anyhow)?;
92
93    if grant.expires_at < clock.now() {
94        return Err(InternalError::from_anyhow(anyhow::anyhow!(
95            "Grant is expired"
96        )));
97    }
98
99    let client = repo
100        .oauth2_client()
101        .lookup(grant.client_id)
102        .await?
103        .context("Client not found")
104        .map_err(InternalError::from_anyhow)?;
105
106    // Evaluate the policy
107    let res = policy
108        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
109            grant_type: mas_policy::GrantType::DeviceCode,
110            client: &client,
111            scope: &grant.scope,
112            user: Some(&session.user),
113            requester: mas_policy::Requester {
114                ip_address: activity_tracker.ip(),
115                user_agent,
116            },
117        })
118        .await?;
119    if !res.valid() {
120        warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
121
122        let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
123        let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
124            .with_session(session)
125            .with_csrf(csrf_token.form_value())
126            .with_language(locale);
127
128        let content = templates.render_policy_violation(&ctx)?;
129
130        return Ok((cookie_jar, Html(content)).into_response());
131    }
132
133    let ctx = DeviceConsentContext::new(grant, client)
134        .with_session(session)
135        .with_csrf(csrf_token.form_value())
136        .with_language(locale);
137
138    let rendered = templates
139        .render_device_consent(&ctx)
140        .context("Failed to render template")
141        .map_err(InternalError::from_anyhow)?;
142
143    Ok((cookie_jar, Html(rendered)).into_response())
144}
145
146#[tracing::instrument(name = "handlers.oauth2.device.consent.post", skip_all)]
147pub(crate) async fn post(
148    mut rng: BoxRng,
149    clock: BoxClock,
150    PreferredLanguage(locale): PreferredLanguage,
151    State(templates): State<Templates>,
152    State(url_builder): State<UrlBuilder>,
153    mut repo: BoxRepository,
154    mut policy: Policy,
155    activity_tracker: BoundActivityTracker,
156    user_agent: Option<TypedHeader<headers::UserAgent>>,
157    cookie_jar: CookieJar,
158    Path(grant_id): Path<Ulid>,
159    Form(form): Form<ProtectedForm<ConsentForm>>,
160) -> Result<Response, InternalError> {
161    let form = cookie_jar.verify_form(&clock, form)?;
162    let (cookie_jar, maybe_session) = match load_session_or_fallback(
163        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
164    )
165    .await?
166    {
167        SessionOrFallback::MaybeSession {
168            cookie_jar,
169            maybe_session,
170            ..
171        } => (cookie_jar, maybe_session),
172        SessionOrFallback::Fallback { response } => return Ok(response),
173    };
174    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
175
176    let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());
177
178    let Some(session) = maybe_session else {
179        let login = mas_router::Login::and_continue_device_code_grant(grant_id);
180        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
181    };
182
183    activity_tracker
184        .record_browser_session(&clock, &session)
185        .await;
186
187    // TODO: better error handling
188    let grant = repo
189        .oauth2_device_code_grant()
190        .lookup(grant_id)
191        .await?
192        .context("Device grant not found")
193        .map_err(InternalError::from_anyhow)?;
194
195    if grant.expires_at < clock.now() {
196        return Err(InternalError::from_anyhow(anyhow::anyhow!(
197            "Grant is expired"
198        )));
199    }
200
201    let client = repo
202        .oauth2_client()
203        .lookup(grant.client_id)
204        .await?
205        .context("Client not found")
206        .map_err(InternalError::from_anyhow)?;
207
208    // Evaluate the policy
209    let res = policy
210        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
211            grant_type: mas_policy::GrantType::DeviceCode,
212            client: &client,
213            scope: &grant.scope,
214            user: Some(&session.user),
215            requester: mas_policy::Requester {
216                ip_address: activity_tracker.ip(),
217                user_agent,
218            },
219        })
220        .await?;
221    if !res.valid() {
222        warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
223
224        let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
225        let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
226            .with_session(session)
227            .with_csrf(csrf_token.form_value())
228            .with_language(locale);
229
230        let content = templates.render_policy_violation(&ctx)?;
231
232        return Ok((cookie_jar, Html(content)).into_response());
233    }
234
235    let grant = if grant.is_pending() {
236        match form.action {
237            Action::Consent => {
238                repo.oauth2_device_code_grant()
239                    .fulfill(&clock, grant, &session)
240                    .await?
241            }
242            Action::Reject => {
243                repo.oauth2_device_code_grant()
244                    .reject(&clock, grant, &session)
245                    .await?
246            }
247        }
248    } else {
249        // XXX: In case we're not pending, let's just return the grant as-is
250        // since it might just be a form resubmission, and feedback is nice enough
251        warn!(
252            oauth2_device_code.id = %grant.id,
253            browser_session.id = %session.id,
254            user.id = %session.user.id,
255            "Grant is not pending",
256        );
257        grant
258    };
259
260    repo.save().await?;
261
262    let ctx = DeviceConsentContext::new(grant, client)
263        .with_session(session)
264        .with_csrf(csrf_token.form_value())
265        .with_language(locale);
266
267    let rendered = templates
268        .render_device_consent(&ctx)
269        .context("Failed to render template")
270        .map_err(InternalError::from_anyhow)?;
271
272    Ok((cookie_jar, Html(rendered)).into_response())
273}