mas_handlers/oauth2/authorization/
callback.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7#![allow(clippy::module_name_repetitions)]
8
9use std::collections::HashMap;
10
11use axum::response::{Html, IntoResponse, Redirect, Response};
12use mas_data_model::AuthorizationGrant;
13use mas_i18n::DataLocale;
14use mas_templates::{FormPostContext, Templates};
15use oauth2_types::requests::ResponseMode;
16use serde::Serialize;
17use thiserror::Error;
18use url::Url;
19
20#[derive(Debug, Clone)]
21enum CallbackDestinationMode {
22    Query {
23        existing_params: HashMap<String, String>,
24    },
25    Fragment,
26    FormPost,
27}
28
29#[derive(Debug, Clone)]
30pub struct CallbackDestination {
31    mode: CallbackDestinationMode,
32    safe_redirect_uri: Url,
33    state: Option<String>,
34}
35
36#[derive(Debug, Error)]
37pub enum IntoCallbackDestinationError {
38    #[error("Redirect URI can't have a fragment")]
39    RedirectUriFragmentNotAllowed,
40
41    #[error("Existing query parameters are not valid")]
42    RedirectUriInvalidQueryParams(#[from] serde_urlencoded::de::Error),
43
44    #[error("Requested response_mode is not supported")]
45    UnsupportedResponseMode,
46}
47
48#[derive(Debug, Error)]
49pub enum CallbackDestinationError {
50    #[error("Failed to render the form_post template")]
51    FormPostRender(#[from] mas_templates::TemplateError),
52
53    #[error("Failed to serialize parameters query string")]
54    ParamsSerialization(#[from] serde_urlencoded::ser::Error),
55}
56
57impl TryFrom<&AuthorizationGrant> for CallbackDestination {
58    type Error = IntoCallbackDestinationError;
59
60    fn try_from(value: &AuthorizationGrant) -> Result<Self, Self::Error> {
61        Self::try_new(
62            &value.response_mode,
63            value.redirect_uri.clone(),
64            value.state.clone(),
65        )
66    }
67}
68
69impl CallbackDestination {
70    pub fn try_new(
71        mode: &ResponseMode,
72        mut redirect_uri: Url,
73        state: Option<String>,
74    ) -> Result<Self, IntoCallbackDestinationError> {
75        if redirect_uri.fragment().is_some() {
76            return Err(IntoCallbackDestinationError::RedirectUriFragmentNotAllowed);
77        }
78
79        let mode = match mode {
80            ResponseMode::Query => {
81                let existing_params = redirect_uri
82                    .query()
83                    .map(serde_urlencoded::from_str)
84                    .transpose()?
85                    .unwrap_or_default();
86
87                // Remove the query from the URL
88                redirect_uri.set_query(None);
89
90                CallbackDestinationMode::Query { existing_params }
91            }
92            ResponseMode::Fragment => CallbackDestinationMode::Fragment,
93            ResponseMode::FormPost => CallbackDestinationMode::FormPost,
94            _ => return Err(IntoCallbackDestinationError::UnsupportedResponseMode),
95        };
96
97        Ok(Self {
98            mode,
99            safe_redirect_uri: redirect_uri,
100            state,
101        })
102    }
103
104    pub async fn go<T: Serialize + Send + Sync>(
105        self,
106        templates: &Templates,
107        locale: &DataLocale,
108        params: T,
109    ) -> Result<Response, CallbackDestinationError> {
110        #[derive(Serialize)]
111        struct AllParams<'s, T> {
112            #[serde(flatten, skip_serializing_if = "Option::is_none")]
113            existing: Option<&'s HashMap<String, String>>,
114
115            #[serde(skip_serializing_if = "Option::is_none")]
116            state: Option<String>,
117
118            #[serde(flatten)]
119            params: T,
120        }
121
122        let mut redirect_uri = self.safe_redirect_uri;
123        let state = self.state;
124
125        match self.mode {
126            CallbackDestinationMode::Query { existing_params } => {
127                let merged = AllParams {
128                    existing: Some(&existing_params),
129                    state,
130                    params,
131                };
132
133                let new_qs = serde_urlencoded::to_string(merged)?;
134
135                redirect_uri.set_query(Some(&new_qs));
136
137                Ok(Redirect::to(redirect_uri.as_str()).into_response())
138            }
139
140            CallbackDestinationMode::Fragment => {
141                let merged = AllParams {
142                    existing: None,
143                    state,
144                    params,
145                };
146
147                let new_qs = serde_urlencoded::to_string(merged)?;
148
149                redirect_uri.set_fragment(Some(&new_qs));
150
151                Ok(Redirect::to(redirect_uri.as_str()).into_response())
152            }
153
154            CallbackDestinationMode::FormPost => {
155                let merged = AllParams {
156                    existing: None,
157                    state,
158                    params,
159                };
160                let ctx = FormPostContext::new_for_url(redirect_uri, merged).with_language(locale);
161                let rendered = templates.render_form_post(&ctx)?;
162                Ok(Html(rendered).into_response())
163            }
164        }
165    }
166}