1use std::{collections::HashMap, error::Error};
8
9use axum::{
10 extract::{
11 Form, FromRequest, FromRequestParts,
12 rejection::{FailedToDeserializeForm, FormRejection},
13 },
14 response::{IntoResponse, Response},
15};
16use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
17use headers::{Authorization, Header, HeaderMapExt, HeaderName, authorization::Bearer};
18use http::{HeaderMap, HeaderValue, Request, StatusCode, header::WWW_AUTHENTICATE};
19use mas_data_model::Session;
20use mas_storage::{
21 Clock, RepositoryAccess,
22 oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
23};
24use serde::{Deserialize, de::DeserializeOwned};
25use thiserror::Error;
26
27#[derive(Debug, Deserialize)]
28struct AuthorizedForm<F> {
29 #[serde(default)]
30 access_token: Option<String>,
31
32 #[serde(flatten)]
33 inner: F,
34}
35
36#[derive(Debug)]
37enum AccessToken {
38 Form(String),
39 Header(String),
40 None,
41}
42
43impl AccessToken {
44 async fn fetch<E>(
45 &self,
46 repo: &mut impl RepositoryAccess<Error = E>,
47 ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
48 let token = match self {
49 AccessToken::Form(t) | AccessToken::Header(t) => t,
50 AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
51 };
52
53 let token = repo
54 .oauth2_access_token()
55 .find_by_token(token.as_str())
56 .await?
57 .ok_or(AuthorizationVerificationError::InvalidToken)?;
58
59 let session = repo
60 .oauth2_session()
61 .lookup(token.session_id)
62 .await?
63 .ok_or(AuthorizationVerificationError::InvalidToken)?;
64
65 Ok((token, session))
66 }
67}
68
69#[derive(Debug)]
70pub struct UserAuthorization<F = ()> {
71 access_token: AccessToken,
72 form: Option<F>,
73}
74
75impl<F: Send> UserAuthorization<F> {
76 pub async fn protected_form<E>(
85 self,
86 repo: &mut impl RepositoryAccess<Error = E>,
87 clock: &impl Clock,
88 ) -> Result<(Session, F), AuthorizationVerificationError<E>> {
89 let Some(form) = self.form else {
90 return Err(AuthorizationVerificationError::MissingForm);
91 };
92
93 let (token, session) = self.access_token.fetch(repo).await?;
94
95 if !token.is_valid(clock.now()) || !session.is_valid() {
96 return Err(AuthorizationVerificationError::InvalidToken);
97 }
98
99 Ok((session, form))
100 }
101
102 pub async fn protected<E>(
109 self,
110 repo: &mut impl RepositoryAccess<Error = E>,
111 clock: &impl Clock,
112 ) -> Result<Session, AuthorizationVerificationError<E>> {
113 let (token, session) = self.access_token.fetch(repo).await?;
114
115 if !token.is_valid(clock.now()) || !session.is_valid() {
116 return Err(AuthorizationVerificationError::InvalidToken);
117 }
118
119 if !token.is_used() {
120 repo.oauth2_access_token().mark_used(clock, token).await?;
122 }
123
124 Ok(session)
125 }
126}
127
128pub enum UserAuthorizationError {
129 InvalidHeader,
130 TokenInFormAndHeader,
131 BadForm(FailedToDeserializeForm),
132 Internal(Box<dyn Error>),
133}
134
135#[derive(Debug, Error)]
136pub enum AuthorizationVerificationError<E> {
137 #[error("missing token")]
138 MissingToken,
139
140 #[error("invalid token")]
141 InvalidToken,
142
143 #[error("missing form")]
144 MissingForm,
145
146 #[error(transparent)]
147 Internal(#[from] E),
148}
149
150enum BearerError {
151 InvalidRequest,
152 InvalidToken,
153 #[allow(dead_code)]
154 InsufficientScope {
155 scope: Option<HeaderValue>,
156 },
157}
158
159impl BearerError {
160 fn error(&self) -> HeaderValue {
161 match self {
162 BearerError::InvalidRequest => HeaderValue::from_static("invalid_request"),
163 BearerError::InvalidToken => HeaderValue::from_static("invalid_token"),
164 BearerError::InsufficientScope { .. } => HeaderValue::from_static("insufficient_scope"),
165 }
166 }
167
168 fn params(&self) -> HashMap<&'static str, HeaderValue> {
169 match self {
170 BearerError::InsufficientScope { scope: Some(scope) } => {
171 let mut m = HashMap::new();
172 m.insert("scope", scope.clone());
173 m
174 }
175 _ => HashMap::new(),
176 }
177 }
178}
179
180enum WwwAuthenticate {
181 #[allow(dead_code)]
182 Basic { realm: HeaderValue },
183 Bearer {
184 realm: Option<HeaderValue>,
185 error: BearerError,
186 error_description: Option<HeaderValue>,
187 },
188}
189
190impl Header for WwwAuthenticate {
191 fn name() -> &'static HeaderName {
192 &WWW_AUTHENTICATE
193 }
194
195 fn decode<'i, I>(_values: &mut I) -> Result<Self, headers::Error>
196 where
197 Self: Sized,
198 I: Iterator<Item = &'i http::HeaderValue>,
199 {
200 Err(headers::Error::invalid())
201 }
202
203 fn encode<E: Extend<http::HeaderValue>>(&self, values: &mut E) {
204 let (scheme, params) = match self {
205 WwwAuthenticate::Basic { realm } => {
206 let mut params = HashMap::new();
207 params.insert("realm", realm.clone());
208 ("Basic", params)
209 }
210 WwwAuthenticate::Bearer {
211 realm,
212 error,
213 error_description,
214 } => {
215 let mut params = error.params();
216 params.insert("error", error.error());
217
218 if let Some(realm) = realm {
219 params.insert("realm", realm.clone());
220 }
221
222 if let Some(error_description) = error_description {
223 params.insert("error_description", error_description.clone());
224 }
225
226 ("Bearer", params)
227 }
228 };
229
230 let params = params.into_iter().map(|(k, v)| format!(" {k}={v:?}"));
231 let value: String = std::iter::once(scheme.to_owned()).chain(params).collect();
232 let value = HeaderValue::from_str(&value).unwrap();
233 values.extend(std::iter::once(value));
234 }
235}
236
237impl IntoResponse for UserAuthorizationError {
238 fn into_response(self) -> Response {
239 match self {
240 Self::BadForm(_) | Self::InvalidHeader | Self::TokenInFormAndHeader => {
241 let mut headers = HeaderMap::new();
242
243 headers.typed_insert(WwwAuthenticate::Bearer {
244 realm: None,
245 error: BearerError::InvalidRequest,
246 error_description: None,
247 });
248 (StatusCode::BAD_REQUEST, headers).into_response()
249 }
250 Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
251 }
252 }
253}
254
255impl<E> IntoResponse for AuthorizationVerificationError<E>
256where
257 E: ToString,
258{
259 fn into_response(self) -> Response {
260 match self {
261 Self::MissingForm | Self::MissingToken => {
262 let mut headers = HeaderMap::new();
263
264 headers.typed_insert(WwwAuthenticate::Bearer {
265 realm: None,
266 error: BearerError::InvalidRequest,
267 error_description: None,
268 });
269 (StatusCode::BAD_REQUEST, headers).into_response()
270 }
271 Self::InvalidToken => {
272 let mut headers = HeaderMap::new();
273
274 headers.typed_insert(WwwAuthenticate::Bearer {
275 realm: None,
276 error: BearerError::InvalidToken,
277 error_description: None,
278 });
279 (StatusCode::BAD_REQUEST, headers).into_response()
280 }
281 Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
282 }
283 }
284}
285
286impl<S, F> FromRequest<S> for UserAuthorization<F>
287where
288 F: DeserializeOwned,
289 S: Send + Sync,
290{
291 type Rejection = UserAuthorizationError;
292
293 async fn from_request(
294 req: Request<axum::body::Body>,
295 state: &S,
296 ) -> Result<Self, Self::Rejection> {
297 let (mut parts, body) = req.into_parts();
298 let header =
299 TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
300
301 let token_from_header = match header {
303 Ok(header) => Some(header.token().to_owned()),
304 Err(err) => match err.reason() {
305 TypedHeaderRejectionReason::Missing => None,
307 _ => return Err(UserAuthorizationError::InvalidHeader),
309 },
310 };
311
312 let req = Request::from_parts(parts, body);
313
314 let (token_from_form, form) =
316 match Form::<AuthorizedForm<F>>::from_request(req, state).await {
317 Ok(Form(form)) => (form.access_token, Some(form.inner)),
318 Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
320 Err(FormRejection::FailedToDeserializeForm(err)) => {
322 return Err(UserAuthorizationError::BadForm(err));
323 }
324 Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))),
326 };
327
328 let access_token = match (token_from_header, token_from_form) {
329 (Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
331 (Some(t), None) => AccessToken::Header(t),
332 (None, Some(t)) => AccessToken::Form(t),
333 (None, None) => AccessToken::None,
334 };
335
336 Ok(UserAuthorization { access_token, form })
337 }
338}