mas_axum_utils/
language_detection.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::cmp::Reverse;
8
9use headers::{Error, Header};
10use http::{HeaderName, HeaderValue, header::ACCEPT_LANGUAGE};
11use icu_locid::Locale;
12
13#[derive(PartialEq, Eq, Debug)]
14struct AcceptLanguagePart {
15    // None means *
16    locale: Option<Locale>,
17
18    // Quality is between 0 and 1 with 3 decimal places
19    // Which we map from 0 to 1000, e.g. 0.5 becomes 500
20    quality: u16,
21}
22
23impl PartialOrd for AcceptLanguagePart {
24    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
25        Some(self.cmp(other))
26    }
27}
28
29impl Ord for AcceptLanguagePart {
30    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
31        // When comparing two AcceptLanguage structs, we only consider the
32        // quality, in reverse.
33        Reverse(self.quality).cmp(&Reverse(other.quality))
34    }
35}
36
37/// A header that represents the `Accept-Language` header.
38#[derive(PartialEq, Eq, Debug)]
39pub struct AcceptLanguage {
40    parts: Vec<AcceptLanguagePart>,
41}
42
43impl AcceptLanguage {
44    pub fn iter(&self) -> impl Iterator<Item = &Locale> {
45        // This should stop when we hit the first None, aka the first *
46        self.parts.iter().map_while(|item| item.locale.as_ref())
47    }
48}
49
50/// Utility to trim ASCII whitespace from the start and end of a byte slice
51const fn trim_bytes(mut bytes: &[u8]) -> &[u8] {
52    // Trim leading and trailing whitespace
53    while let [first, rest @ ..] = bytes {
54        if first.is_ascii_whitespace() {
55            bytes = rest;
56        } else {
57            break;
58        }
59    }
60
61    while let [rest @ .., last] = bytes {
62        if last.is_ascii_whitespace() {
63            bytes = rest;
64        } else {
65            break;
66        }
67    }
68
69    bytes
70}
71
72impl Header for AcceptLanguage {
73    fn name() -> &'static HeaderName {
74        &ACCEPT_LANGUAGE
75    }
76
77    fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
78    where
79        Self: Sized,
80        I: Iterator<Item = &'i HeaderValue>,
81    {
82        let mut parts = Vec::new();
83        for value in values {
84            for part in value.as_bytes().split(|b| *b == b',') {
85                let mut it = part.split(|b| *b == b';');
86                let locale = it.next().ok_or(Error::invalid())?;
87                let locale = trim_bytes(locale);
88
89                let locale = match locale {
90                    b"*" => None,
91                    locale => {
92                        let locale =
93                            Locale::try_from_bytes(locale).map_err(|_e| Error::invalid())?;
94                        Some(locale)
95                    }
96                };
97
98                let quality = if let Some(quality) = it.next() {
99                    let quality = trim_bytes(quality);
100                    let quality = quality.strip_prefix(b"q=").ok_or(Error::invalid())?;
101                    let quality = std::str::from_utf8(quality).map_err(|_e| Error::invalid())?;
102                    let quality = quality.parse::<f64>().map_err(|_e| Error::invalid())?;
103                    // Bound the quality between 0 and 1
104                    let quality = quality.clamp(0_f64, 1_f64);
105
106                    // Make sure the iterator is empty
107                    if it.next().is_some() {
108                        return Err(Error::invalid());
109                    }
110
111                    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
112                    {
113                        f64::round(quality * 1000_f64) as u16
114                    }
115                } else {
116                    1000
117                };
118
119                parts.push(AcceptLanguagePart { locale, quality });
120            }
121        }
122
123        parts.sort();
124
125        Ok(AcceptLanguage { parts })
126    }
127
128    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
129        let mut value = String::new();
130        let mut first = true;
131        for part in &self.parts {
132            if first {
133                first = false;
134            } else {
135                value.push_str(", ");
136            }
137
138            if let Some(locale) = &part.locale {
139                value.push_str(&locale.to_string());
140            } else {
141                value.push('*');
142            }
143
144            if part.quality != 1000 {
145                value.push_str(";q=");
146                value.push_str(&(f64::from(part.quality) / 1000_f64).to_string());
147            }
148        }
149
150        // We know this is safe because we only use ASCII characters
151        values.extend(Some(HeaderValue::from_str(&value).unwrap()));
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use headers::HeaderMapExt;
158    use http::{HeaderMap, HeaderValue, header::ACCEPT_LANGUAGE};
159    use icu_locid::locale;
160
161    use super::*;
162
163    #[test]
164    fn test_decode() {
165        let headers = HeaderMap::from_iter([(
166            ACCEPT_LANGUAGE,
167            HeaderValue::from_str("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5").unwrap(),
168        )]);
169
170        let accept_language: Option<AcceptLanguage> = headers.typed_get();
171        assert!(accept_language.is_some());
172        let accept_language = accept_language.unwrap();
173
174        assert_eq!(
175            accept_language,
176            AcceptLanguage {
177                parts: vec![
178                    AcceptLanguagePart {
179                        locale: Some(locale!("fr-CH")),
180                        quality: 1000,
181                    },
182                    AcceptLanguagePart {
183                        locale: Some(locale!("fr")),
184                        quality: 900,
185                    },
186                    AcceptLanguagePart {
187                        locale: Some(locale!("en")),
188                        quality: 800,
189                    },
190                    AcceptLanguagePart {
191                        locale: Some(locale!("de")),
192                        quality: 700,
193                    },
194                    AcceptLanguagePart {
195                        locale: None,
196                        quality: 500,
197                    },
198                ]
199            }
200        );
201    }
202
203    #[test]
204    /// Test that we can decode a header with multiple values unordered, and
205    /// that the output is ordered by quality
206    fn test_decode_order() {
207        let headers = HeaderMap::from_iter([(
208            ACCEPT_LANGUAGE,
209            HeaderValue::from_str("*;q=0.5, fr-CH, en;q=0.8, fr;q=0.9, de;q=0.9").unwrap(),
210        )]);
211
212        let accept_language: Option<AcceptLanguage> = headers.typed_get();
213        assert!(accept_language.is_some());
214        let accept_language = accept_language.unwrap();
215
216        assert_eq!(
217            accept_language,
218            AcceptLanguage {
219                parts: vec![
220                    AcceptLanguagePart {
221                        locale: Some(locale!("fr-CH")),
222                        quality: 1000,
223                    },
224                    AcceptLanguagePart {
225                        locale: Some(locale!("fr")),
226                        quality: 900,
227                    },
228                    AcceptLanguagePart {
229                        locale: Some(locale!("de")),
230                        quality: 900,
231                    },
232                    AcceptLanguagePart {
233                        locale: Some(locale!("en")),
234                        quality: 800,
235                    },
236                    AcceptLanguagePart {
237                        locale: None,
238                        quality: 500,
239                    },
240                ]
241            }
242        );
243    }
244
245    #[test]
246    fn test_encode() {
247        let accept_language = AcceptLanguage {
248            parts: vec![
249                AcceptLanguagePart {
250                    locale: Some(locale!("fr-CH")),
251                    quality: 1000,
252                },
253                AcceptLanguagePart {
254                    locale: Some(locale!("fr")),
255                    quality: 900,
256                },
257                AcceptLanguagePart {
258                    locale: Some(locale!("de")),
259                    quality: 900,
260                },
261                AcceptLanguagePart {
262                    locale: Some(locale!("en")),
263                    quality: 800,
264                },
265                AcceptLanguagePart {
266                    locale: None,
267                    quality: 500,
268                },
269            ],
270        };
271
272        let mut headers = HeaderMap::new();
273        headers.typed_insert(accept_language);
274        let header = headers.get(ACCEPT_LANGUAGE).unwrap();
275        assert_eq!(
276            header.to_str().unwrap(),
277            "fr-CH, fr;q=0.9, de;q=0.9, en;q=0.8, *;q=0.5"
278        );
279    }
280}