mas_handlers/upstream_oauth2/
template.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{collections::HashMap, sync::Arc};
8
9use base64ct::{Base64, Base64Unpadded, Base64Url, Base64UrlUnpadded, Encoding};
10use minijinja::{
11    Environment, Error, ErrorKind, Value,
12    value::{Enumerator, Object},
13};
14
15/// Context passed to the attribute mapping template
16///
17/// The variables available in the template are:
18/// - `user`: claims for the user, merged from the ID token and userinfo
19///   endpoint
20/// - `id_token_claims`: claims from the ID token
21/// - `userinfo_claims`: claims from the userinfo endpoint
22/// - `extra_callback_parameters`: extra parameters passed to the callback
23#[derive(Debug, Default)]
24pub(crate) struct AttributeMappingContext {
25    id_token_claims: Option<HashMap<String, serde_json::Value>>,
26    extra_callback_parameters: Option<serde_json::Value>,
27    userinfo_claims: Option<serde_json::Value>,
28}
29
30impl AttributeMappingContext {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    pub fn with_id_token_claims(
36        mut self,
37        id_token_claims: HashMap<String, serde_json::Value>,
38    ) -> Self {
39        self.id_token_claims = Some(id_token_claims);
40        self
41    }
42
43    pub fn with_extra_callback_parameters(
44        mut self,
45        extra_callback_parameters: serde_json::Value,
46    ) -> Self {
47        self.extra_callback_parameters = Some(extra_callback_parameters);
48        self
49    }
50
51    pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self {
52        self.userinfo_claims = Some(userinfo_claims);
53        self
54    }
55
56    pub fn build(self) -> Value {
57        Value::from_object(self)
58    }
59}
60
61impl Object for AttributeMappingContext {
62    fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
63        match name.as_str()? {
64            "user" => {
65                if self.id_token_claims.is_none() && self.userinfo_claims.is_none() {
66                    return None;
67                }
68                let mut merged_user: HashMap<String, serde_json::Value> = HashMap::new();
69                if let serde_json::Value::Object(userinfo) = self
70                    .userinfo_claims
71                    .clone()
72                    .unwrap_or(serde_json::Value::Null)
73                {
74                    merged_user.extend(userinfo);
75                }
76                if let Some(id_token) = self.id_token_claims.clone() {
77                    merged_user.extend(id_token);
78                }
79                Some(Value::from_serialize(merged_user))
80            }
81            "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
82            "userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize),
83            "extra_callback_parameters" => self
84                .extra_callback_parameters
85                .as_ref()
86                .map(Value::from_serialize),
87            _ => None,
88        }
89    }
90
91    fn enumerate(self: &Arc<Self>) -> Enumerator {
92        let mut attrs = Vec::new();
93        if self.id_token_claims.is_some() || self.userinfo_claims.is_none() {
94            attrs.push(minijinja::Value::from("user"));
95        }
96        if self.id_token_claims.is_some() {
97            attrs.push(minijinja::Value::from("id_token_claims"));
98        }
99        if self.userinfo_claims.is_some() {
100            attrs.push(minijinja::Value::from("userinfo_claims"));
101        }
102        if self.extra_callback_parameters.is_some() {
103            attrs.push(minijinja::Value::from("extra_callback_parameters"));
104        }
105        Enumerator::Values(attrs)
106    }
107}
108
109fn b64decode(value: &str) -> Result<Value, Error> {
110    // We're not too concerned about the performance of this filter, so we'll just
111    // try all the base64 variants when decoding
112    let bytes = Base64::decode_vec(value)
113        .or_else(|_| Base64Url::decode_vec(value))
114        .or_else(|_| Base64Unpadded::decode_vec(value))
115        .or_else(|_| Base64UrlUnpadded::decode_vec(value))
116        .map_err(|e| {
117            Error::new(
118                ErrorKind::InvalidOperation,
119                "Failed to decode base64 string",
120            )
121            .with_source(e)
122        })?;
123
124    // It is not obvious, but the cleanest way to get a Value stored as raw bytes is
125    // to wrap it in an Arc, because Value implements From<Arc<Vec<u8>>>
126    Ok(Value::from(Arc::new(bytes)))
127}
128
129fn b64encode(bytes: &[u8]) -> String {
130    Base64::encode_string(bytes)
131}
132
133/// Decode a Tag-Length-Value encoded byte array into a map of tag to value.
134fn tlvdecode(bytes: &[u8]) -> Result<HashMap<Value, Value>, Error> {
135    let mut iter = bytes.iter().copied();
136    let mut ret = HashMap::new();
137    loop {
138        // TODO: this assumes the tag and the length are both single bytes, which is not
139        // always the case with protobufs. We should properly decode varints
140        // here.
141        let Some(tag) = iter.next() else {
142            break;
143        };
144
145        let len = iter
146            .next()
147            .ok_or_else(|| Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding"))?;
148
149        let mut bytes = Vec::with_capacity(len.into());
150        for _ in 0..len {
151            bytes.push(
152                iter.next().ok_or_else(|| {
153                    Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding")
154                })?,
155            );
156        }
157
158        ret.insert(tag.into(), Value::from(Arc::new(bytes)));
159    }
160
161    Ok(ret)
162}
163
164fn string(value: &Value) -> String {
165    value.to_string()
166}
167
168fn from_json(value: &str) -> Result<Value, minijinja::Error> {
169    let value: serde_json::Value = serde_json::from_str(value).map_err(|e| {
170        minijinja::Error::new(
171            minijinja::ErrorKind::InvalidOperation,
172            "Failed to decode JSON",
173        )
174        .with_source(e)
175    })?;
176
177    Ok(Value::from_serialize(value))
178}
179
180pub fn environment() -> Environment<'static> {
181    let mut env = Environment::new();
182
183    minijinja_contrib::add_to_environment(&mut env);
184
185    env.add_filter("b64decode", b64decode);
186    env.add_filter("b64encode", b64encode);
187    env.add_filter("tlvdecode", tlvdecode);
188    env.add_filter("string", string);
189    env.add_filter("from_json", from_json);
190
191    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
192
193    env
194}
195
196#[cfg(test)]
197mod tests {
198    use super::environment;
199
200    #[test]
201    fn test_split() {
202        let env = environment();
203        let res = env
204            .render_str(r#"{{ 'foo, bar' | split(', ') | join(" | ") }}"#, ())
205            .unwrap();
206        assert_eq!(res, "foo | bar");
207    }
208
209    #[test]
210    fn test_ilvdecode() {
211        let env = environment();
212        let res = env
213            .render_str(
214                r#"
215                    {%- set tlv = 'Cg0wLTM4NS0yODA4OS0wEgRtb2Nr' | b64decode | tlvdecode -%}
216                    {%- if tlv[18]|string != 'mock' -%}
217                        {{ "FAIL"/0 }}
218                    {%- endif -%}
219                    {{- tlv[10]|string -}}
220                "#,
221                (),
222            )
223            .unwrap();
224        assert_eq!(res, "0-385-28089-0");
225    }
226
227    #[test]
228    fn test_base64_decode() {
229        let env = environment();
230
231        let res = env
232            .render_str("{{ 'cGFkZGluZw==' | b64decode }}", ())
233            .unwrap();
234        assert_eq!(res, "padding");
235
236        let res = env
237            .render_str("{{ 'dW5wYWRkZWQ' | b64decode }}", ())
238            .unwrap();
239        assert_eq!(res, "unpadded");
240    }
241}