1#![allow(clippy::module_name_repetitions)]
12
13use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
14
15use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
16use serde_with::{DeserializeFromStr, SerializeDisplay};
17use thiserror::Error;
18
19#[derive(Debug, Error, Clone, PartialEq, Eq)]
21#[error("invalid response type")]
22pub struct InvalidResponseType;
23
24#[derive(
32 Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, SerializeDisplay, DeserializeFromStr,
33)]
34#[non_exhaustive]
35pub enum ResponseTypeToken {
36 Code,
38
39 IdToken,
41
42 Token,
44
45 Unknown(String),
47}
48
49impl core::fmt::Display for ResponseTypeToken {
50 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51 match self {
52 ResponseTypeToken::Code => f.write_str("code"),
53 ResponseTypeToken::IdToken => f.write_str("id_token"),
54 ResponseTypeToken::Token => f.write_str("token"),
55 ResponseTypeToken::Unknown(s) => f.write_str(s),
56 }
57 }
58}
59
60impl core::str::FromStr for ResponseTypeToken {
61 type Err = core::convert::Infallible;
62
63 fn from_str(s: &str) -> Result<Self, Self::Err> {
64 match s {
65 "code" => Ok(Self::Code),
66 "id_token" => Ok(Self::IdToken),
67 "token" => Ok(Self::Token),
68 s => Ok(Self::Unknown(s.to_owned())),
69 }
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr, PartialOrd, Ord)]
82pub struct ResponseType(BTreeSet<ResponseTypeToken>);
83
84impl std::ops::Deref for ResponseType {
85 type Target = BTreeSet<ResponseTypeToken>;
86
87 fn deref(&self) -> &Self::Target {
88 &self.0
89 }
90}
91
92impl ResponseType {
93 #[must_use]
95 pub fn has_code(&self) -> bool {
96 self.0.contains(&ResponseTypeToken::Code)
97 }
98
99 #[must_use]
101 pub fn has_id_token(&self) -> bool {
102 self.0.contains(&ResponseTypeToken::IdToken)
103 }
104
105 #[must_use]
107 pub fn has_token(&self) -> bool {
108 self.0.contains(&ResponseTypeToken::Token)
109 }
110}
111
112impl FromStr for ResponseType {
113 type Err = InvalidResponseType;
114
115 fn from_str(s: &str) -> Result<Self, Self::Err> {
116 let s = s.trim();
117
118 if s.is_empty() {
119 Err(InvalidResponseType)
120 } else if s == "none" {
121 Ok(Self(BTreeSet::new()))
122 } else {
123 s.split_ascii_whitespace()
124 .map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
125 .collect::<Result<_, _>>()
126 }
127 }
128}
129
130impl fmt::Display for ResponseType {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 let mut iter = self.iter();
133
134 if let Some(first) = iter.next() {
136 first.fmt(f)?;
137 } else {
138 write!(f, "none")?;
140 return Ok(());
141 }
142
143 for item in iter {
145 write!(f, " {item}")?;
146 }
147
148 Ok(())
149 }
150}
151
152impl FromIterator<ResponseTypeToken> for ResponseType {
153 fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
154 Self(BTreeSet::from_iter(iter))
155 }
156}
157
158impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
159 fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
160 match response_type {
161 OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
162 OAuthAuthorizationEndpointResponseType::CodeIdToken => {
163 Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
164 }
165 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
166 [
167 ResponseTypeToken::Code,
168 ResponseTypeToken::IdToken,
169 ResponseTypeToken::Token,
170 ]
171 .into(),
172 ),
173 OAuthAuthorizationEndpointResponseType::CodeToken => {
174 Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
175 }
176 OAuthAuthorizationEndpointResponseType::IdToken => {
177 Self([ResponseTypeToken::IdToken].into())
178 }
179 OAuthAuthorizationEndpointResponseType::IdTokenToken => {
180 Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
181 }
182 OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
183 OAuthAuthorizationEndpointResponseType::Token => {
184 Self([ResponseTypeToken::Token].into())
185 }
186 }
187 }
188}
189
190impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
191 type Error = InvalidResponseType;
192
193 fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
194 if response_type
195 .iter()
196 .any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
197 {
198 return Err(InvalidResponseType);
199 }
200
201 let tokens = response_type.iter().collect::<Vec<_>>();
202 let res = match *tokens {
203 [ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
204 [ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
205 [ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
206 [ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
207 OAuthAuthorizationEndpointResponseType::CodeIdToken
208 }
209 [ResponseTypeToken::Code, ResponseTypeToken::Token] => {
210 OAuthAuthorizationEndpointResponseType::CodeToken
211 }
212 [ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
213 OAuthAuthorizationEndpointResponseType::IdTokenToken
214 }
215 [
216 ResponseTypeToken::Code,
217 ResponseTypeToken::IdToken,
218 ResponseTypeToken::Token,
219 ] => OAuthAuthorizationEndpointResponseType::CodeIdTokenToken,
220 _ => OAuthAuthorizationEndpointResponseType::None,
221 };
222
223 Ok(res)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn deserialize_response_type_token() {
233 assert_eq!(
234 serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
235 ResponseTypeToken::Code
236 );
237 assert_eq!(
238 serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
239 ResponseTypeToken::IdToken
240 );
241 assert_eq!(
242 serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
243 ResponseTypeToken::Token
244 );
245 assert_eq!(
246 serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
247 ResponseTypeToken::Unknown("something_unsupported".to_owned())
248 );
249 }
250
251 #[test]
252 fn serialize_response_type_token() {
253 assert_eq!(
254 serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
255 "\"code\""
256 );
257 assert_eq!(
258 serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
259 "\"id_token\""
260 );
261 assert_eq!(
262 serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
263 "\"token\""
264 );
265 assert_eq!(
266 serde_json::to_string(&ResponseTypeToken::Unknown(
267 "something_unsupported".to_owned()
268 ))
269 .unwrap(),
270 "\"something_unsupported\""
271 );
272 }
273
274 #[test]
275 #[allow(clippy::too_many_lines)]
276 fn deserialize_response_type() {
277 serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
278
279 let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
280 let mut iter = res_type.iter();
281 assert_eq!(iter.next(), None);
282 assert_eq!(
283 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
284 OAuthAuthorizationEndpointResponseType::None
285 );
286
287 let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
288 let mut iter = res_type.iter();
289 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
290 assert_eq!(iter.next(), None);
291 assert_eq!(
292 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
293 OAuthAuthorizationEndpointResponseType::Code
294 );
295
296 let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
297 let mut iter = res_type.iter();
298 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
299 assert_eq!(iter.next(), None);
300 assert_eq!(
301 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
302 OAuthAuthorizationEndpointResponseType::Code
303 );
304
305 let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
306 let mut iter = res_type.iter();
307 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
308 assert_eq!(iter.next(), None);
309 assert_eq!(
310 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
311 OAuthAuthorizationEndpointResponseType::IdToken
312 );
313
314 let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
315 let mut iter = res_type.iter();
316 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
317 assert_eq!(iter.next(), None);
318 assert_eq!(
319 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
320 OAuthAuthorizationEndpointResponseType::Token
321 );
322
323 let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
324 let mut iter = res_type.iter();
325 assert_eq!(
326 iter.next(),
327 Some(&ResponseTypeToken::Unknown(
328 "something_unsupported".to_owned()
329 ))
330 );
331 assert_eq!(iter.next(), None);
332 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
333
334 let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
335 let mut iter = res_type.iter();
336 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
337 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
338 assert_eq!(iter.next(), None);
339 assert_eq!(
340 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
341 OAuthAuthorizationEndpointResponseType::CodeIdToken
342 );
343
344 let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
345 let mut iter = res_type.iter();
346 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
347 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
348 assert_eq!(iter.next(), None);
349 assert_eq!(
350 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
351 OAuthAuthorizationEndpointResponseType::CodeToken
352 );
353
354 let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
355 let mut iter = res_type.iter();
356 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
357 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
358 assert_eq!(iter.next(), None);
359 assert_eq!(
360 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
361 OAuthAuthorizationEndpointResponseType::IdTokenToken
362 );
363
364 let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
365 let mut iter = res_type.iter();
366 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
367 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
368 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
369 assert_eq!(iter.next(), None);
370 assert_eq!(
371 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
372 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
373 );
374
375 let res_type =
376 serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
377 .unwrap();
378 let mut iter = res_type.iter();
379 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
380 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
381 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
382 assert_eq!(
383 iter.next(),
384 Some(&ResponseTypeToken::Unknown(
385 "something_unsupported".to_owned()
386 ))
387 );
388 assert_eq!(iter.next(), None);
389 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
390
391 let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
393 let mut iter = res_type.iter();
394 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
395 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
396 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
397 assert_eq!(iter.next(), None);
398 assert_eq!(
399 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
400 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
401 );
402
403 let res_type =
404 serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
405 let mut iter = res_type.iter();
406 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
407 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
408 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
409 assert_eq!(iter.next(), None);
410 assert_eq!(
411 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
412 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
413 );
414 }
415
416 #[test]
417 fn serialize_response_type() {
418 assert_eq!(
419 serde_json::to_string(&ResponseType::from(
420 OAuthAuthorizationEndpointResponseType::None
421 ))
422 .unwrap(),
423 "\"none\""
424 );
425 assert_eq!(
426 serde_json::to_string(&ResponseType::from(
427 OAuthAuthorizationEndpointResponseType::Code
428 ))
429 .unwrap(),
430 "\"code\""
431 );
432 assert_eq!(
433 serde_json::to_string(&ResponseType::from(
434 OAuthAuthorizationEndpointResponseType::IdToken
435 ))
436 .unwrap(),
437 "\"id_token\""
438 );
439 assert_eq!(
440 serde_json::to_string(&ResponseType::from(
441 OAuthAuthorizationEndpointResponseType::CodeIdToken
442 ))
443 .unwrap(),
444 "\"code id_token\""
445 );
446 assert_eq!(
447 serde_json::to_string(&ResponseType::from(
448 OAuthAuthorizationEndpointResponseType::CodeToken
449 ))
450 .unwrap(),
451 "\"code token\""
452 );
453 assert_eq!(
454 serde_json::to_string(&ResponseType::from(
455 OAuthAuthorizationEndpointResponseType::IdTokenToken
456 ))
457 .unwrap(),
458 "\"id_token token\""
459 );
460 assert_eq!(
461 serde_json::to_string(&ResponseType::from(
462 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
463 ))
464 .unwrap(),
465 "\"code id_token token\""
466 );
467
468 assert_eq!(
469 serde_json::to_string(
470 &[
471 ResponseTypeToken::Unknown("something_unsupported".to_owned()),
472 ResponseTypeToken::Code
473 ]
474 .into_iter()
475 .collect::<ResponseType>()
476 )
477 .unwrap(),
478 "\"code something_unsupported\""
479 );
480
481 let res = [
483 ResponseTypeToken::IdToken,
484 ResponseTypeToken::Token,
485 ResponseTypeToken::Code,
486 ]
487 .into_iter()
488 .collect::<ResponseType>();
489 assert_eq!(
490 serde_json::to_string(&res).unwrap(),
491 "\"code id_token token\""
492 );
493
494 let res = [
495 ResponseTypeToken::Code,
496 ResponseTypeToken::Token,
497 ResponseTypeToken::IdToken,
498 ]
499 .into_iter()
500 .collect::<ResponseType>();
501 assert_eq!(
502 serde_json::to_string(&res).unwrap(),
503 "\"code id_token token\""
504 );
505 }
506}