1#![allow(clippy::module_name_repetitions)]
12
13use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr};
14
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17
18#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
20#[error("Invalid scope format")]
21pub struct InvalidScope;
22
23#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
25pub struct ScopeToken(Cow<'static, str>);
26
27impl ScopeToken {
28 #[must_use]
31 pub const fn from_static(token: &'static str) -> Self {
32 Self(Cow::Borrowed(token))
33 }
34
35 #[must_use]
37 pub fn as_str(&self) -> &str {
38 self.0.as_ref()
39 }
40}
41
42pub const OPENID: ScopeToken = ScopeToken::from_static("openid");
46
47pub const PROFILE: ScopeToken = ScopeToken::from_static("profile");
51
52pub const EMAIL: ScopeToken = ScopeToken::from_static("email");
56
57pub const ADDRESS: ScopeToken = ScopeToken::from_static("address");
61
62pub const PHONE: ScopeToken = ScopeToken::from_static("phone");
66
67pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::from_static("offline_access");
73
74fn nqchar(c: char) -> bool {
79 '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
80}
81
82impl FromStr for ScopeToken {
83 type Err = InvalidScope;
84
85 fn from_str(s: &str) -> Result<Self, Self::Err> {
86 if !s.is_empty() && s.chars().all(nqchar) {
91 Ok(ScopeToken(Cow::Owned(s.into())))
92 } else {
93 Err(InvalidScope)
94 }
95 }
96}
97
98impl Deref for ScopeToken {
99 type Target = str;
100
101 fn deref(&self) -> &Self::Target {
102 &self.0
103 }
104}
105
106impl std::fmt::Display for ScopeToken {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 self.0.fmt(f)
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct Scope(BTreeSet<ScopeToken>);
115
116impl Deref for Scope {
117 type Target = BTreeSet<ScopeToken>;
118
119 fn deref(&self) -> &Self::Target {
120 &self.0
121 }
122}
123
124impl FromStr for Scope {
125 type Err = InvalidScope;
126
127 fn from_str(s: &str) -> Result<Self, Self::Err> {
128 let scopes: Result<BTreeSet<ScopeToken>, InvalidScope> =
133 s.split(' ').map(ScopeToken::from_str).collect();
134
135 Ok(Self(scopes?))
136 }
137}
138
139impl Scope {
140 #[must_use]
142 pub fn is_empty(&self) -> bool {
143 self.0.is_empty()
145 }
146
147 #[must_use]
149 pub fn len(&self) -> usize {
150 self.0.len()
151 }
152
153 #[must_use]
155 pub fn contains(&self, token: &str) -> bool {
156 ScopeToken::from_str(token)
157 .map(|token| self.0.contains(&token))
158 .unwrap_or(false)
159 }
160
161 pub fn insert(&mut self, value: ScopeToken) -> bool {
165 self.0.insert(value)
166 }
167}
168
169impl std::fmt::Display for Scope {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 for (index, token) in self.0.iter().enumerate() {
172 if index == 0 {
173 write!(f, "{token}")?;
174 } else {
175 write!(f, " {token}")?;
176 }
177 }
178
179 Ok(())
180 }
181}
182
183impl Serialize for Scope {
184 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185 where
186 S: serde::Serializer,
187 {
188 self.to_string().serialize(serializer)
189 }
190}
191
192impl<'de> Deserialize<'de> for Scope {
193 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
194 where
195 D: serde::Deserializer<'de>,
196 {
197 let scope: String = Deserialize::deserialize(deserializer)?;
199 Scope::from_str(&scope).map_err(serde::de::Error::custom)
200 }
201}
202
203impl FromIterator<ScopeToken> for Scope {
204 fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
205 Self(BTreeSet::from_iter(iter))
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn parse_scope_token() {
215 assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
216
217 assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
218 }
219
220 #[test]
221 fn parse_scope() {
222 let scope = Scope::from_str("openid profile address").unwrap();
223 assert_eq!(scope.len(), 3);
224 assert!(scope.contains("openid"));
225 assert!(scope.contains("profile"));
226 assert!(scope.contains("address"));
227 assert!(!scope.contains("unknown"));
228
229 assert!(
230 Scope::from_str("").is_err(),
231 "there should always be at least one token in the scope"
232 );
233
234 assert!(Scope::from_str("invalid\\scope").is_err());
235 assert!(Scope::from_str("no double space").is_err());
236 assert!(Scope::from_str(" no leading space").is_err());
237 assert!(Scope::from_str("no trailing space ").is_err());
238
239 let scope = Scope::from_str("openid").unwrap();
240 assert_eq!(scope.len(), 1);
241 assert!(scope.contains("openid"));
242 assert!(!scope.contains("profile"));
243 assert!(!scope.contains("address"));
244
245 assert_eq!(
246 Scope::from_str("order does not matter"),
247 Scope::from_str("matter not order does"),
248 );
249
250 assert!(Scope::from_str("http://example.com").is_ok());
251 assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:*").is_ok());
252 }
253}