oauth2_types/
scope.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! Types to define an [access token's scope].
8//!
9//! [access token's scope]: https://www.rfc-editor.org/rfc/rfc6749#section-3.3
10
11#![allow(clippy::module_name_repetitions)]
12
13use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr};
14
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17
18/// The error type returned when a scope is invalid.
19#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
20#[error("Invalid scope format")]
21pub struct InvalidScope;
22
23/// A scope token or scope value.
24#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
25pub struct ScopeToken(Cow<'static, str>);
26
27impl ScopeToken {
28    /// Create a `ScopeToken` from a static string. The validity of it is not
29    /// checked since it has to be valid in const contexts
30    #[must_use]
31    pub const fn from_static(token: &'static str) -> Self {
32        Self(Cow::Borrowed(token))
33    }
34
35    /// Get the scope token as a string slice.
36    #[must_use]
37    pub fn as_str(&self) -> &str {
38        self.0.as_ref()
39    }
40}
41
42/// `openid`.
43///
44/// Must be included in OpenID Connect requests.
45pub const OPENID: ScopeToken = ScopeToken::from_static("openid");
46
47/// `profile`.
48///
49/// Requests access to the End-User's default profile Claims.
50pub const PROFILE: ScopeToken = ScopeToken::from_static("profile");
51
52/// `email`.
53///
54/// Requests access to the `email` and `email_verified` Claims.
55pub const EMAIL: ScopeToken = ScopeToken::from_static("email");
56
57/// `address`.
58///
59/// Requests access to the `address` Claim.
60pub const ADDRESS: ScopeToken = ScopeToken::from_static("address");
61
62/// `phone`.
63///
64/// Requests access to the `phone_number` and `phone_number_verified` Claims.
65pub const PHONE: ScopeToken = ScopeToken::from_static("phone");
66
67/// `offline_access`.
68///
69/// Requests that an OAuth 2.0 Refresh Token be issued that can be used to
70/// obtain an Access Token that grants access to the End-User's Userinfo
71/// Endpoint even when the End-User is not present (not logged in).
72pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::from_static("offline_access");
73
74// As per RFC6749 appendix A:
75// https://datatracker.ietf.org/doc/html/rfc6749#appendix-A
76//
77//    NQCHAR     = %x21 / %x23-5B / %x5D-7E
78fn nqchar(c: char) -> bool {
79    '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
80}
81
82impl FromStr for ScopeToken {
83    type Err = InvalidScope;
84
85    fn from_str(s: &str) -> Result<Self, Self::Err> {
86        // As per RFC6749 appendix A.4:
87        // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
88        //
89        //    scope-token = 1*NQCHAR
90        if !s.is_empty() && s.chars().all(nqchar) {
91            Ok(ScopeToken(Cow::Owned(s.into())))
92        } else {
93            Err(InvalidScope)
94        }
95    }
96}
97
98impl Deref for ScopeToken {
99    type Target = str;
100
101    fn deref(&self) -> &Self::Target {
102        &self.0
103    }
104}
105
106impl std::fmt::Display for ScopeToken {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        self.0.fmt(f)
109    }
110}
111
112/// A scope.
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct Scope(BTreeSet<ScopeToken>);
115
116impl Deref for Scope {
117    type Target = BTreeSet<ScopeToken>;
118
119    fn deref(&self) -> &Self::Target {
120        &self.0
121    }
122}
123
124impl FromStr for Scope {
125    type Err = InvalidScope;
126
127    fn from_str(s: &str) -> Result<Self, Self::Err> {
128        // As per RFC6749 appendix A.4:
129        // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
130        //
131        //    scope       = scope-token *( SP scope-token )
132        let scopes: Result<BTreeSet<ScopeToken>, InvalidScope> =
133            s.split(' ').map(ScopeToken::from_str).collect();
134
135        Ok(Self(scopes?))
136    }
137}
138
139impl Scope {
140    /// Whether this `Scope` is empty.
141    #[must_use]
142    pub fn is_empty(&self) -> bool {
143        // This should never be the case?
144        self.0.is_empty()
145    }
146
147    /// The number of tokens in the `Scope`.
148    #[must_use]
149    pub fn len(&self) -> usize {
150        self.0.len()
151    }
152
153    /// Whether this `Scope` contains the given value.
154    #[must_use]
155    pub fn contains(&self, token: &str) -> bool {
156        ScopeToken::from_str(token)
157            .map(|token| self.0.contains(&token))
158            .unwrap_or(false)
159    }
160
161    /// Inserts the given token in this `Scope`.
162    ///
163    /// Returns whether the token was newly inserted.
164    pub fn insert(&mut self, value: ScopeToken) -> bool {
165        self.0.insert(value)
166    }
167}
168
169impl std::fmt::Display for Scope {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        for (index, token) in self.0.iter().enumerate() {
172            if index == 0 {
173                write!(f, "{token}")?;
174            } else {
175                write!(f, " {token}")?;
176            }
177        }
178
179        Ok(())
180    }
181}
182
183impl Serialize for Scope {
184    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185    where
186        S: serde::Serializer,
187    {
188        self.to_string().serialize(serializer)
189    }
190}
191
192impl<'de> Deserialize<'de> for Scope {
193    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
194    where
195        D: serde::Deserializer<'de>,
196    {
197        // FIXME: seems like there is an unnecessary clone here?
198        let scope: String = Deserialize::deserialize(deserializer)?;
199        Scope::from_str(&scope).map_err(serde::de::Error::custom)
200    }
201}
202
203impl FromIterator<ScopeToken> for Scope {
204    fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
205        Self(BTreeSet::from_iter(iter))
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn parse_scope_token() {
215        assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
216
217        assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
218    }
219
220    #[test]
221    fn parse_scope() {
222        let scope = Scope::from_str("openid profile address").unwrap();
223        assert_eq!(scope.len(), 3);
224        assert!(scope.contains("openid"));
225        assert!(scope.contains("profile"));
226        assert!(scope.contains("address"));
227        assert!(!scope.contains("unknown"));
228
229        assert!(
230            Scope::from_str("").is_err(),
231            "there should always be at least one token in the scope"
232        );
233
234        assert!(Scope::from_str("invalid\\scope").is_err());
235        assert!(Scope::from_str("no  double space").is_err());
236        assert!(Scope::from_str(" no leading space").is_err());
237        assert!(Scope::from_str("no trailing space ").is_err());
238
239        let scope = Scope::from_str("openid").unwrap();
240        assert_eq!(scope.len(), 1);
241        assert!(scope.contains("openid"));
242        assert!(!scope.contains("profile"));
243        assert!(!scope.contains("address"));
244
245        assert_eq!(
246            Scope::from_str("order does not matter"),
247            Scope::from_str("matter not order does"),
248        );
249
250        assert!(Scope::from_str("http://example.com").is_ok());
251        assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:*").is_ok());
252    }
253}