1use base64ct::{Base64UrlUnpadded, Encoding};
8use rand::thread_rng;
9use serde::{Serialize, de::DeserializeOwned};
10use signature::{RandomizedSigner, SignatureEncoding, Verifier, rand_core::CryptoRngCore};
11use thiserror::Error;
12
13use super::{header::JsonWebSignatureHeader, raw::RawJwt};
14use crate::{constraints::ConstraintSet, jwk::PublicJsonWebKeySet};
15
16#[derive(Clone, PartialEq, Eq)]
17pub struct Jwt<'a, T> {
18 raw: RawJwt<'a>,
19 header: JsonWebSignatureHeader,
20 payload: T,
21 signature: Vec<u8>,
22}
23
24impl<T> std::fmt::Display for Jwt<'_, T> {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(f, "{}", self.raw)
27 }
28}
29
30impl<T> std::fmt::Debug for Jwt<'_, T>
31where
32 T: std::fmt::Debug,
33{
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("Jwt")
36 .field("raw", &"...")
37 .field("header", &self.header)
38 .field("payload", &self.payload)
39 .field("signature", &"...")
40 .finish()
41 }
42}
43
44#[derive(Debug, Error)]
45pub enum JwtDecodeError {
46 #[error(transparent)]
47 RawDecode {
48 #[from]
49 inner: super::raw::DecodeError,
50 },
51
52 #[error("failed to decode JWT header")]
53 DecodeHeader {
54 #[source]
55 inner: base64ct::Error,
56 },
57
58 #[error("failed to deserialize JWT header")]
59 DeserializeHeader {
60 #[source]
61 inner: serde_json::Error,
62 },
63
64 #[error("failed to decode JWT payload")]
65 DecodePayload {
66 #[source]
67 inner: base64ct::Error,
68 },
69
70 #[error("failed to deserialize JWT payload")]
71 DeserializePayload {
72 #[source]
73 inner: serde_json::Error,
74 },
75
76 #[error("failed to decode JWT signature")]
77 DecodeSignature {
78 #[source]
79 inner: base64ct::Error,
80 },
81}
82
83impl JwtDecodeError {
84 fn decode_header(inner: base64ct::Error) -> Self {
85 Self::DecodeHeader { inner }
86 }
87
88 fn deserialize_header(inner: serde_json::Error) -> Self {
89 Self::DeserializeHeader { inner }
90 }
91
92 fn decode_payload(inner: base64ct::Error) -> Self {
93 Self::DecodePayload { inner }
94 }
95
96 fn deserialize_payload(inner: serde_json::Error) -> Self {
97 Self::DeserializePayload { inner }
98 }
99
100 fn decode_signature(inner: base64ct::Error) -> Self {
101 Self::DecodeSignature { inner }
102 }
103}
104
105impl<'a, T> TryFrom<RawJwt<'a>> for Jwt<'a, T>
106where
107 T: DeserializeOwned,
108{
109 type Error = JwtDecodeError;
110 fn try_from(raw: RawJwt<'a>) -> Result<Self, Self::Error> {
111 let header_reader =
112 base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.header().as_bytes())
113 .map_err(JwtDecodeError::decode_header)?;
114 let header =
115 serde_json::from_reader(header_reader).map_err(JwtDecodeError::deserialize_header)?;
116
117 let payload_reader =
118 base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.payload().as_bytes())
119 .map_err(JwtDecodeError::decode_payload)?;
120 let payload =
121 serde_json::from_reader(payload_reader).map_err(JwtDecodeError::deserialize_payload)?;
122
123 let signature = Base64UrlUnpadded::decode_vec(raw.signature())
124 .map_err(JwtDecodeError::decode_signature)?;
125
126 Ok(Self {
127 raw,
128 header,
129 payload,
130 signature,
131 })
132 }
133}
134
135impl<'a, T> TryFrom<&'a str> for Jwt<'a, T>
136where
137 T: DeserializeOwned,
138{
139 type Error = JwtDecodeError;
140 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
141 let raw = RawJwt::try_from(value)?;
142 Self::try_from(raw)
143 }
144}
145
146impl<T> TryFrom<String> for Jwt<'static, T>
147where
148 T: DeserializeOwned,
149{
150 type Error = JwtDecodeError;
151 fn try_from(value: String) -> Result<Self, Self::Error> {
152 let raw = RawJwt::try_from(value)?;
153 Self::try_from(raw)
154 }
155}
156
157#[derive(Debug, Error)]
158pub enum JwtVerificationError {
159 #[error("failed to parse signature")]
160 ParseSignature,
161
162 #[error("signature verification failed")]
163 Verify {
164 #[source]
165 inner: signature::Error,
166 },
167}
168
169impl JwtVerificationError {
170 #[allow(clippy::needless_pass_by_value)]
171 fn parse_signature<E>(_inner: E) -> Self {
172 Self::ParseSignature
173 }
174
175 fn verify(inner: signature::Error) -> Self {
176 Self::Verify { inner }
177 }
178}
179
180#[derive(Debug, Error, Default)]
181#[error("none of the keys worked")]
182pub struct NoKeyWorked {
183 _inner: (),
184}
185
186impl<'a, T> Jwt<'a, T> {
187 pub fn header(&self) -> &JsonWebSignatureHeader {
189 &self.header
190 }
191
192 pub fn payload(&self) -> &T {
194 &self.payload
195 }
196
197 pub fn into_owned(self) -> Jwt<'static, T> {
198 Jwt {
199 raw: self.raw.into_owned(),
200 header: self.header,
201 payload: self.payload,
202 signature: self.signature,
203 }
204 }
205
206 pub fn verify<K, S>(&self, key: &K) -> Result<(), JwtVerificationError>
212 where
213 K: Verifier<S>,
214 S: SignatureEncoding,
215 {
216 let signature =
217 S::try_from(&self.signature).map_err(JwtVerificationError::parse_signature)?;
218
219 key.verify(self.raw.signed_part().as_bytes(), &signature)
220 .map_err(JwtVerificationError::verify)
221 }
222
223 pub fn verify_with_shared_secret(&self, secret: Vec<u8>) -> Result<(), NoKeyWorked> {
230 let verifier = crate::jwa::SymmetricKey::new_for_alg(secret, self.header().alg())
231 .map_err(|_| NoKeyWorked::default())?;
232
233 self.verify(&verifier).map_err(|_| NoKeyWorked::default())?;
234
235 Ok(())
236 }
237
238 pub fn verify_with_jwks(&self, jwks: &PublicJsonWebKeySet) -> Result<(), NoKeyWorked> {
245 let constraints = ConstraintSet::from(self.header());
246 let candidates = constraints.filter(&**jwks);
247
248 for candidate in candidates {
249 let Ok(key) = crate::jwa::AsymmetricVerifyingKey::from_jwk_and_alg(
250 candidate.params(),
251 self.header().alg(),
252 ) else {
253 continue;
254 };
255
256 if self.verify(&key).is_ok() {
257 return Ok(());
258 }
259 }
260
261 Err(NoKeyWorked::default())
262 }
263
264 pub fn as_str(&'a self) -> &'a str {
266 &self.raw
267 }
268
269 pub fn into_string(self) -> String {
271 self.raw.into()
272 }
273
274 pub fn into_parts(self) -> (JsonWebSignatureHeader, T) {
276 (self.header, self.payload)
277 }
278}
279
280#[derive(Debug, Error)]
281pub enum JwtSignatureError {
282 #[error("failed to serialize header")]
283 EncodeHeader {
284 #[source]
285 inner: serde_json::Error,
286 },
287
288 #[error("failed to serialize payload")]
289 EncodePayload {
290 #[source]
291 inner: serde_json::Error,
292 },
293
294 #[error("failed to sign")]
295 Signature {
296 #[from]
297 inner: signature::Error,
298 },
299}
300
301impl JwtSignatureError {
302 fn encode_header(inner: serde_json::Error) -> Self {
303 Self::EncodeHeader { inner }
304 }
305
306 fn encode_payload(inner: serde_json::Error) -> Self {
307 Self::EncodePayload { inner }
308 }
309}
310
311impl<T> Jwt<'static, T> {
312 pub fn sign<K, S>(
319 header: JsonWebSignatureHeader,
320 payload: T,
321 key: &K,
322 ) -> Result<Self, JwtSignatureError>
323 where
324 K: RandomizedSigner<S>,
325 S: SignatureEncoding,
326 T: Serialize,
327 {
328 #[allow(clippy::disallowed_methods)]
329 Self::sign_with_rng(&mut thread_rng(), header, payload, key)
330 }
331
332 pub fn sign_with_rng<R, K, S>(
339 rng: &mut R,
340 header: JsonWebSignatureHeader,
341 payload: T,
342 key: &K,
343 ) -> Result<Self, JwtSignatureError>
344 where
345 R: CryptoRngCore,
346 K: RandomizedSigner<S>,
347 S: SignatureEncoding,
348 T: Serialize,
349 {
350 let header_ = serde_json::to_vec(&header).map_err(JwtSignatureError::encode_header)?;
351 let header_ = Base64UrlUnpadded::encode_string(&header_);
352
353 let payload_ = serde_json::to_vec(&payload).map_err(JwtSignatureError::encode_payload)?;
354 let payload_ = Base64UrlUnpadded::encode_string(&payload_);
355
356 let mut inner = format!("{header_}.{payload_}");
357
358 let first_dot = header_.len();
359 let second_dot = inner.len();
360
361 let signature = key.try_sign_with_rng(rng, inner.as_bytes())?.to_vec();
362 let signature_ = Base64UrlUnpadded::encode_string(&signature);
363 inner.reserve_exact(1 + signature_.len());
364 inner.push('.');
365 inner.push_str(&signature_);
366
367 let raw = RawJwt::new(inner, first_dot, second_dot);
368
369 Ok(Self {
370 raw,
371 header,
372 payload,
373 signature,
374 })
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 #![allow(clippy::disallowed_methods)]
381 use mas_iana::jose::JsonWebSignatureAlg;
382 use rand::thread_rng;
383
384 use super::*;
385
386 #[test]
387 fn test_jwt_decode() {
388 let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
389 let jwt: Jwt<'_, serde_json::Value> = Jwt::try_from(jwt).unwrap();
390 assert_eq!(jwt.raw.header(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9");
391 assert_eq!(
392 jwt.raw.payload(),
393 "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
394 );
395 assert_eq!(
396 jwt.raw.signature(),
397 "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
398 );
399 assert_eq!(
400 jwt.raw.signed_part(),
401 "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
402 );
403 }
404
405 #[test]
406 fn test_jwt_sign_and_verify() {
407 let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Es256);
408 let payload = serde_json::json!({"hello": "world"});
409
410 let key = ecdsa::SigningKey::<p256::NistP256>::random(&mut thread_rng());
411 let signed = Jwt::sign::<_, ecdsa::Signature<_>>(header, payload, &key).unwrap();
412 signed
413 .verify::<_, ecdsa::Signature<_>>(key.verifying_key())
414 .unwrap();
415 }
416}