mas_jose/
claims.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{collections::HashMap, convert::Infallible, marker::PhantomData, ops::Deref};
8
9use base64ct::{Base64UrlUnpadded, Encoding};
10use mas_iana::jose::JsonWebSignatureAlg;
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use sha2::{Digest, Sha256, Sha384, Sha512};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum ClaimError {
17    #[error("missing claim {0:?}")]
18    MissingClaim(&'static str),
19
20    #[error("invalid claim {0:?}")]
21    InvalidClaim(&'static str),
22
23    #[error("could not validate claim {claim:?}")]
24    ValidationError {
25        claim: &'static str,
26        #[source]
27        source: Box<dyn std::error::Error + Send + Sync + 'static>,
28    },
29}
30
31pub trait Validator<T> {
32    /// The associated error type returned by this validator.
33    type Error;
34
35    /// Validate a claim value
36    ///
37    /// # Errors
38    ///
39    /// Returns an error if the value is invalid.
40    fn validate(&self, value: &T) -> Result<(), Self::Error>;
41}
42
43impl<T> Validator<T> for () {
44    type Error = Infallible;
45
46    fn validate(&self, _value: &T) -> Result<(), Self::Error> {
47        Ok(())
48    }
49}
50
51pub struct Claim<T, V = ()> {
52    claim: &'static str,
53    t: PhantomData<T>,
54    v: PhantomData<V>,
55}
56
57impl<T, V> Claim<T, V>
58where
59    V: Validator<T>,
60{
61    #[must_use]
62    pub const fn new(claim: &'static str) -> Self {
63        Self {
64            claim,
65            t: PhantomData,
66            v: PhantomData,
67        }
68    }
69
70    /// Insert a claim into the given claims map.
71    ///
72    /// # Errors
73    ///
74    /// Returns an error if the value failed to serialize.
75    pub fn insert<I>(
76        &self,
77        claims: &mut HashMap<String, serde_json::Value>,
78        value: I,
79    ) -> Result<(), ClaimError>
80    where
81        I: Into<T>,
82        T: Serialize,
83    {
84        let value = value.into();
85        let value: serde_json::Value =
86            serde_json::to_value(&value).map_err(|_| ClaimError::InvalidClaim(self.claim))?;
87        claims.insert(self.claim.to_owned(), value);
88
89        Ok(())
90    }
91
92    /// Extract a claim from the given claims map.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if the value failed to deserialize, if its value is
97    /// invalid or if the claim is missing.
98    pub fn extract_required(
99        &self,
100        claims: &mut HashMap<String, serde_json::Value>,
101    ) -> Result<T, ClaimError>
102    where
103        T: DeserializeOwned,
104        V: Default,
105        V::Error: std::error::Error + Send + Sync + 'static,
106    {
107        let validator = V::default();
108        self.extract_required_with_options(claims, validator)
109    }
110
111    /// Extract a claim from the given claims map, with the given options.
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if the value failed to deserialize, if its value is
116    /// invalid or if the claim is missing.
117    pub fn extract_required_with_options<I>(
118        &self,
119        claims: &mut HashMap<String, serde_json::Value>,
120        validator: I,
121    ) -> Result<T, ClaimError>
122    where
123        T: DeserializeOwned,
124        I: Into<V>,
125        V::Error: std::error::Error + Send + Sync + 'static,
126    {
127        let validator: V = validator.into();
128        let claim = claims
129            .remove(self.claim)
130            .ok_or(ClaimError::MissingClaim(self.claim))?;
131
132        let res =
133            serde_json::from_value(claim).map_err(|_| ClaimError::InvalidClaim(self.claim))?;
134        validator
135            .validate(&res)
136            .map_err(|source| ClaimError::ValidationError {
137                claim: self.claim,
138                source: Box::new(source),
139            })?;
140        Ok(res)
141    }
142
143    /// Extract a claim from the given claims map, if it exists.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if the value failed to deserialize or if its value is
148    /// invalid.
149    pub fn extract_optional(
150        &self,
151        claims: &mut HashMap<String, serde_json::Value>,
152    ) -> Result<Option<T>, ClaimError>
153    where
154        T: DeserializeOwned,
155        V: Default,
156        V::Error: std::error::Error + Send + Sync + 'static,
157    {
158        let validator = V::default();
159        self.extract_optional_with_options(claims, validator)
160    }
161
162    /// Extract a claim from the given claims map, if it exists, with the given
163    /// options.
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if the value failed to deserialize or if its value is
168    /// invalid.
169    pub fn extract_optional_with_options<I>(
170        &self,
171        claims: &mut HashMap<String, serde_json::Value>,
172        validator: I,
173    ) -> Result<Option<T>, ClaimError>
174    where
175        T: DeserializeOwned,
176        I: Into<V>,
177        V::Error: std::error::Error + Send + Sync + 'static,
178    {
179        match self.extract_required_with_options(claims, validator) {
180            Ok(v) => Ok(Some(v)),
181            Err(ClaimError::MissingClaim(_)) => Ok(None),
182            Err(e) => Err(e),
183        }
184    }
185
186    /// Assert that the claim is absent.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if the claim is present.
191    pub fn assert_absent(
192        &self,
193        claims: &HashMap<String, serde_json::Value>,
194    ) -> Result<(), ClaimError> {
195        if claims.contains_key(self.claim) {
196            Err(ClaimError::InvalidClaim(self.claim))
197        } else {
198            Ok(())
199        }
200    }
201}
202
203#[derive(Debug, Clone)]
204pub struct TimeOptions {
205    when: chrono::DateTime<chrono::Utc>,
206    leeway: chrono::Duration,
207}
208
209impl TimeOptions {
210    #[must_use]
211    pub fn new(when: chrono::DateTime<chrono::Utc>) -> Self {
212        Self {
213            when,
214            leeway: chrono::Duration::microseconds(5 * 60 * 1000 * 1000),
215        }
216    }
217
218    #[must_use]
219    pub fn leeway(mut self, leeway: chrono::Duration) -> Self {
220        self.leeway = leeway;
221        self
222    }
223}
224
225#[derive(Debug, Clone, Copy, Error)]
226#[error("Current time is too far away")]
227pub struct TimeTooFarError;
228
229#[derive(Debug, Clone)]
230pub struct TimeNotAfter(TimeOptions);
231
232impl Validator<Timestamp> for TimeNotAfter {
233    type Error = TimeTooFarError;
234    fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
235        if self.0.when <= value.0 + self.0.leeway {
236            Ok(())
237        } else {
238            Err(TimeTooFarError)
239        }
240    }
241}
242
243impl From<TimeOptions> for TimeNotAfter {
244    fn from(opt: TimeOptions) -> Self {
245        Self(opt)
246    }
247}
248
249impl From<&TimeOptions> for TimeNotAfter {
250    fn from(opt: &TimeOptions) -> Self {
251        opt.clone().into()
252    }
253}
254
255#[derive(Debug, Clone)]
256pub struct TimeNotBefore(TimeOptions);
257
258impl Validator<Timestamp> for TimeNotBefore {
259    type Error = TimeTooFarError;
260
261    fn validate(&self, value: &Timestamp) -> Result<(), Self::Error> {
262        if self.0.when >= value.0 - self.0.leeway {
263            Ok(())
264        } else {
265            Err(TimeTooFarError)
266        }
267    }
268}
269
270impl From<TimeOptions> for TimeNotBefore {
271    fn from(opt: TimeOptions) -> Self {
272        Self(opt)
273    }
274}
275
276impl From<&TimeOptions> for TimeNotBefore {
277    fn from(opt: &TimeOptions) -> Self {
278        opt.clone().into()
279    }
280}
281
282/// Hash the given token with the given algorithm for an ID Token claim.
283///
284/// According to the [OpenID Connect Core 1.0 specification].
285///
286/// # Errors
287///
288/// Returns an error if the algorithm is not supported.
289///
290/// [OpenID Connect Core 1.0 specification]: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
291pub fn hash_token(alg: &JsonWebSignatureAlg, token: &str) -> Result<String, TokenHashError> {
292    let bits = match alg {
293        JsonWebSignatureAlg::Hs256
294        | JsonWebSignatureAlg::Rs256
295        | JsonWebSignatureAlg::Es256
296        | JsonWebSignatureAlg::Ps256
297        | JsonWebSignatureAlg::Es256K => {
298            let mut hasher = Sha256::new();
299            hasher.update(token);
300            let hash: [u8; 32] = hasher.finalize().into();
301            // Left-most half
302            hash[..16].to_owned()
303        }
304        JsonWebSignatureAlg::Hs384
305        | JsonWebSignatureAlg::Rs384
306        | JsonWebSignatureAlg::Es384
307        | JsonWebSignatureAlg::Ps384 => {
308            let mut hasher = Sha384::new();
309            hasher.update(token);
310            let hash: [u8; 48] = hasher.finalize().into();
311            // Left-most half
312            hash[..24].to_owned()
313        }
314        JsonWebSignatureAlg::Hs512
315        | JsonWebSignatureAlg::Rs512
316        | JsonWebSignatureAlg::Es512
317        | JsonWebSignatureAlg::Ps512 => {
318            let mut hasher = Sha512::new();
319            hasher.update(token);
320            let hash: [u8; 64] = hasher.finalize().into();
321            // Left-most half
322            hash[..32].to_owned()
323        }
324        _ => return Err(TokenHashError::UnsupportedAlgorithm),
325    };
326
327    Ok(Base64UrlUnpadded::encode_string(&bits))
328}
329
330#[derive(Debug, Clone, Copy, Error)]
331pub enum TokenHashError {
332    #[error("Hashes don't match")]
333    HashMismatch,
334
335    #[error("Unsupported algorithm for hashing")]
336    UnsupportedAlgorithm,
337}
338
339#[derive(Debug, Clone)]
340pub struct TokenHash<'a> {
341    alg: &'a JsonWebSignatureAlg,
342    token: &'a str,
343}
344
345impl<'a> TokenHash<'a> {
346    /// Creates a new `TokenHash` validator for the given algorithm and token.
347    #[must_use]
348    pub fn new(alg: &'a JsonWebSignatureAlg, token: &'a str) -> Self {
349        Self { alg, token }
350    }
351}
352
353impl Validator<String> for TokenHash<'_> {
354    type Error = TokenHashError;
355    fn validate(&self, value: &String) -> Result<(), Self::Error> {
356        if hash_token(self.alg, self.token)? == *value {
357            Ok(())
358        } else {
359            Err(TokenHashError::HashMismatch)
360        }
361    }
362}
363
364#[derive(Debug, Clone, Copy, Error)]
365#[error("Values don't match")]
366pub struct EqualityError;
367
368#[derive(Debug, Clone)]
369pub struct Equality<'a, T: ?Sized> {
370    value: &'a T,
371}
372
373impl<'a, T: ?Sized> Equality<'a, T> {
374    /// Creates a new `Equality` validator for the given value.
375    #[must_use]
376    pub fn new(value: &'a T) -> Self {
377        Self { value }
378    }
379}
380
381impl<T1, T2> Validator<T1> for Equality<'_, T2>
382where
383    T2: PartialEq<T1> + ?Sized,
384{
385    type Error = EqualityError;
386    fn validate(&self, value: &T1) -> Result<(), Self::Error> {
387        if *self.value == *value {
388            Ok(())
389        } else {
390            Err(EqualityError)
391        }
392    }
393}
394
395impl<'a, T: ?Sized> From<&'a T> for Equality<'a, T> {
396    fn from(value: &'a T) -> Self {
397        Self::new(value)
398    }
399}
400
401#[derive(Debug, Clone)]
402pub struct Contains<'a, T> {
403    value: &'a T,
404}
405
406impl<'a, T> Contains<'a, T> {
407    /// Creates a new `Contains` validator for the given value.
408    #[must_use]
409    pub fn new(value: &'a T) -> Self {
410        Self { value }
411    }
412}
413
414#[derive(Debug, Clone, Copy, Error)]
415#[error("OneOrMany doesn't contain value")]
416pub struct ContainsError;
417
418impl<T> Validator<OneOrMany<T>> for Contains<'_, T>
419where
420    T: PartialEq,
421{
422    type Error = ContainsError;
423    fn validate(&self, value: &OneOrMany<T>) -> Result<(), Self::Error> {
424        if value.contains(self.value) {
425            Ok(())
426        } else {
427            Err(ContainsError)
428        }
429    }
430}
431
432impl<'a, T> From<&'a T> for Contains<'a, T> {
433    fn from(value: &'a T) -> Self {
434        Self::new(value)
435    }
436}
437
438#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
439#[serde(transparent)]
440pub struct Timestamp(#[serde(with = "chrono::serde::ts_seconds")] chrono::DateTime<chrono::Utc>);
441
442impl Deref for Timestamp {
443    type Target = chrono::DateTime<chrono::Utc>;
444
445    fn deref(&self) -> &Self::Target {
446        &self.0
447    }
448}
449
450impl From<chrono::DateTime<chrono::Utc>> for Timestamp {
451    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
452        Timestamp(value)
453    }
454}
455
456#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
457#[serde(
458    transparent,
459    bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>")
460)]
461pub struct OneOrMany<T>(
462    // serde_as seems to not work properly with #[serde(transparent)]
463    // We have use plain old #[serde(with = ...)] with serde_with's utilities, which is a bit
464    // verbose but works
465    #[serde(
466        with = "serde_with::As::<serde_with::OneOrMany<serde_with::Same, serde_with::formats::PreferOne>>"
467    )]
468    Vec<T>,
469);
470
471impl<T> Deref for OneOrMany<T> {
472    type Target = Vec<T>;
473
474    fn deref(&self) -> &Self::Target {
475        &self.0
476    }
477}
478
479impl<T> From<Vec<T>> for OneOrMany<T> {
480    fn from(value: Vec<T>) -> Self {
481        Self(value)
482    }
483}
484
485impl<T> From<T> for OneOrMany<T> {
486    fn from(value: T) -> Self {
487        Self(vec![value])
488    }
489}
490
491/// Claims defined in RFC7519 sec. 4.1
492/// <https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1>
493mod rfc7519 {
494    use super::{Claim, Contains, Equality, OneOrMany, TimeNotAfter, TimeNotBefore, Timestamp};
495
496    pub const ISS: Claim<String, Equality<str>> = Claim::new("iss");
497    pub const SUB: Claim<String> = Claim::new("sub");
498    pub const AUD: Claim<OneOrMany<String>, Contains<String>> = Claim::new("aud");
499    pub const NBF: Claim<Timestamp, TimeNotBefore> = Claim::new("nbf");
500    pub const EXP: Claim<Timestamp, TimeNotAfter> = Claim::new("exp");
501    pub const IAT: Claim<Timestamp, TimeNotBefore> = Claim::new("iat");
502    pub const JTI: Claim<String> = Claim::new("jti");
503}
504
505/// Claims defined in OIDC.Core sec. 2 and sec. 5.1
506/// <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>
507/// <https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims>
508mod oidc_core {
509    use url::Url;
510
511    use super::{Claim, Equality, Timestamp, TokenHash};
512
513    pub const AUTH_TIME: Claim<Timestamp> = Claim::new("auth_time");
514    pub const NONCE: Claim<String, Equality<str>> = Claim::new("nonce");
515    pub const AT_HASH: Claim<String, TokenHash> = Claim::new("at_hash");
516    pub const C_HASH: Claim<String, TokenHash> = Claim::new("c_hash");
517
518    pub const NAME: Claim<String> = Claim::new("name");
519    pub const GIVEN_NAME: Claim<String> = Claim::new("given_name");
520    pub const FAMILY_NAME: Claim<String> = Claim::new("family_name");
521    pub const MIDDLE_NAME: Claim<String> = Claim::new("middle_name");
522    pub const NICKNAME: Claim<String> = Claim::new("nickname");
523    pub const PREFERRED_USERNAME: Claim<String> = Claim::new("preferred_username");
524    pub const PROFILE: Claim<Url> = Claim::new("profile");
525    pub const PICTURE: Claim<Url> = Claim::new("picture");
526    pub const WEBSITE: Claim<Url> = Claim::new("website");
527    // TODO: email type?
528    pub const EMAIL: Claim<String> = Claim::new("email");
529    pub const EMAIL_VERIFIED: Claim<bool> = Claim::new("email_verified");
530    pub const GENDER: Claim<String> = Claim::new("gender");
531    // TODO: date type
532    pub const BIRTHDATE: Claim<String> = Claim::new("birthdate");
533    // TODO: timezone type
534    pub const ZONEINFO: Claim<String> = Claim::new("zoneinfo");
535    // TODO: locale type
536    pub const LOCALE: Claim<String> = Claim::new("locale");
537    // TODO: phone number type
538    pub const PHONE_NUMBER: Claim<String> = Claim::new("phone_number");
539    pub const PHONE_NUMBER_VERIFIED: Claim<bool> = Claim::new("phone_number_verified");
540    // TODO: pub const ADDRESS: Claim<Timestamp> = Claim::new("address");
541    pub const UPDATED_AT: Claim<Timestamp> = Claim::new("updated_at");
542}
543
544/// Claims defined in OpenID.FrontChannel
545/// <https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents>
546mod oidc_frontchannel {
547    use super::Claim;
548
549    pub const SID: Claim<String> = Claim::new("sid");
550}
551
552pub use self::{oidc_core::*, oidc_frontchannel::*, rfc7519::*};
553
554#[cfg(test)]
555mod tests {
556    use chrono::TimeZone;
557
558    use super::*;
559
560    #[test]
561    fn timestamp_serde() {
562        let datetime = Timestamp(
563            chrono::Utc
564                .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
565                .unwrap(),
566        );
567        let timestamp = serde_json::Value::Number(1_516_239_022.into());
568
569        assert_eq!(datetime, serde_json::from_value(timestamp.clone()).unwrap());
570        assert_eq!(timestamp, serde_json::to_value(&datetime).unwrap());
571    }
572
573    #[test]
574    fn one_or_many_serde() {
575        let one = OneOrMany(vec!["one".to_owned()]);
576        let many = OneOrMany(vec!["one".to_owned(), "two".to_owned()]);
577
578        assert_eq!(
579            one,
580            serde_json::from_value(serde_json::json!("one")).unwrap()
581        );
582        assert_eq!(
583            one,
584            serde_json::from_value(serde_json::json!(["one"])).unwrap()
585        );
586        assert_eq!(
587            many,
588            serde_json::from_value(serde_json::json!(["one", "two"])).unwrap()
589        );
590        assert_eq!(
591            serde_json::to_value(&one).unwrap(),
592            serde_json::json!("one")
593        );
594        assert_eq!(
595            serde_json::to_value(&many).unwrap(),
596            serde_json::json!(["one", "two"])
597        );
598    }
599
600    #[test]
601    fn extract_claims() {
602        let now = chrono::Utc
603            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
604            .unwrap();
605        let expiration = now + chrono::Duration::microseconds(5 * 60 * 1000 * 1000);
606        let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
607
608        let claims = serde_json::json!({
609            "iss": "https://foo.com",
610            "sub": "johndoe",
611            "aud": ["abcd-efgh"],
612            "iat": 1_516_239_022,
613            "nbf": 1_516_239_022,
614            "exp": 1_516_239_322,
615            "jti": "1122-3344-5566-7788",
616        });
617        let mut claims = serde_json::from_value(claims).unwrap();
618
619        let iss = ISS
620            .extract_required_with_options(&mut claims, "https://foo.com")
621            .unwrap();
622        let sub = SUB.extract_optional(&mut claims).unwrap();
623        let aud = AUD
624            .extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned())
625            .unwrap();
626        let nbf = NBF
627            .extract_optional_with_options(&mut claims, &time_options)
628            .unwrap();
629        let exp = EXP
630            .extract_optional_with_options(&mut claims, &time_options)
631            .unwrap();
632        let iat = IAT
633            .extract_optional_with_options(&mut claims, &time_options)
634            .unwrap();
635        let jti = JTI.extract_optional(&mut claims).unwrap();
636
637        assert_eq!(iss, "https://foo.com".to_owned());
638        assert_eq!(sub, Some("johndoe".to_owned()));
639        assert_eq!(aud.as_deref(), Some(&vec!["abcd-efgh".to_owned()]));
640        assert_eq!(iat.as_deref(), Some(&now));
641        assert_eq!(nbf.as_deref(), Some(&now));
642        assert_eq!(exp.as_deref(), Some(&expiration));
643        assert_eq!(jti, Some("1122-3344-5566-7788".to_owned()));
644
645        assert!(claims.is_empty());
646    }
647
648    #[test]
649    fn time_validation() {
650        let now = chrono::Utc
651            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
652            .unwrap();
653
654        let claims = serde_json::json!({
655            "iat": 1_516_239_022,
656            "nbf": 1_516_239_022,
657            "exp": 1_516_239_322,
658        });
659        let claims: HashMap<String, serde_json::Value> = serde_json::from_value(claims).unwrap();
660
661        // Everything should be fine at this point, the claims iat & nbf == now
662        {
663            let mut claims = claims.clone();
664
665            // so no leeway should be fine as well here
666            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
667            assert!(
668                IAT.extract_required_with_options(&mut claims, &time_options)
669                    .is_ok()
670            );
671            assert!(
672                NBF.extract_required_with_options(&mut claims, &time_options)
673                    .is_ok()
674            );
675            assert!(
676                EXP.extract_required_with_options(&mut claims, &time_options)
677                    .is_ok()
678            );
679        }
680
681        // Let's go back in time a bit
682        let now = now - chrono::Duration::microseconds(60 * 1000 * 1000);
683
684        {
685            // There is now a time variance between the two parties...
686            let mut claims = claims.clone();
687
688            // but no time variance is allowed. "iat" and "nbf" validation will fail
689            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
690            assert!(matches!(
691                IAT.extract_required_with_options(&mut claims, &time_options),
692                Err(ClaimError::ValidationError { claim: "iat", .. }),
693            ));
694            assert!(matches!(
695                NBF.extract_required_with_options(&mut claims, &time_options),
696                Err(ClaimError::ValidationError { claim: "nbf", .. }),
697            ));
698            assert!(
699                EXP.extract_required_with_options(&mut claims, &time_options)
700                    .is_ok()
701            );
702        }
703
704        {
705            // This time, there is a two minute leeway, they all should be fine
706            let mut claims = claims.clone();
707
708            // but no time variance is allowed. "iat" and "nbf" validation will fail
709            let time_options =
710                TimeOptions::new(now).leeway(chrono::Duration::microseconds(2 * 60 * 1000 * 1000));
711            assert!(
712                IAT.extract_required_with_options(&mut claims, &time_options)
713                    .is_ok()
714            );
715            assert!(
716                NBF.extract_required_with_options(&mut claims, &time_options)
717                    .is_ok()
718            );
719            assert!(
720                EXP.extract_required_with_options(&mut claims, &time_options)
721                    .is_ok()
722            );
723        }
724
725        // Let's wait some time so it expires
726        let now = now + chrono::Duration::microseconds((1 + 6) * 60 * 1000 * 1000);
727
728        {
729            // At this point, the claims expired one minute ago
730            let mut claims = claims.clone();
731
732            // but no time variance is allowed. "exp" validation will fail
733            let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
734            assert!(
735                IAT.extract_required_with_options(&mut claims, &time_options)
736                    .is_ok()
737            );
738            assert!(
739                NBF.extract_required_with_options(&mut claims, &time_options)
740                    .is_ok()
741            );
742            assert!(matches!(
743                EXP.extract_required_with_options(&mut claims, &time_options),
744                Err(ClaimError::ValidationError { claim: "exp", .. }),
745            ));
746        }
747
748        {
749            let mut claims = claims;
750
751            // Same, but with a 2 minutes leeway should be fine then
752            let time_options =
753                TimeOptions::new(now).leeway(chrono::Duration::try_minutes(2).unwrap());
754            assert!(
755                IAT.extract_required_with_options(&mut claims, &time_options)
756                    .is_ok()
757            );
758            assert!(
759                NBF.extract_required_with_options(&mut claims, &time_options)
760                    .is_ok()
761            );
762            assert!(
763                EXP.extract_required_with_options(&mut claims, &time_options)
764                    .is_ok()
765            );
766        }
767    }
768
769    #[test]
770    fn invalid_claims() {
771        let now = chrono::Utc
772            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
773            .unwrap();
774        let time_options = TimeOptions::new(now).leeway(chrono::Duration::zero());
775
776        let claims = serde_json::json!({
777            "iss": 123,
778            "sub": 456,
779            "aud": 789,
780            "iat": "123",
781            "nbf": "456",
782            "exp": "789",
783            "jti": 123,
784        });
785        let mut claims = serde_json::from_value(claims).unwrap();
786
787        assert!(matches!(
788            ISS.extract_required_with_options(&mut claims, "https://foo.com"),
789            Err(ClaimError::InvalidClaim("iss"))
790        ));
791        assert!(matches!(
792            SUB.extract_required(&mut claims),
793            Err(ClaimError::InvalidClaim("sub"))
794        ));
795        assert!(matches!(
796            AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
797            Err(ClaimError::InvalidClaim("aud"))
798        ));
799        assert!(matches!(
800            NBF.extract_required_with_options(&mut claims, &time_options),
801            Err(ClaimError::InvalidClaim("nbf"))
802        ));
803        assert!(matches!(
804            EXP.extract_required_with_options(&mut claims, &time_options),
805            Err(ClaimError::InvalidClaim("exp"))
806        ));
807        assert!(matches!(
808            IAT.extract_required_with_options(&mut claims, &time_options),
809            Err(ClaimError::InvalidClaim("iat"))
810        ));
811        assert!(matches!(
812            JTI.extract_required(&mut claims),
813            Err(ClaimError::InvalidClaim("jti"))
814        ));
815    }
816
817    #[test]
818    fn missing_claims() {
819        // Empty claim set
820        let mut claims = HashMap::new();
821
822        assert!(matches!(
823            ISS.extract_required_with_options(&mut claims, "https://foo.com"),
824            Err(ClaimError::MissingClaim("iss"))
825        ));
826        assert!(matches!(
827            SUB.extract_required(&mut claims),
828            Err(ClaimError::MissingClaim("sub"))
829        ));
830        assert!(matches!(
831            AUD.extract_required_with_options(&mut claims, &"abcd-efgh".to_owned()),
832            Err(ClaimError::MissingClaim("aud"))
833        ));
834
835        assert!(matches!(
836            ISS.extract_optional_with_options(&mut claims, "https://foo.com"),
837            Ok(None)
838        ));
839        assert!(matches!(SUB.extract_optional(&mut claims), Ok(None)));
840        assert!(matches!(
841            AUD.extract_optional_with_options(&mut claims, &"abcd-efgh".to_owned()),
842            Ok(None)
843        ));
844    }
845
846    #[test]
847    fn string_eq_validation() {
848        let claims = serde_json::json!({
849            "iss": "https://foo.com",
850        });
851        let mut claims: HashMap<String, serde_json::Value> =
852            serde_json::from_value(claims).unwrap();
853
854        ISS.extract_required_with_options(&mut claims.clone(), "https://foo.com")
855            .unwrap();
856
857        assert!(matches!(
858            ISS.extract_required_with_options(&mut claims, "https://bar.com"),
859            Err(ClaimError::ValidationError { claim: "iss", .. }),
860        ));
861    }
862
863    #[test]
864    fn contains_validation() {
865        let claims = serde_json::json!({
866            "aud": "abcd-efgh",
867        });
868        let mut claims: HashMap<String, serde_json::Value> =
869            serde_json::from_value(claims).unwrap();
870
871        AUD.extract_required_with_options(&mut claims.clone(), &"abcd-efgh".to_owned())
872            .unwrap();
873
874        assert!(matches!(
875            AUD.extract_required_with_options(&mut claims, &"wxyz".to_owned()),
876            Err(ClaimError::ValidationError { claim: "aud", .. }),
877        ));
878    }
879}