1use std::{collections::HashMap, hash::Hash};
8
9use serde::{Deserialize, Serialize};
10
11pub trait FormField: Copy + Hash + PartialEq + Eq + Serialize + for<'de> Deserialize<'de> {
13 fn keep(&self) -> bool;
16}
17
18#[derive(Debug, Serialize)]
20#[serde(rename_all = "snake_case", tag = "kind")]
21pub enum FieldError {
22 Required,
24
25 Unspecified,
27
28 Invalid,
30
31 PasswordMismatch,
33
34 Exists,
36
37 Policy {
39 code: Option<&'static str>,
41
42 message: String,
44 },
45}
46
47#[derive(Debug, Serialize)]
49#[serde(rename_all = "snake_case", tag = "kind")]
50pub enum FormError {
51 InvalidCredentials,
53
54 PasswordMismatch,
56
57 Internal,
59
60 RateLimitExceeded,
62
63 Policy {
65 code: Option<&'static str>,
67
68 message: String,
70 },
71
72 Captcha,
74}
75
76#[derive(Debug, Default, Serialize)]
77struct FieldState {
78 value: Option<String>,
79 errors: Vec<FieldError>,
80}
81
82#[derive(Debug, Serialize)]
84pub struct FormState<K: Hash + Eq> {
85 fields: HashMap<K, FieldState>,
86 errors: Vec<FormError>,
87
88 #[serde(skip)]
89 has_errors: bool,
90}
91
92impl<K: Hash + Eq> Default for FormState<K> {
93 fn default() -> Self {
94 FormState {
95 fields: HashMap::default(),
96 errors: Vec::default(),
97 has_errors: false,
98 }
99 }
100}
101
102#[derive(Deserialize, PartialEq, Eq, Hash)]
103#[serde(untagged)]
104enum KeyOrOther<K> {
105 Key(K),
106 Other(String),
107}
108
109impl<K> KeyOrOther<K> {
110 fn key(self) -> Option<K> {
111 match self {
112 Self::Key(key) => Some(key),
113 Self::Other(_) => None,
114 }
115 }
116}
117
118impl<K: FormField> FormState<K> {
119 pub fn from_form<F: Serialize>(form: &F) -> Self {
126 let form = serde_json::to_value(form).unwrap();
127 let fields: HashMap<KeyOrOther<K>, Option<String>> = serde_json::from_value(form).unwrap();
128
129 let fields = fields
130 .into_iter()
131 .filter_map(|(key, value)| {
132 let key = key.key()?;
133 let value = key.keep().then_some(value).flatten();
134 let field = FieldState {
135 value,
136 errors: Vec::new(),
137 };
138 Some((key, field))
139 })
140 .collect();
141
142 FormState {
143 fields,
144 errors: Vec::new(),
145 has_errors: false,
146 }
147 }
148
149 pub fn add_error_on_field(&mut self, field: K, error: FieldError) {
151 self.fields.entry(field).or_default().errors.push(error);
152 self.has_errors = true;
153 }
154
155 #[must_use]
157 pub fn with_error_on_field(mut self, field: K, error: FieldError) -> Self {
158 self.add_error_on_field(field, error);
159 self
160 }
161
162 pub fn add_error_on_form(&mut self, error: FormError) {
164 self.errors.push(error);
165 self.has_errors = true;
166 }
167
168 #[must_use]
170 pub fn with_error_on_form(mut self, error: FormError) -> Self {
171 self.add_error_on_form(error);
172 self
173 }
174
175 pub fn set_value(&mut self, field: K, value: Option<String>) {
177 self.fields.entry(field).or_default().value = value;
178 }
179
180 pub fn has_value(&self, field: K) -> bool {
182 self.fields.get(&field).is_some_and(|f| f.value.is_some())
183 }
184
185 #[must_use]
187 pub fn is_valid(&self) -> bool {
188 !self.has_errors
189 }
190}
191
192pub trait ToFormState: Serialize {
194 type Field: FormField;
196
197 fn to_form_state(&self) -> FormState<Self::Field> {
203 FormState::from_form(&self)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[derive(Serialize)]
212 struct TestForm {
213 foo: String,
214 bar: String,
215 }
216
217 #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
218 #[serde(rename_all = "snake_case")]
219 enum TestFormField {
220 Foo,
221 Bar,
222 }
223
224 impl FormField for TestFormField {
225 fn keep(&self) -> bool {
226 match self {
227 Self::Foo => true,
228 Self::Bar => false,
229 }
230 }
231 }
232
233 impl ToFormState for TestForm {
234 type Field = TestFormField;
235 }
236
237 #[test]
238 fn form_state_serialization() {
239 let form = TestForm {
240 foo: "john".to_owned(),
241 bar: "hunter2".to_owned(),
242 };
243
244 let state = form.to_form_state();
245 let state = serde_json::to_value(state).unwrap();
246 assert_eq!(
247 state,
248 serde_json::json!({
249 "errors": [],
250 "fields": {
251 "foo": {
252 "errors": [],
253 "value": "john",
254 },
255 "bar": {
256 "errors": [],
257 "value": null
258 },
259 }
260 })
261 );
262
263 let form = TestForm {
264 foo: String::new(),
265 bar: String::new(),
266 };
267 let state = form
268 .to_form_state()
269 .with_error_on_field(TestFormField::Foo, FieldError::Required)
270 .with_error_on_field(TestFormField::Bar, FieldError::Required)
271 .with_error_on_form(FormError::InvalidCredentials);
272
273 let state = serde_json::to_value(state).unwrap();
274 assert_eq!(
275 state,
276 serde_json::json!({
277 "errors": [{"kind": "invalid_credentials"}],
278 "fields": {
279 "foo": {
280 "errors": [{"kind": "required"}],
281 "value": "",
282 },
283 "bar": {
284 "errors": [{"kind": "required"}],
285 "value": null
286 },
287 }
288 })
289 );
290 }
291}