mas_templates/
forms.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{collections::HashMap, hash::Hash};
8
9use serde::{Deserialize, Serialize};
10
11/// A trait which should be used for form field enums
12pub trait FormField: Copy + Hash + PartialEq + Eq + Serialize + for<'de> Deserialize<'de> {
13    /// Return false for fields where values should not be kept (e.g. password
14    /// fields)
15    fn keep(&self) -> bool;
16}
17
18/// An error on a form field
19#[derive(Debug, Serialize)]
20#[serde(rename_all = "snake_case", tag = "kind")]
21pub enum FieldError {
22    /// A required field is missing
23    Required,
24
25    /// An unspecified error on the field
26    Unspecified,
27
28    /// Invalid value for this field
29    Invalid,
30
31    /// The password confirmation doesn't match the password
32    PasswordMismatch,
33
34    /// That value already exists
35    Exists,
36
37    /// Denied by the policy
38    Policy {
39        /// Well-known policy code
40        code: Option<&'static str>,
41
42        /// Message for this policy violation
43        message: String,
44    },
45}
46
47/// An error on the whole form
48#[derive(Debug, Serialize)]
49#[serde(rename_all = "snake_case", tag = "kind")]
50pub enum FormError {
51    /// The given credentials are not valid
52    InvalidCredentials,
53
54    /// Password fields don't match
55    PasswordMismatch,
56
57    /// There was an internal error
58    Internal,
59
60    /// Rate limit exceeded
61    RateLimitExceeded,
62
63    /// Denied by the policy
64    Policy {
65        /// Well-known policy code
66        code: Option<&'static str>,
67
68        /// Message for this policy violation
69        message: String,
70    },
71
72    /// Failed to validate CAPTCHA
73    Captcha,
74}
75
76#[derive(Debug, Default, Serialize)]
77struct FieldState {
78    value: Option<String>,
79    errors: Vec<FieldError>,
80}
81
82/// The state of a form and its fields
83#[derive(Debug, Serialize)]
84pub struct FormState<K: Hash + Eq> {
85    fields: HashMap<K, FieldState>,
86    errors: Vec<FormError>,
87
88    #[serde(skip)]
89    has_errors: bool,
90}
91
92impl<K: Hash + Eq> Default for FormState<K> {
93    fn default() -> Self {
94        FormState {
95            fields: HashMap::default(),
96            errors: Vec::default(),
97            has_errors: false,
98        }
99    }
100}
101
102#[derive(Deserialize, PartialEq, Eq, Hash)]
103#[serde(untagged)]
104enum KeyOrOther<K> {
105    Key(K),
106    Other(String),
107}
108
109impl<K> KeyOrOther<K> {
110    fn key(self) -> Option<K> {
111        match self {
112            Self::Key(key) => Some(key),
113            Self::Other(_) => None,
114        }
115    }
116}
117
118impl<K: FormField> FormState<K> {
119    /// Generate a [`FormState`] out of a form
120    ///
121    /// # Panics
122    ///
123    /// If the form fails to serialize, or the form field keys fail to
124    /// deserialize
125    pub fn from_form<F: Serialize>(form: &F) -> Self {
126        let form = serde_json::to_value(form).unwrap();
127        let fields: HashMap<KeyOrOther<K>, Option<String>> = serde_json::from_value(form).unwrap();
128
129        let fields = fields
130            .into_iter()
131            .filter_map(|(key, value)| {
132                let key = key.key()?;
133                let value = key.keep().then_some(value).flatten();
134                let field = FieldState {
135                    value,
136                    errors: Vec::new(),
137                };
138                Some((key, field))
139            })
140            .collect();
141
142        FormState {
143            fields,
144            errors: Vec::new(),
145            has_errors: false,
146        }
147    }
148
149    /// Add an error on a form field
150    pub fn add_error_on_field(&mut self, field: K, error: FieldError) {
151        self.fields.entry(field).or_default().errors.push(error);
152        self.has_errors = true;
153    }
154
155    /// Add an error on a form field
156    #[must_use]
157    pub fn with_error_on_field(mut self, field: K, error: FieldError) -> Self {
158        self.add_error_on_field(field, error);
159        self
160    }
161
162    /// Add an error on the form
163    pub fn add_error_on_form(&mut self, error: FormError) {
164        self.errors.push(error);
165        self.has_errors = true;
166    }
167
168    /// Add an error on the form
169    #[must_use]
170    pub fn with_error_on_form(mut self, error: FormError) -> Self {
171        self.add_error_on_form(error);
172        self
173    }
174
175    /// Set a value on the form
176    pub fn set_value(&mut self, field: K, value: Option<String>) {
177        self.fields.entry(field).or_default().value = value;
178    }
179
180    /// Checks if a field contains a value
181    pub fn has_value(&self, field: K) -> bool {
182        self.fields.get(&field).is_some_and(|f| f.value.is_some())
183    }
184
185    /// Returns `true` if the form has no error attached to it
186    #[must_use]
187    pub fn is_valid(&self) -> bool {
188        !self.has_errors
189    }
190}
191
192/// Utility trait to help creating [`FormState`] out of a form
193pub trait ToFormState: Serialize {
194    /// The enum used for field names
195    type Field: FormField;
196
197    /// Generate a [`FormState`] out of [`Self`]
198    ///
199    /// # Panics
200    ///
201    /// If the form fails to serialize or [`Self::Field`] fails to deserialize
202    fn to_form_state(&self) -> FormState<Self::Field> {
203        FormState::from_form(&self)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[derive(Serialize)]
212    struct TestForm {
213        foo: String,
214        bar: String,
215    }
216
217    #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
218    #[serde(rename_all = "snake_case")]
219    enum TestFormField {
220        Foo,
221        Bar,
222    }
223
224    impl FormField for TestFormField {
225        fn keep(&self) -> bool {
226            match self {
227                Self::Foo => true,
228                Self::Bar => false,
229            }
230        }
231    }
232
233    impl ToFormState for TestForm {
234        type Field = TestFormField;
235    }
236
237    #[test]
238    fn form_state_serialization() {
239        let form = TestForm {
240            foo: "john".to_owned(),
241            bar: "hunter2".to_owned(),
242        };
243
244        let state = form.to_form_state();
245        let state = serde_json::to_value(state).unwrap();
246        assert_eq!(
247            state,
248            serde_json::json!({
249                "errors": [],
250                "fields": {
251                    "foo": {
252                        "errors": [],
253                        "value": "john",
254                    },
255                    "bar": {
256                        "errors": [],
257                        "value": null
258                    },
259                }
260            })
261        );
262
263        let form = TestForm {
264            foo: String::new(),
265            bar: String::new(),
266        };
267        let state = form
268            .to_form_state()
269            .with_error_on_field(TestFormField::Foo, FieldError::Required)
270            .with_error_on_field(TestFormField::Bar, FieldError::Required)
271            .with_error_on_form(FormError::InvalidCredentials);
272
273        let state = serde_json::to_value(state).unwrap();
274        assert_eq!(
275            state,
276            serde_json::json!({
277                "errors": [{"kind": "invalid_credentials"}],
278                "fields": {
279                    "foo": {
280                        "errors": [{"kind": "required"}],
281                        "value": "",
282                    },
283                    "bar": {
284                        "errors": [{"kind": "required"}],
285                        "value": null
286                    },
287                }
288            })
289        );
290    }
291}