mas_templates/
functions.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7// This is needed to make the Environment::add* functions work
8#![allow(clippy::needless_pass_by_value)]
9
10//! Additional functions, tests and filters used in templates
11
12use std::{
13    collections::HashMap,
14    fmt::Formatter,
15    str::FromStr,
16    sync::{Arc, atomic::AtomicUsize},
17};
18
19use camino::Utf8Path;
20use mas_i18n::{Argument, ArgumentList, DataLocale, Translator, sprintf::FormattedMessagePart};
21use mas_router::UrlBuilder;
22use mas_spa::ViteManifest;
23use minijinja::{
24    Error, ErrorKind, State, Value, escape_formatter,
25    machinery::make_string_output,
26    value::{Kwargs, Object, ViaDeserialize, from_args},
27};
28use url::Url;
29
30pub fn register(
31    env: &mut minijinja::Environment,
32    url_builder: UrlBuilder,
33    vite_manifest: ViteManifest,
34    translator: Arc<Translator>,
35) {
36    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
37
38    minijinja_contrib::add_to_environment(env);
39    env.add_test("empty", self::tester_empty);
40    env.add_filter("to_params", filter_to_params);
41    env.add_filter("simplify_url", filter_simplify_url);
42    env.add_filter("add_slashes", filter_add_slashes);
43    env.add_function("add_params_to_url", function_add_params_to_url);
44    env.add_function("counter", || Ok(Value::from_object(Counter::default())));
45    env.add_global(
46        "include_asset",
47        Value::from_object(IncludeAsset {
48            url_builder: url_builder.clone(),
49            vite_manifest,
50        }),
51    );
52    env.add_global(
53        "translator",
54        Value::from_object(TranslatorFunc { translator }),
55    );
56    env.add_filter("prefix_url", move |url: &str| -> String {
57        if !url.starts_with('/') {
58            // Let's assume it's not an internal URL and return it as-is
59            return url.to_owned();
60        }
61
62        let Some(prefix) = url_builder.prefix() else {
63            // If there is no prefix to add, return the URL as-is
64            return url.to_owned();
65        };
66
67        format!("{prefix}{url}")
68    });
69}
70
71fn tester_empty(seq: Value) -> bool {
72    seq.len() == Some(0)
73}
74
75fn filter_add_slashes(value: &str) -> String {
76    value
77        .replace('\\', "\\\\")
78        .replace('\"', "\\\"")
79        .replace('\'', "\\\'")
80}
81
82fn filter_to_params(params: &Value, kwargs: Kwargs) -> Result<String, Error> {
83    let params = serde_urlencoded::to_string(params).map_err(|e| {
84        Error::new(
85            ErrorKind::InvalidOperation,
86            "Could not serialize parameters",
87        )
88        .with_source(e)
89    })?;
90
91    let prefix = kwargs.get("prefix").unwrap_or("");
92    kwargs.assert_all_used()?;
93
94    if params.is_empty() {
95        Ok(String::new())
96    } else {
97        Ok(format!("{prefix}{params}"))
98    }
99}
100
101/// Filter which simplifies a URL to its domain name for HTTP(S) URLs
102fn filter_simplify_url(url: &str, kwargs: Kwargs) -> Result<String, minijinja::Error> {
103    // Do nothing if the URL is not valid
104    let Ok(mut url) = Url::from_str(url) else {
105        return Ok(url.to_owned());
106    };
107
108    // Always at least remove the query parameters and fragment
109    url.set_query(None);
110    url.set_fragment(None);
111
112    // Do nothing else for non-HTTPS URLs
113    if url.scheme() != "https" {
114        return Ok(url.to_string());
115    }
116
117    let keep_path = kwargs.get::<Option<bool>>("keep_path")?.unwrap_or_default();
118    kwargs.assert_all_used()?;
119
120    // Only return the domain name
121    let Some(domain) = url.domain() else {
122        return Ok(url.to_string());
123    };
124
125    if keep_path {
126        Ok(format!(
127            "{domain}{path}",
128            domain = domain,
129            path = url.path(),
130        ))
131    } else {
132        Ok(domain.to_owned())
133    }
134}
135
136enum ParamsWhere {
137    Fragment,
138    Query,
139}
140
141fn function_add_params_to_url(
142    uri: ViaDeserialize<Url>,
143    mode: &str,
144    params: ViaDeserialize<HashMap<String, Value>>,
145) -> Result<String, Error> {
146    use ParamsWhere::{Fragment, Query};
147
148    let mode = match mode {
149        "fragment" => Fragment,
150        "query" => Query,
151        _ => {
152            return Err(Error::new(
153                ErrorKind::InvalidOperation,
154                "Invalid `mode` parameter",
155            ));
156        }
157    };
158
159    // First, get the `uri`, `mode` and `params` parameters
160    // Get the relevant part of the URI and parse for existing parameters
161    let existing = match mode {
162        Fragment => uri.fragment(),
163        Query => uri.query(),
164    };
165    let existing: HashMap<String, Value> = existing
166        .map(serde_urlencoded::from_str)
167        .transpose()
168        .map_err(|e| {
169            Error::new(
170                ErrorKind::InvalidOperation,
171                "Could not parse existing `uri` parameters",
172            )
173            .with_source(e)
174        })?
175        .unwrap_or_default();
176
177    // Merge the exising and the additional parameters together
178    let params: HashMap<&String, &Value> = params.iter().chain(existing.iter()).collect();
179
180    // Transform them back to urlencoded
181    let params = serde_urlencoded::to_string(params).map_err(|e| {
182        Error::new(
183            ErrorKind::InvalidOperation,
184            "Could not serialize back parameters",
185        )
186        .with_source(e)
187    })?;
188
189    let uri = {
190        let mut uri = uri;
191        match mode {
192            Fragment => uri.set_fragment(Some(&params)),
193            Query => uri.set_query(Some(&params)),
194        }
195        uri
196    };
197
198    Ok(uri.to_string())
199}
200
201struct TranslatorFunc {
202    translator: Arc<Translator>,
203}
204
205impl std::fmt::Debug for TranslatorFunc {
206    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207        f.debug_struct("TranslatorFunc")
208            .field("translator", &"..")
209            .finish()
210    }
211}
212
213impl std::fmt::Display for TranslatorFunc {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        f.write_str("translator")
216    }
217}
218
219impl Object for TranslatorFunc {
220    fn call(self: &Arc<Self>, _state: &State, args: &[Value]) -> Result<Value, Error> {
221        let (lang,): (&str,) = from_args(args)?;
222
223        let lang: DataLocale = lang.parse().map_err(|e| {
224            Error::new(ErrorKind::InvalidOperation, "Invalid language").with_source(e)
225        })?;
226
227        Ok(Value::from_object(TranslateFunc {
228            lang,
229            translator: Arc::clone(&self.translator),
230        }))
231    }
232}
233
234struct TranslateFunc {
235    translator: Arc<Translator>,
236    lang: DataLocale,
237}
238
239impl std::fmt::Debug for TranslateFunc {
240    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
241        f.debug_struct("Translate")
242            .field("translator", &"..")
243            .field("lang", &self.lang)
244            .finish()
245    }
246}
247
248impl std::fmt::Display for TranslateFunc {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        f.write_str("translate")
251    }
252}
253
254impl Object for TranslateFunc {
255    fn call(self: &Arc<Self>, state: &State, args: &[Value]) -> Result<Value, Error> {
256        let (key, kwargs): (&str, Kwargs) = from_args(args)?;
257
258        let (message, _locale) = if let Some(count) = kwargs.get("count")? {
259            self.translator
260                .plural_with_fallback(self.lang.clone(), key, count)
261                .ok_or(Error::new(
262                    ErrorKind::InvalidOperation,
263                    "Missing translation",
264                ))?
265        } else {
266            self.translator
267                .message_with_fallback(self.lang.clone(), key)
268                .ok_or(Error::new(
269                    ErrorKind::InvalidOperation,
270                    "Missing translation",
271                ))?
272        };
273
274        let res: Result<ArgumentList, Error> = kwargs
275            .args()
276            .map(|name| {
277                let value: Value = kwargs.get(name)?;
278                let value = serde_json::to_value(value).map_err(|e| {
279                    Error::new(ErrorKind::InvalidOperation, "Could not serialize argument")
280                        .with_source(e)
281                })?;
282
283                Ok::<_, Error>(Argument::named(name.to_owned(), value))
284            })
285            .collect();
286        let list = res?;
287
288        let formatted = message.format_(&list).map_err(|e| {
289            Error::new(ErrorKind::InvalidOperation, "Could not format message").with_source(e)
290        })?;
291
292        let mut buf = String::with_capacity(formatted.len());
293        let mut output = make_string_output(&mut buf);
294        for part in formatted.parts() {
295            match part {
296                FormattedMessagePart::Text(text) => {
297                    // Literal text, just write it
298                    output.write_str(text)?;
299                }
300                FormattedMessagePart::Placeholder(placeholder) => {
301                    // Placeholder, escape it
302                    escape_formatter(&mut output, state, &placeholder.as_str().into())?;
303                }
304            }
305        }
306
307        Ok(Value::from_safe_string(buf))
308    }
309
310    fn call_method(
311        self: &Arc<Self>,
312        _state: &State,
313        name: &str,
314        args: &[Value],
315    ) -> Result<Value, Error> {
316        match name {
317            "relative_date" => {
318                let (date,): (String,) = from_args(args)?;
319                let date: chrono::DateTime<chrono::Utc> = date.parse().map_err(|e| {
320                    Error::new(
321                        ErrorKind::InvalidOperation,
322                        "Invalid date while calling function `relative_date`",
323                    )
324                    .with_source(e)
325                })?;
326
327                // TODO: grab the clock somewhere
328                #[allow(clippy::disallowed_methods)]
329                let now = chrono::Utc::now();
330
331                let diff = (date - now).num_days();
332
333                Ok(Value::from(
334                    self.translator
335                        .relative_date(&self.lang, diff)
336                        .map_err(|_e| {
337                            Error::new(
338                                ErrorKind::InvalidOperation,
339                                "Failed to format relative date",
340                            )
341                        })?,
342                ))
343            }
344
345            "short_time" => {
346                let (date,): (String,) = from_args(args)?;
347                let date: chrono::DateTime<chrono::Utc> = date.parse().map_err(|e| {
348                    Error::new(
349                        ErrorKind::InvalidOperation,
350                        "Invalid date while calling function `time`",
351                    )
352                    .with_source(e)
353                })?;
354
355                // TODO: we should use the user's timezone here
356                let time = date.time();
357
358                Ok(Value::from(
359                    self.translator
360                        .short_time(&self.lang, &TimeAdapter(time))
361                        .map_err(|_e| {
362                            Error::new(ErrorKind::InvalidOperation, "Failed to format time")
363                        })?,
364                ))
365            }
366
367            _ => Err(Error::new(
368                ErrorKind::InvalidOperation,
369                "Invalid method on include_asset",
370            )),
371        }
372    }
373}
374
375/// An adapter to make a [`Timelike`] implement [`IsoTimeInput`]
376///
377/// [`Timelike`]: chrono::Timelike
378/// [`IsoTimeInput`]: mas_i18n::icu_datetime::input::IsoTimeInput
379struct TimeAdapter<T>(T);
380
381impl<T: chrono::Timelike> mas_i18n::icu_datetime::input::IsoTimeInput for TimeAdapter<T> {
382    fn hour(&self) -> Option<mas_i18n::icu_calendar::types::IsoHour> {
383        let hour: usize = chrono::Timelike::hour(&self.0).try_into().ok()?;
384        hour.try_into().ok()
385    }
386
387    fn minute(&self) -> Option<mas_i18n::icu_calendar::types::IsoMinute> {
388        let minute: usize = chrono::Timelike::minute(&self.0).try_into().ok()?;
389        minute.try_into().ok()
390    }
391
392    fn second(&self) -> Option<mas_i18n::icu_calendar::types::IsoSecond> {
393        let second: usize = chrono::Timelike::second(&self.0).try_into().ok()?;
394        second.try_into().ok()
395    }
396
397    fn nanosecond(&self) -> Option<mas_i18n::icu_calendar::types::NanoSecond> {
398        let nanosecond: usize = chrono::Timelike::nanosecond(&self.0).try_into().ok()?;
399        nanosecond.try_into().ok()
400    }
401}
402
403struct IncludeAsset {
404    url_builder: UrlBuilder,
405    vite_manifest: ViteManifest,
406}
407
408impl std::fmt::Debug for IncludeAsset {
409    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
410        f.debug_struct("IncludeAsset")
411            .field("url_builder", &self.url_builder.assets_base())
412            .field("vite_manifest", &"..")
413            .finish()
414    }
415}
416
417impl std::fmt::Display for IncludeAsset {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        f.write_str("include_asset")
420    }
421}
422
423impl Object for IncludeAsset {
424    fn call(self: &Arc<Self>, _state: &State, args: &[Value]) -> Result<Value, Error> {
425        let (path,): (&str,) = from_args(args)?;
426
427        let path: &Utf8Path = path.into();
428
429        let (main, imported) = self.vite_manifest.find_assets(path).map_err(|_e| {
430            Error::new(
431                ErrorKind::InvalidOperation,
432                "Invalid assets manifest while calling function `include_asset`",
433            )
434        })?;
435
436        let assets = std::iter::once(main)
437            .chain(imported.iter().filter(|a| a.is_stylesheet()).copied())
438            .filter_map(|asset| asset.include_tag(self.url_builder.assets_base().into()));
439
440        let preloads = imported
441            .iter()
442            .filter(|a| a.is_script())
443            .map(|asset| asset.preload_tag(self.url_builder.assets_base().into()));
444
445        let tags: Vec<String> = preloads.chain(assets).collect();
446
447        Ok(Value::from_safe_string(tags.join("\n")))
448    }
449}
450
451#[derive(Debug, Default)]
452struct Counter {
453    count: AtomicUsize,
454}
455
456impl std::fmt::Display for Counter {
457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458        write!(
459            f,
460            "{}",
461            self.count.load(std::sync::atomic::Ordering::Relaxed)
462        )
463    }
464}
465
466impl Object for Counter {
467    fn call_method(
468        self: &Arc<Self>,
469        _state: &State,
470        name: &str,
471        args: &[Value],
472    ) -> Result<Value, Error> {
473        // None of the methods take any arguments
474        from_args::<()>(args)?;
475
476        match name {
477            "reset" => {
478                self.count.store(0, std::sync::atomic::Ordering::Relaxed);
479                Ok(Value::UNDEFINED)
480            }
481            "next" => {
482                let old = self
483                    .count
484                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
485                Ok(Value::from(old))
486            }
487            "peek" => Ok(Value::from(
488                self.count.load(std::sync::atomic::Ordering::Relaxed),
489            )),
490            _ => Err(Error::new(
491                ErrorKind::InvalidOperation,
492                "Invalid method on counter",
493            )),
494        }
495    }
496}