mas_handlers/oauth2/authorization/
callback.rs1#![allow(clippy::module_name_repetitions)]
8
9use std::collections::HashMap;
10
11use axum::response::{Html, IntoResponse, Redirect, Response};
12use mas_data_model::AuthorizationGrant;
13use mas_i18n::DataLocale;
14use mas_templates::{FormPostContext, Templates};
15use oauth2_types::requests::ResponseMode;
16use serde::Serialize;
17use thiserror::Error;
18use url::Url;
19
20#[derive(Debug, Clone)]
21enum CallbackDestinationMode {
22 Query {
23 existing_params: HashMap<String, String>,
24 },
25 Fragment,
26 FormPost,
27}
28
29#[derive(Debug, Clone)]
30pub struct CallbackDestination {
31 mode: CallbackDestinationMode,
32 safe_redirect_uri: Url,
33 state: Option<String>,
34}
35
36#[derive(Debug, Error)]
37pub enum IntoCallbackDestinationError {
38 #[error("Redirect URI can't have a fragment")]
39 RedirectUriFragmentNotAllowed,
40
41 #[error("Existing query parameters are not valid")]
42 RedirectUriInvalidQueryParams(#[from] serde_urlencoded::de::Error),
43
44 #[error("Requested response_mode is not supported")]
45 UnsupportedResponseMode,
46}
47
48#[derive(Debug, Error)]
49pub enum CallbackDestinationError {
50 #[error("Failed to render the form_post template")]
51 FormPostRender(#[from] mas_templates::TemplateError),
52
53 #[error("Failed to serialize parameters query string")]
54 ParamsSerialization(#[from] serde_urlencoded::ser::Error),
55}
56
57impl TryFrom<&AuthorizationGrant> for CallbackDestination {
58 type Error = IntoCallbackDestinationError;
59
60 fn try_from(value: &AuthorizationGrant) -> Result<Self, Self::Error> {
61 Self::try_new(
62 &value.response_mode,
63 value.redirect_uri.clone(),
64 value.state.clone(),
65 )
66 }
67}
68
69impl CallbackDestination {
70 pub fn try_new(
71 mode: &ResponseMode,
72 mut redirect_uri: Url,
73 state: Option<String>,
74 ) -> Result<Self, IntoCallbackDestinationError> {
75 if redirect_uri.fragment().is_some() {
76 return Err(IntoCallbackDestinationError::RedirectUriFragmentNotAllowed);
77 }
78
79 let mode = match mode {
80 ResponseMode::Query => {
81 let existing_params = redirect_uri
82 .query()
83 .map(serde_urlencoded::from_str)
84 .transpose()?
85 .unwrap_or_default();
86
87 redirect_uri.set_query(None);
89
90 CallbackDestinationMode::Query { existing_params }
91 }
92 ResponseMode::Fragment => CallbackDestinationMode::Fragment,
93 ResponseMode::FormPost => CallbackDestinationMode::FormPost,
94 _ => return Err(IntoCallbackDestinationError::UnsupportedResponseMode),
95 };
96
97 Ok(Self {
98 mode,
99 safe_redirect_uri: redirect_uri,
100 state,
101 })
102 }
103
104 pub async fn go<T: Serialize + Send + Sync>(
105 self,
106 templates: &Templates,
107 locale: &DataLocale,
108 params: T,
109 ) -> Result<Response, CallbackDestinationError> {
110 #[derive(Serialize)]
111 struct AllParams<'s, T> {
112 #[serde(flatten, skip_serializing_if = "Option::is_none")]
113 existing: Option<&'s HashMap<String, String>>,
114
115 #[serde(skip_serializing_if = "Option::is_none")]
116 state: Option<String>,
117
118 #[serde(flatten)]
119 params: T,
120 }
121
122 let mut redirect_uri = self.safe_redirect_uri;
123 let state = self.state;
124
125 match self.mode {
126 CallbackDestinationMode::Query { existing_params } => {
127 let merged = AllParams {
128 existing: Some(&existing_params),
129 state,
130 params,
131 };
132
133 let new_qs = serde_urlencoded::to_string(merged)?;
134
135 redirect_uri.set_query(Some(&new_qs));
136
137 Ok(Redirect::to(redirect_uri.as_str()).into_response())
138 }
139
140 CallbackDestinationMode::Fragment => {
141 let merged = AllParams {
142 existing: None,
143 state,
144 params,
145 };
146
147 let new_qs = serde_urlencoded::to_string(merged)?;
148
149 redirect_uri.set_fragment(Some(&new_qs));
150
151 Ok(Redirect::to(redirect_uri.as_str()).into_response())
152 }
153
154 CallbackDestinationMode::FormPost => {
155 let merged = AllParams {
156 existing: None,
157 state,
158 params,
159 };
160 let ctx = FormPostContext::new_for_url(redirect_uri, merged).with_language(locale);
161 let rendered = templates.render_form_post(&ctx)?;
162 Ok(Html(rendered).into_response())
163 }
164 }
165 }
166}