mas_handlers/upstream_oauth2/
template.rs
1use std::{collections::HashMap, sync::Arc};
8
9use base64ct::{Base64, Base64Unpadded, Base64Url, Base64UrlUnpadded, Encoding};
10use minijinja::{
11 Environment, Error, ErrorKind, Value,
12 value::{Enumerator, Object},
13};
14
15#[derive(Debug, Default)]
24pub(crate) struct AttributeMappingContext {
25 id_token_claims: Option<HashMap<String, serde_json::Value>>,
26 extra_callback_parameters: Option<serde_json::Value>,
27 userinfo_claims: Option<serde_json::Value>,
28}
29
30impl AttributeMappingContext {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_id_token_claims(
36 mut self,
37 id_token_claims: HashMap<String, serde_json::Value>,
38 ) -> Self {
39 self.id_token_claims = Some(id_token_claims);
40 self
41 }
42
43 pub fn with_extra_callback_parameters(
44 mut self,
45 extra_callback_parameters: serde_json::Value,
46 ) -> Self {
47 self.extra_callback_parameters = Some(extra_callback_parameters);
48 self
49 }
50
51 pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self {
52 self.userinfo_claims = Some(userinfo_claims);
53 self
54 }
55
56 pub fn build(self) -> Value {
57 Value::from_object(self)
58 }
59}
60
61impl Object for AttributeMappingContext {
62 fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
63 match name.as_str()? {
64 "user" => {
65 if self.id_token_claims.is_none() && self.userinfo_claims.is_none() {
66 return None;
67 }
68 let mut merged_user: HashMap<String, serde_json::Value> = HashMap::new();
69 if let serde_json::Value::Object(userinfo) = self
70 .userinfo_claims
71 .clone()
72 .unwrap_or(serde_json::Value::Null)
73 {
74 merged_user.extend(userinfo);
75 }
76 if let Some(id_token) = self.id_token_claims.clone() {
77 merged_user.extend(id_token);
78 }
79 Some(Value::from_serialize(merged_user))
80 }
81 "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
82 "userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize),
83 "extra_callback_parameters" => self
84 .extra_callback_parameters
85 .as_ref()
86 .map(Value::from_serialize),
87 _ => None,
88 }
89 }
90
91 fn enumerate(self: &Arc<Self>) -> Enumerator {
92 let mut attrs = Vec::new();
93 if self.id_token_claims.is_some() || self.userinfo_claims.is_none() {
94 attrs.push(minijinja::Value::from("user"));
95 }
96 if self.id_token_claims.is_some() {
97 attrs.push(minijinja::Value::from("id_token_claims"));
98 }
99 if self.userinfo_claims.is_some() {
100 attrs.push(minijinja::Value::from("userinfo_claims"));
101 }
102 if self.extra_callback_parameters.is_some() {
103 attrs.push(minijinja::Value::from("extra_callback_parameters"));
104 }
105 Enumerator::Values(attrs)
106 }
107}
108
109fn b64decode(value: &str) -> Result<Value, Error> {
110 let bytes = Base64::decode_vec(value)
113 .or_else(|_| Base64Url::decode_vec(value))
114 .or_else(|_| Base64Unpadded::decode_vec(value))
115 .or_else(|_| Base64UrlUnpadded::decode_vec(value))
116 .map_err(|e| {
117 Error::new(
118 ErrorKind::InvalidOperation,
119 "Failed to decode base64 string",
120 )
121 .with_source(e)
122 })?;
123
124 Ok(Value::from(Arc::new(bytes)))
127}
128
129fn b64encode(bytes: &[u8]) -> String {
130 Base64::encode_string(bytes)
131}
132
133fn tlvdecode(bytes: &[u8]) -> Result<HashMap<Value, Value>, Error> {
135 let mut iter = bytes.iter().copied();
136 let mut ret = HashMap::new();
137 loop {
138 let Some(tag) = iter.next() else {
142 break;
143 };
144
145 let len = iter
146 .next()
147 .ok_or_else(|| Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding"))?;
148
149 let mut bytes = Vec::with_capacity(len.into());
150 for _ in 0..len {
151 bytes.push(
152 iter.next().ok_or_else(|| {
153 Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding")
154 })?,
155 );
156 }
157
158 ret.insert(tag.into(), Value::from(Arc::new(bytes)));
159 }
160
161 Ok(ret)
162}
163
164fn string(value: &Value) -> String {
165 value.to_string()
166}
167
168fn from_json(value: &str) -> Result<Value, minijinja::Error> {
169 let value: serde_json::Value = serde_json::from_str(value).map_err(|e| {
170 minijinja::Error::new(
171 minijinja::ErrorKind::InvalidOperation,
172 "Failed to decode JSON",
173 )
174 .with_source(e)
175 })?;
176
177 Ok(Value::from_serialize(value))
178}
179
180pub fn environment() -> Environment<'static> {
181 let mut env = Environment::new();
182
183 minijinja_contrib::add_to_environment(&mut env);
184
185 env.add_filter("b64decode", b64decode);
186 env.add_filter("b64encode", b64encode);
187 env.add_filter("tlvdecode", tlvdecode);
188 env.add_filter("string", string);
189 env.add_filter("from_json", from_json);
190
191 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
192
193 env
194}
195
196#[cfg(test)]
197mod tests {
198 use super::environment;
199
200 #[test]
201 fn test_split() {
202 let env = environment();
203 let res = env
204 .render_str(r#"{{ 'foo, bar' | split(', ') | join(" | ") }}"#, ())
205 .unwrap();
206 assert_eq!(res, "foo | bar");
207 }
208
209 #[test]
210 fn test_ilvdecode() {
211 let env = environment();
212 let res = env
213 .render_str(
214 r#"
215 {%- set tlv = 'Cg0wLTM4NS0yODA4OS0wEgRtb2Nr' | b64decode | tlvdecode -%}
216 {%- if tlv[18]|string != 'mock' -%}
217 {{ "FAIL"/0 }}
218 {%- endif -%}
219 {{- tlv[10]|string -}}
220 "#,
221 (),
222 )
223 .unwrap();
224 assert_eq!(res, "0-385-28089-0");
225 }
226
227 #[test]
228 fn test_base64_decode() {
229 let env = environment();
230
231 let res = env
232 .render_str("{{ 'cGFkZGluZw==' | b64decode }}", ())
233 .unwrap();
234 assert_eq!(res, "padding");
235
236 let res = env
237 .render_str("{{ 'dW5wYWRkZWQ' | b64decode }}", ())
238 .unwrap();
239 assert_eq!(res, "unpadded");
240 }
241}