oauth2_types/
response_type.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! [Response types] in the OpenID Connect specification.
8//!
9//! [Response types]: https://openid.net/specs/openid-connect-core-1_0.html#Authentication
10
11#![allow(clippy::module_name_repetitions)]
12
13use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
14
15use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
16use serde_with::{DeserializeFromStr, SerializeDisplay};
17use thiserror::Error;
18
19/// An error encountered when trying to parse an invalid [`ResponseType`].
20#[derive(Debug, Error, Clone, PartialEq, Eq)]
21#[error("invalid response type")]
22pub struct InvalidResponseType;
23
24/// The accepted tokens in a [`ResponseType`].
25///
26/// `none` is not in this enum because it is represented by an empty
27/// [`ResponseType`].
28///
29/// This type also accepts unknown tokens that can be constructed via it's
30/// `FromStr` implementation or used via its `Display` implementation.
31#[derive(
32    Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, SerializeDisplay, DeserializeFromStr,
33)]
34#[non_exhaustive]
35pub enum ResponseTypeToken {
36    /// `code`
37    Code,
38
39    /// `id_token`
40    IdToken,
41
42    /// `token`
43    Token,
44
45    /// Unknown token.
46    Unknown(String),
47}
48
49impl core::fmt::Display for ResponseTypeToken {
50    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51        match self {
52            ResponseTypeToken::Code => f.write_str("code"),
53            ResponseTypeToken::IdToken => f.write_str("id_token"),
54            ResponseTypeToken::Token => f.write_str("token"),
55            ResponseTypeToken::Unknown(s) => f.write_str(s),
56        }
57    }
58}
59
60impl core::str::FromStr for ResponseTypeToken {
61    type Err = core::convert::Infallible;
62
63    fn from_str(s: &str) -> Result<Self, Self::Err> {
64        match s {
65            "code" => Ok(Self::Code),
66            "id_token" => Ok(Self::IdToken),
67            "token" => Ok(Self::Token),
68            s => Ok(Self::Unknown(s.to_owned())),
69        }
70    }
71}
72
73/// An [OAuth 2.0 `response_type` value] that the client can use
74/// at the [authorization endpoint].
75///
76/// It is recommended to construct this type from an
77/// [`OAuthAuthorizationEndpointResponseType`].
78///
79/// [OAuth 2.0 `response_type` value]: https://www.rfc-editor.org/rfc/rfc7591#page-9
80/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
81#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr, PartialOrd, Ord)]
82pub struct ResponseType(BTreeSet<ResponseTypeToken>);
83
84impl std::ops::Deref for ResponseType {
85    type Target = BTreeSet<ResponseTypeToken>;
86
87    fn deref(&self) -> &Self::Target {
88        &self.0
89    }
90}
91
92impl ResponseType {
93    /// Whether this response type requests a code.
94    #[must_use]
95    pub fn has_code(&self) -> bool {
96        self.0.contains(&ResponseTypeToken::Code)
97    }
98
99    /// Whether this response type requests an ID token.
100    #[must_use]
101    pub fn has_id_token(&self) -> bool {
102        self.0.contains(&ResponseTypeToken::IdToken)
103    }
104
105    /// Whether this response type requests a token.
106    #[must_use]
107    pub fn has_token(&self) -> bool {
108        self.0.contains(&ResponseTypeToken::Token)
109    }
110}
111
112impl FromStr for ResponseType {
113    type Err = InvalidResponseType;
114
115    fn from_str(s: &str) -> Result<Self, Self::Err> {
116        let s = s.trim();
117
118        if s.is_empty() {
119            Err(InvalidResponseType)
120        } else if s == "none" {
121            Ok(Self(BTreeSet::new()))
122        } else {
123            s.split_ascii_whitespace()
124                .map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
125                .collect::<Result<_, _>>()
126        }
127    }
128}
129
130impl fmt::Display for ResponseType {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        let mut iter = self.iter();
133
134        // First item shouldn't have a leading space
135        if let Some(first) = iter.next() {
136            first.fmt(f)?;
137        } else {
138            // If the whole iterator is empty, write 'none' instead
139            write!(f, "none")?;
140            return Ok(());
141        }
142
143        // Write the other items with a leading space
144        for item in iter {
145            write!(f, " {item}")?;
146        }
147
148        Ok(())
149    }
150}
151
152impl FromIterator<ResponseTypeToken> for ResponseType {
153    fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
154        Self(BTreeSet::from_iter(iter))
155    }
156}
157
158impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
159    fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
160        match response_type {
161            OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
162            OAuthAuthorizationEndpointResponseType::CodeIdToken => {
163                Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
164            }
165            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
166                [
167                    ResponseTypeToken::Code,
168                    ResponseTypeToken::IdToken,
169                    ResponseTypeToken::Token,
170                ]
171                .into(),
172            ),
173            OAuthAuthorizationEndpointResponseType::CodeToken => {
174                Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
175            }
176            OAuthAuthorizationEndpointResponseType::IdToken => {
177                Self([ResponseTypeToken::IdToken].into())
178            }
179            OAuthAuthorizationEndpointResponseType::IdTokenToken => {
180                Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
181            }
182            OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
183            OAuthAuthorizationEndpointResponseType::Token => {
184                Self([ResponseTypeToken::Token].into())
185            }
186        }
187    }
188}
189
190impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
191    type Error = InvalidResponseType;
192
193    fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
194        if response_type
195            .iter()
196            .any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
197        {
198            return Err(InvalidResponseType);
199        }
200
201        let tokens = response_type.iter().collect::<Vec<_>>();
202        let res = match *tokens {
203            [ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
204            [ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
205            [ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
206            [ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
207                OAuthAuthorizationEndpointResponseType::CodeIdToken
208            }
209            [ResponseTypeToken::Code, ResponseTypeToken::Token] => {
210                OAuthAuthorizationEndpointResponseType::CodeToken
211            }
212            [ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
213                OAuthAuthorizationEndpointResponseType::IdTokenToken
214            }
215            [
216                ResponseTypeToken::Code,
217                ResponseTypeToken::IdToken,
218                ResponseTypeToken::Token,
219            ] => OAuthAuthorizationEndpointResponseType::CodeIdTokenToken,
220            _ => OAuthAuthorizationEndpointResponseType::None,
221        };
222
223        Ok(res)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn deserialize_response_type_token() {
233        assert_eq!(
234            serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
235            ResponseTypeToken::Code
236        );
237        assert_eq!(
238            serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
239            ResponseTypeToken::IdToken
240        );
241        assert_eq!(
242            serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
243            ResponseTypeToken::Token
244        );
245        assert_eq!(
246            serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
247            ResponseTypeToken::Unknown("something_unsupported".to_owned())
248        );
249    }
250
251    #[test]
252    fn serialize_response_type_token() {
253        assert_eq!(
254            serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
255            "\"code\""
256        );
257        assert_eq!(
258            serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
259            "\"id_token\""
260        );
261        assert_eq!(
262            serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
263            "\"token\""
264        );
265        assert_eq!(
266            serde_json::to_string(&ResponseTypeToken::Unknown(
267                "something_unsupported".to_owned()
268            ))
269            .unwrap(),
270            "\"something_unsupported\""
271        );
272    }
273
274    #[test]
275    #[allow(clippy::too_many_lines)]
276    fn deserialize_response_type() {
277        serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
278
279        let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
280        let mut iter = res_type.iter();
281        assert_eq!(iter.next(), None);
282        assert_eq!(
283            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
284            OAuthAuthorizationEndpointResponseType::None
285        );
286
287        let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
288        let mut iter = res_type.iter();
289        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
290        assert_eq!(iter.next(), None);
291        assert_eq!(
292            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
293            OAuthAuthorizationEndpointResponseType::Code
294        );
295
296        let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
297        let mut iter = res_type.iter();
298        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
299        assert_eq!(iter.next(), None);
300        assert_eq!(
301            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
302            OAuthAuthorizationEndpointResponseType::Code
303        );
304
305        let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
306        let mut iter = res_type.iter();
307        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
308        assert_eq!(iter.next(), None);
309        assert_eq!(
310            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
311            OAuthAuthorizationEndpointResponseType::IdToken
312        );
313
314        let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
315        let mut iter = res_type.iter();
316        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
317        assert_eq!(iter.next(), None);
318        assert_eq!(
319            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
320            OAuthAuthorizationEndpointResponseType::Token
321        );
322
323        let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
324        let mut iter = res_type.iter();
325        assert_eq!(
326            iter.next(),
327            Some(&ResponseTypeToken::Unknown(
328                "something_unsupported".to_owned()
329            ))
330        );
331        assert_eq!(iter.next(), None);
332        OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
333
334        let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
335        let mut iter = res_type.iter();
336        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
337        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
338        assert_eq!(iter.next(), None);
339        assert_eq!(
340            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
341            OAuthAuthorizationEndpointResponseType::CodeIdToken
342        );
343
344        let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
345        let mut iter = res_type.iter();
346        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
347        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
348        assert_eq!(iter.next(), None);
349        assert_eq!(
350            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
351            OAuthAuthorizationEndpointResponseType::CodeToken
352        );
353
354        let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
355        let mut iter = res_type.iter();
356        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
357        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
358        assert_eq!(iter.next(), None);
359        assert_eq!(
360            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
361            OAuthAuthorizationEndpointResponseType::IdTokenToken
362        );
363
364        let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
365        let mut iter = res_type.iter();
366        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
367        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
368        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
369        assert_eq!(iter.next(), None);
370        assert_eq!(
371            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
372            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
373        );
374
375        let res_type =
376            serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
377                .unwrap();
378        let mut iter = res_type.iter();
379        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
380        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
381        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
382        assert_eq!(
383            iter.next(),
384            Some(&ResponseTypeToken::Unknown(
385                "something_unsupported".to_owned()
386            ))
387        );
388        assert_eq!(iter.next(), None);
389        OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
390
391        // Order doesn't matter
392        let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
393        let mut iter = res_type.iter();
394        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
395        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
396        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
397        assert_eq!(iter.next(), None);
398        assert_eq!(
399            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
400            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
401        );
402
403        let res_type =
404            serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
405        let mut iter = res_type.iter();
406        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
407        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
408        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
409        assert_eq!(iter.next(), None);
410        assert_eq!(
411            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
412            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
413        );
414    }
415
416    #[test]
417    fn serialize_response_type() {
418        assert_eq!(
419            serde_json::to_string(&ResponseType::from(
420                OAuthAuthorizationEndpointResponseType::None
421            ))
422            .unwrap(),
423            "\"none\""
424        );
425        assert_eq!(
426            serde_json::to_string(&ResponseType::from(
427                OAuthAuthorizationEndpointResponseType::Code
428            ))
429            .unwrap(),
430            "\"code\""
431        );
432        assert_eq!(
433            serde_json::to_string(&ResponseType::from(
434                OAuthAuthorizationEndpointResponseType::IdToken
435            ))
436            .unwrap(),
437            "\"id_token\""
438        );
439        assert_eq!(
440            serde_json::to_string(&ResponseType::from(
441                OAuthAuthorizationEndpointResponseType::CodeIdToken
442            ))
443            .unwrap(),
444            "\"code id_token\""
445        );
446        assert_eq!(
447            serde_json::to_string(&ResponseType::from(
448                OAuthAuthorizationEndpointResponseType::CodeToken
449            ))
450            .unwrap(),
451            "\"code token\""
452        );
453        assert_eq!(
454            serde_json::to_string(&ResponseType::from(
455                OAuthAuthorizationEndpointResponseType::IdTokenToken
456            ))
457            .unwrap(),
458            "\"id_token token\""
459        );
460        assert_eq!(
461            serde_json::to_string(&ResponseType::from(
462                OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
463            ))
464            .unwrap(),
465            "\"code id_token token\""
466        );
467
468        assert_eq!(
469            serde_json::to_string(
470                &[
471                    ResponseTypeToken::Unknown("something_unsupported".to_owned()),
472                    ResponseTypeToken::Code
473                ]
474                .into_iter()
475                .collect::<ResponseType>()
476            )
477            .unwrap(),
478            "\"code something_unsupported\""
479        );
480
481        // Order doesn't matter.
482        let res = [
483            ResponseTypeToken::IdToken,
484            ResponseTypeToken::Token,
485            ResponseTypeToken::Code,
486        ]
487        .into_iter()
488        .collect::<ResponseType>();
489        assert_eq!(
490            serde_json::to_string(&res).unwrap(),
491            "\"code id_token token\""
492        );
493
494        let res = [
495            ResponseTypeToken::Code,
496            ResponseTypeToken::Token,
497            ResponseTypeToken::IdToken,
498        ]
499        .into_iter()
500        .collect::<ResponseType>();
501        assert_eq!(
502            serde_json::to_string(&res).unwrap(),
503            "\"code id_token token\""
504        );
505    }
506}