mas_axum_utils/
user_authorization.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{collections::HashMap, error::Error};
8
9use axum::{
10    extract::{
11        Form, FromRequest, FromRequestParts,
12        rejection::{FailedToDeserializeForm, FormRejection},
13    },
14    response::{IntoResponse, Response},
15};
16use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
17use headers::{Authorization, Header, HeaderMapExt, HeaderName, authorization::Bearer};
18use http::{HeaderMap, HeaderValue, Request, StatusCode, header::WWW_AUTHENTICATE};
19use mas_data_model::Session;
20use mas_storage::{
21    Clock, RepositoryAccess,
22    oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
23};
24use serde::{Deserialize, de::DeserializeOwned};
25use thiserror::Error;
26
27#[derive(Debug, Deserialize)]
28struct AuthorizedForm<F> {
29    #[serde(default)]
30    access_token: Option<String>,
31
32    #[serde(flatten)]
33    inner: F,
34}
35
36#[derive(Debug)]
37enum AccessToken {
38    Form(String),
39    Header(String),
40    None,
41}
42
43impl AccessToken {
44    async fn fetch<E>(
45        &self,
46        repo: &mut impl RepositoryAccess<Error = E>,
47    ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
48        let token = match self {
49            AccessToken::Form(t) | AccessToken::Header(t) => t,
50            AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
51        };
52
53        let token = repo
54            .oauth2_access_token()
55            .find_by_token(token.as_str())
56            .await?
57            .ok_or(AuthorizationVerificationError::InvalidToken)?;
58
59        let session = repo
60            .oauth2_session()
61            .lookup(token.session_id)
62            .await?
63            .ok_or(AuthorizationVerificationError::InvalidToken)?;
64
65        Ok((token, session))
66    }
67}
68
69#[derive(Debug)]
70pub struct UserAuthorization<F = ()> {
71    access_token: AccessToken,
72    form: Option<F>,
73}
74
75impl<F: Send> UserAuthorization<F> {
76    // TODO: take scopes to validate as parameter
77    /// Verify a user authorization and return the session and the protected
78    /// form value
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if the token is invalid, if the user session ended or
83    /// if the form is missing
84    pub async fn protected_form<E>(
85        self,
86        repo: &mut impl RepositoryAccess<Error = E>,
87        clock: &impl Clock,
88    ) -> Result<(Session, F), AuthorizationVerificationError<E>> {
89        let Some(form) = self.form else {
90            return Err(AuthorizationVerificationError::MissingForm);
91        };
92
93        let (token, session) = self.access_token.fetch(repo).await?;
94
95        if !token.is_valid(clock.now()) || !session.is_valid() {
96            return Err(AuthorizationVerificationError::InvalidToken);
97        }
98
99        Ok((session, form))
100    }
101
102    // TODO: take scopes to validate as parameter
103    /// Verify a user authorization and return the session
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the token is invalid or if the user session ended
108    pub async fn protected<E>(
109        self,
110        repo: &mut impl RepositoryAccess<Error = E>,
111        clock: &impl Clock,
112    ) -> Result<Session, AuthorizationVerificationError<E>> {
113        let (token, session) = self.access_token.fetch(repo).await?;
114
115        if !token.is_valid(clock.now()) || !session.is_valid() {
116            return Err(AuthorizationVerificationError::InvalidToken);
117        }
118
119        if !token.is_used() {
120            // Mark the token as used
121            repo.oauth2_access_token().mark_used(clock, token).await?;
122        }
123
124        Ok(session)
125    }
126}
127
128pub enum UserAuthorizationError {
129    InvalidHeader,
130    TokenInFormAndHeader,
131    BadForm(FailedToDeserializeForm),
132    Internal(Box<dyn Error>),
133}
134
135#[derive(Debug, Error)]
136pub enum AuthorizationVerificationError<E> {
137    #[error("missing token")]
138    MissingToken,
139
140    #[error("invalid token")]
141    InvalidToken,
142
143    #[error("missing form")]
144    MissingForm,
145
146    #[error(transparent)]
147    Internal(#[from] E),
148}
149
150enum BearerError {
151    InvalidRequest,
152    InvalidToken,
153    #[allow(dead_code)]
154    InsufficientScope {
155        scope: Option<HeaderValue>,
156    },
157}
158
159impl BearerError {
160    fn error(&self) -> HeaderValue {
161        match self {
162            BearerError::InvalidRequest => HeaderValue::from_static("invalid_request"),
163            BearerError::InvalidToken => HeaderValue::from_static("invalid_token"),
164            BearerError::InsufficientScope { .. } => HeaderValue::from_static("insufficient_scope"),
165        }
166    }
167
168    fn params(&self) -> HashMap<&'static str, HeaderValue> {
169        match self {
170            BearerError::InsufficientScope { scope: Some(scope) } => {
171                let mut m = HashMap::new();
172                m.insert("scope", scope.clone());
173                m
174            }
175            _ => HashMap::new(),
176        }
177    }
178}
179
180enum WwwAuthenticate {
181    #[allow(dead_code)]
182    Basic { realm: HeaderValue },
183    Bearer {
184        realm: Option<HeaderValue>,
185        error: BearerError,
186        error_description: Option<HeaderValue>,
187    },
188}
189
190impl Header for WwwAuthenticate {
191    fn name() -> &'static HeaderName {
192        &WWW_AUTHENTICATE
193    }
194
195    fn decode<'i, I>(_values: &mut I) -> Result<Self, headers::Error>
196    where
197        Self: Sized,
198        I: Iterator<Item = &'i http::HeaderValue>,
199    {
200        Err(headers::Error::invalid())
201    }
202
203    fn encode<E: Extend<http::HeaderValue>>(&self, values: &mut E) {
204        let (scheme, params) = match self {
205            WwwAuthenticate::Basic { realm } => {
206                let mut params = HashMap::new();
207                params.insert("realm", realm.clone());
208                ("Basic", params)
209            }
210            WwwAuthenticate::Bearer {
211                realm,
212                error,
213                error_description,
214            } => {
215                let mut params = error.params();
216                params.insert("error", error.error());
217
218                if let Some(realm) = realm {
219                    params.insert("realm", realm.clone());
220                }
221
222                if let Some(error_description) = error_description {
223                    params.insert("error_description", error_description.clone());
224                }
225
226                ("Bearer", params)
227            }
228        };
229
230        let params = params.into_iter().map(|(k, v)| format!(" {k}={v:?}"));
231        let value: String = std::iter::once(scheme.to_owned()).chain(params).collect();
232        let value = HeaderValue::from_str(&value).unwrap();
233        values.extend(std::iter::once(value));
234    }
235}
236
237impl IntoResponse for UserAuthorizationError {
238    fn into_response(self) -> Response {
239        match self {
240            Self::BadForm(_) | Self::InvalidHeader | Self::TokenInFormAndHeader => {
241                let mut headers = HeaderMap::new();
242
243                headers.typed_insert(WwwAuthenticate::Bearer {
244                    realm: None,
245                    error: BearerError::InvalidRequest,
246                    error_description: None,
247                });
248                (StatusCode::BAD_REQUEST, headers).into_response()
249            }
250            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
251        }
252    }
253}
254
255impl<E> IntoResponse for AuthorizationVerificationError<E>
256where
257    E: ToString,
258{
259    fn into_response(self) -> Response {
260        match self {
261            Self::MissingForm | Self::MissingToken => {
262                let mut headers = HeaderMap::new();
263
264                headers.typed_insert(WwwAuthenticate::Bearer {
265                    realm: None,
266                    error: BearerError::InvalidRequest,
267                    error_description: None,
268                });
269                (StatusCode::BAD_REQUEST, headers).into_response()
270            }
271            Self::InvalidToken => {
272                let mut headers = HeaderMap::new();
273
274                headers.typed_insert(WwwAuthenticate::Bearer {
275                    realm: None,
276                    error: BearerError::InvalidToken,
277                    error_description: None,
278                });
279                (StatusCode::BAD_REQUEST, headers).into_response()
280            }
281            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
282        }
283    }
284}
285
286impl<S, F> FromRequest<S> for UserAuthorization<F>
287where
288    F: DeserializeOwned,
289    S: Send + Sync,
290{
291    type Rejection = UserAuthorizationError;
292
293    async fn from_request(
294        req: Request<axum::body::Body>,
295        state: &S,
296    ) -> Result<Self, Self::Rejection> {
297        let (mut parts, body) = req.into_parts();
298        let header =
299            TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
300
301        // Take the Authorization header
302        let token_from_header = match header {
303            Ok(header) => Some(header.token().to_owned()),
304            Err(err) => match err.reason() {
305                // If it's missing it is fine
306                TypedHeaderRejectionReason::Missing => None,
307                // If the header could not be parsed, return the error
308                _ => return Err(UserAuthorizationError::InvalidHeader),
309            },
310        };
311
312        let req = Request::from_parts(parts, body);
313
314        // Take the form value
315        let (token_from_form, form) =
316            match Form::<AuthorizedForm<F>>::from_request(req, state).await {
317                Ok(Form(form)) => (form.access_token, Some(form.inner)),
318                // If it is not a form, continue
319                Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
320                // If the form could not be read, return a Bad Request error
321                Err(FormRejection::FailedToDeserializeForm(err)) => {
322                    return Err(UserAuthorizationError::BadForm(err));
323                }
324                // Other errors (body read twice, byte stream broke) return an internal error
325                Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))),
326            };
327
328        let access_token = match (token_from_header, token_from_form) {
329            // Ensure the token should not be in both the form and the access token
330            (Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
331            (Some(t), None) => AccessToken::Header(t),
332            (None, Some(t)) => AccessToken::Form(t),
333            (None, None) => AccessToken::None,
334        };
335
336        Ok(UserAuthorization { access_token, form })
337    }
338}