mas_jose/jwt/
signed.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use base64ct::{Base64UrlUnpadded, Encoding};
8use rand::thread_rng;
9use serde::{Serialize, de::DeserializeOwned};
10use signature::{RandomizedSigner, SignatureEncoding, Verifier, rand_core::CryptoRngCore};
11use thiserror::Error;
12
13use super::{header::JsonWebSignatureHeader, raw::RawJwt};
14use crate::{constraints::ConstraintSet, jwk::PublicJsonWebKeySet};
15
16#[derive(Clone, PartialEq, Eq)]
17pub struct Jwt<'a, T> {
18    raw: RawJwt<'a>,
19    header: JsonWebSignatureHeader,
20    payload: T,
21    signature: Vec<u8>,
22}
23
24impl<T> std::fmt::Display for Jwt<'_, T> {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(f, "{}", self.raw)
27    }
28}
29
30impl<T> std::fmt::Debug for Jwt<'_, T>
31where
32    T: std::fmt::Debug,
33{
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("Jwt")
36            .field("raw", &"...")
37            .field("header", &self.header)
38            .field("payload", &self.payload)
39            .field("signature", &"...")
40            .finish()
41    }
42}
43
44#[derive(Debug, Error)]
45pub enum JwtDecodeError {
46    #[error(transparent)]
47    RawDecode {
48        #[from]
49        inner: super::raw::DecodeError,
50    },
51
52    #[error("failed to decode JWT header")]
53    DecodeHeader {
54        #[source]
55        inner: base64ct::Error,
56    },
57
58    #[error("failed to deserialize JWT header")]
59    DeserializeHeader {
60        #[source]
61        inner: serde_json::Error,
62    },
63
64    #[error("failed to decode JWT payload")]
65    DecodePayload {
66        #[source]
67        inner: base64ct::Error,
68    },
69
70    #[error("failed to deserialize JWT payload")]
71    DeserializePayload {
72        #[source]
73        inner: serde_json::Error,
74    },
75
76    #[error("failed to decode JWT signature")]
77    DecodeSignature {
78        #[source]
79        inner: base64ct::Error,
80    },
81}
82
83impl JwtDecodeError {
84    fn decode_header(inner: base64ct::Error) -> Self {
85        Self::DecodeHeader { inner }
86    }
87
88    fn deserialize_header(inner: serde_json::Error) -> Self {
89        Self::DeserializeHeader { inner }
90    }
91
92    fn decode_payload(inner: base64ct::Error) -> Self {
93        Self::DecodePayload { inner }
94    }
95
96    fn deserialize_payload(inner: serde_json::Error) -> Self {
97        Self::DeserializePayload { inner }
98    }
99
100    fn decode_signature(inner: base64ct::Error) -> Self {
101        Self::DecodeSignature { inner }
102    }
103}
104
105impl<'a, T> TryFrom<RawJwt<'a>> for Jwt<'a, T>
106where
107    T: DeserializeOwned,
108{
109    type Error = JwtDecodeError;
110    fn try_from(raw: RawJwt<'a>) -> Result<Self, Self::Error> {
111        let header_reader =
112            base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.header().as_bytes())
113                .map_err(JwtDecodeError::decode_header)?;
114        let header =
115            serde_json::from_reader(header_reader).map_err(JwtDecodeError::deserialize_header)?;
116
117        let payload_reader =
118            base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.payload().as_bytes())
119                .map_err(JwtDecodeError::decode_payload)?;
120        let payload =
121            serde_json::from_reader(payload_reader).map_err(JwtDecodeError::deserialize_payload)?;
122
123        let signature = Base64UrlUnpadded::decode_vec(raw.signature())
124            .map_err(JwtDecodeError::decode_signature)?;
125
126        Ok(Self {
127            raw,
128            header,
129            payload,
130            signature,
131        })
132    }
133}
134
135impl<'a, T> TryFrom<&'a str> for Jwt<'a, T>
136where
137    T: DeserializeOwned,
138{
139    type Error = JwtDecodeError;
140    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
141        let raw = RawJwt::try_from(value)?;
142        Self::try_from(raw)
143    }
144}
145
146impl<T> TryFrom<String> for Jwt<'static, T>
147where
148    T: DeserializeOwned,
149{
150    type Error = JwtDecodeError;
151    fn try_from(value: String) -> Result<Self, Self::Error> {
152        let raw = RawJwt::try_from(value)?;
153        Self::try_from(raw)
154    }
155}
156
157#[derive(Debug, Error)]
158pub enum JwtVerificationError {
159    #[error("failed to parse signature")]
160    ParseSignature,
161
162    #[error("signature verification failed")]
163    Verify {
164        #[source]
165        inner: signature::Error,
166    },
167}
168
169impl JwtVerificationError {
170    #[allow(clippy::needless_pass_by_value)]
171    fn parse_signature<E>(_inner: E) -> Self {
172        Self::ParseSignature
173    }
174
175    fn verify(inner: signature::Error) -> Self {
176        Self::Verify { inner }
177    }
178}
179
180#[derive(Debug, Error, Default)]
181#[error("none of the keys worked")]
182pub struct NoKeyWorked {
183    _inner: (),
184}
185
186impl<'a, T> Jwt<'a, T> {
187    /// Get the JWT header
188    pub fn header(&self) -> &JsonWebSignatureHeader {
189        &self.header
190    }
191
192    /// Get the JWT payload
193    pub fn payload(&self) -> &T {
194        &self.payload
195    }
196
197    pub fn into_owned(self) -> Jwt<'static, T> {
198        Jwt {
199            raw: self.raw.into_owned(),
200            header: self.header,
201            payload: self.payload,
202            signature: self.signature,
203        }
204    }
205
206    /// Verify the signature of this JWT using the given key.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if the signature is invalid.
211    pub fn verify<K, S>(&self, key: &K) -> Result<(), JwtVerificationError>
212    where
213        K: Verifier<S>,
214        S: SignatureEncoding,
215    {
216        let signature =
217            S::try_from(&self.signature).map_err(JwtVerificationError::parse_signature)?;
218
219        key.verify(self.raw.signed_part().as_bytes(), &signature)
220            .map_err(JwtVerificationError::verify)
221    }
222
223    /// Verify the signature of this JWT using the given symmetric key.
224    ///
225    /// # Errors
226    ///
227    /// Returns an error if the signature is invalid or if the algorithm is not
228    /// supported.
229    pub fn verify_with_shared_secret(&self, secret: Vec<u8>) -> Result<(), NoKeyWorked> {
230        let verifier = crate::jwa::SymmetricKey::new_for_alg(secret, self.header().alg())
231            .map_err(|_| NoKeyWorked::default())?;
232
233        self.verify(&verifier).map_err(|_| NoKeyWorked::default())?;
234
235        Ok(())
236    }
237
238    /// Verify the signature of this JWT using the given JWKS.
239    ///
240    /// # Errors
241    ///
242    /// Returns an error if the signature is invalid, if no key matches the
243    /// constraints, or if the algorithm is not supported.
244    pub fn verify_with_jwks(&self, jwks: &PublicJsonWebKeySet) -> Result<(), NoKeyWorked> {
245        let constraints = ConstraintSet::from(self.header());
246        let candidates = constraints.filter(&**jwks);
247
248        for candidate in candidates {
249            let Ok(key) = crate::jwa::AsymmetricVerifyingKey::from_jwk_and_alg(
250                candidate.params(),
251                self.header().alg(),
252            ) else {
253                continue;
254            };
255
256            if self.verify(&key).is_ok() {
257                return Ok(());
258            }
259        }
260
261        Err(NoKeyWorked::default())
262    }
263
264    /// Get the raw JWT string as a borrowed [`str`]
265    pub fn as_str(&'a self) -> &'a str {
266        &self.raw
267    }
268
269    /// Get the raw JWT string as an owned [`String`]
270    pub fn into_string(self) -> String {
271        self.raw.into()
272    }
273
274    /// Split the JWT into its parts (header and payload).
275    pub fn into_parts(self) -> (JsonWebSignatureHeader, T) {
276        (self.header, self.payload)
277    }
278}
279
280#[derive(Debug, Error)]
281pub enum JwtSignatureError {
282    #[error("failed to serialize header")]
283    EncodeHeader {
284        #[source]
285        inner: serde_json::Error,
286    },
287
288    #[error("failed to serialize payload")]
289    EncodePayload {
290        #[source]
291        inner: serde_json::Error,
292    },
293
294    #[error("failed to sign")]
295    Signature {
296        #[from]
297        inner: signature::Error,
298    },
299}
300
301impl JwtSignatureError {
302    fn encode_header(inner: serde_json::Error) -> Self {
303        Self::EncodeHeader { inner }
304    }
305
306    fn encode_payload(inner: serde_json::Error) -> Self {
307        Self::EncodePayload { inner }
308    }
309}
310
311impl<T> Jwt<'static, T> {
312    /// Sign the given payload with the given key.
313    ///
314    /// # Errors
315    ///
316    /// Returns an error if the payload could not be serialized or if the key
317    /// could not sign the payload.
318    pub fn sign<K, S>(
319        header: JsonWebSignatureHeader,
320        payload: T,
321        key: &K,
322    ) -> Result<Self, JwtSignatureError>
323    where
324        K: RandomizedSigner<S>,
325        S: SignatureEncoding,
326        T: Serialize,
327    {
328        #[allow(clippy::disallowed_methods)]
329        Self::sign_with_rng(&mut thread_rng(), header, payload, key)
330    }
331
332    /// Sign the given payload with the given key using the given RNG.
333    ///
334    /// # Errors
335    ///
336    /// Returns an error if the payload could not be serialized or if the key
337    /// could not sign the payload.
338    pub fn sign_with_rng<R, K, S>(
339        rng: &mut R,
340        header: JsonWebSignatureHeader,
341        payload: T,
342        key: &K,
343    ) -> Result<Self, JwtSignatureError>
344    where
345        R: CryptoRngCore,
346        K: RandomizedSigner<S>,
347        S: SignatureEncoding,
348        T: Serialize,
349    {
350        let header_ = serde_json::to_vec(&header).map_err(JwtSignatureError::encode_header)?;
351        let header_ = Base64UrlUnpadded::encode_string(&header_);
352
353        let payload_ = serde_json::to_vec(&payload).map_err(JwtSignatureError::encode_payload)?;
354        let payload_ = Base64UrlUnpadded::encode_string(&payload_);
355
356        let mut inner = format!("{header_}.{payload_}");
357
358        let first_dot = header_.len();
359        let second_dot = inner.len();
360
361        let signature = key.try_sign_with_rng(rng, inner.as_bytes())?.to_vec();
362        let signature_ = Base64UrlUnpadded::encode_string(&signature);
363        inner.reserve_exact(1 + signature_.len());
364        inner.push('.');
365        inner.push_str(&signature_);
366
367        let raw = RawJwt::new(inner, first_dot, second_dot);
368
369        Ok(Self {
370            raw,
371            header,
372            payload,
373            signature,
374        })
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    #![allow(clippy::disallowed_methods)]
381    use mas_iana::jose::JsonWebSignatureAlg;
382    use rand::thread_rng;
383
384    use super::*;
385
386    #[test]
387    fn test_jwt_decode() {
388        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
389        let jwt: Jwt<'_, serde_json::Value> = Jwt::try_from(jwt).unwrap();
390        assert_eq!(jwt.raw.header(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9");
391        assert_eq!(
392            jwt.raw.payload(),
393            "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
394        );
395        assert_eq!(
396            jwt.raw.signature(),
397            "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
398        );
399        assert_eq!(
400            jwt.raw.signed_part(),
401            "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
402        );
403    }
404
405    #[test]
406    fn test_jwt_sign_and_verify() {
407        let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Es256);
408        let payload = serde_json::json!({"hello": "world"});
409
410        let key = ecdsa::SigningKey::<p256::NistP256>::random(&mut thread_rng());
411        let signed = Jwt::sign::<_, ecdsa::Signature<_>>(header, payload, &key).unwrap();
412        signed
413            .verify::<_, ecdsa::Signature<_>>(key.verifying_key())
414            .unwrap();
415    }
416}