mas_oidc_client/requests/
authorization_code.rs1use std::{collections::HashSet, num::NonZeroU32};
12
13use base64ct::{Base64UrlUnpadded, Encoding};
14use chrono::{DateTime, Utc};
15use language_tags::LanguageTag;
16use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
17use mas_jose::claims::{self, TokenHash};
18use oauth2_types::{
19 pkce,
20 prelude::CodeChallengeMethodExt,
21 requests::{
22 AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
23 Display, Prompt, ResponseMode,
24 },
25 scope::{OPENID, Scope},
26};
27use rand::{
28 Rng,
29 distributions::{Alphanumeric, DistString},
30};
31use serde::Serialize;
32use url::Url;
33
34use super::jose::JwtVerificationData;
35use crate::{
36 error::{AuthorizationError, IdTokenError, TokenAuthorizationCodeError},
37 requests::{jose::verify_id_token, token::request_access_token},
38 types::{IdToken, client_credentials::ClientCredentials},
39};
40
41#[derive(Debug, Clone)]
43pub struct AuthorizationRequestData {
44 pub client_id: String,
46
47 pub scope: Scope,
52
53 pub redirect_uri: Url,
57
58 pub code_challenge_methods_supported: Option<Vec<PkceCodeChallengeMethod>>,
63
64 pub display: Option<Display>,
67
68 pub prompt: Option<Vec<Prompt>>,
73
74 pub max_age: Option<NonZeroU32>,
77
78 pub ui_locales: Option<Vec<LanguageTag>>,
80
81 pub id_token_hint: Option<String>,
85
86 pub login_hint: Option<String>,
89
90 pub acr_values: Option<HashSet<String>>,
92
93 pub response_mode: Option<ResponseMode>,
95}
96
97impl AuthorizationRequestData {
98 #[must_use]
101 pub fn new(client_id: String, scope: Scope, redirect_uri: Url) -> Self {
102 Self {
103 client_id,
104 scope,
105 redirect_uri,
106 code_challenge_methods_supported: None,
107 display: None,
108 prompt: None,
109 max_age: None,
110 ui_locales: None,
111 id_token_hint: None,
112 login_hint: None,
113 acr_values: None,
114 response_mode: None,
115 }
116 }
117
118 #[must_use]
121 pub fn with_code_challenge_methods_supported(
122 mut self,
123 code_challenge_methods_supported: Vec<PkceCodeChallengeMethod>,
124 ) -> Self {
125 self.code_challenge_methods_supported = Some(code_challenge_methods_supported);
126 self
127 }
128
129 #[must_use]
131 pub fn with_display(mut self, display: Display) -> Self {
132 self.display = Some(display);
133 self
134 }
135
136 #[must_use]
138 pub fn with_prompt(mut self, prompt: Vec<Prompt>) -> Self {
139 self.prompt = Some(prompt);
140 self
141 }
142
143 #[must_use]
145 pub fn with_max_age(mut self, max_age: NonZeroU32) -> Self {
146 self.max_age = Some(max_age);
147 self
148 }
149
150 #[must_use]
152 pub fn with_ui_locales(mut self, ui_locales: Vec<LanguageTag>) -> Self {
153 self.ui_locales = Some(ui_locales);
154 self
155 }
156
157 #[must_use]
159 pub fn with_id_token_hint(mut self, id_token_hint: String) -> Self {
160 self.id_token_hint = Some(id_token_hint);
161 self
162 }
163
164 #[must_use]
166 pub fn with_login_hint(mut self, login_hint: String) -> Self {
167 self.login_hint = Some(login_hint);
168 self
169 }
170
171 #[must_use]
173 pub fn with_acr_values(mut self, acr_values: HashSet<String>) -> Self {
174 self.acr_values = Some(acr_values);
175 self
176 }
177
178 #[must_use]
180 pub fn with_response_mode(mut self, response_mode: ResponseMode) -> Self {
181 self.response_mode = Some(response_mode);
182 self
183 }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq)]
189pub struct AuthorizationValidationData {
190 pub state: String,
192
193 pub nonce: String,
195
196 pub redirect_uri: Url,
198
199 pub code_challenge_verifier: Option<String>,
201}
202
203#[derive(Clone, Serialize)]
204struct FullAuthorizationRequest {
205 #[serde(flatten)]
206 inner: AuthorizationRequest,
207
208 #[serde(flatten, skip_serializing_if = "Option::is_none")]
209 pkce: Option<pkce::AuthorizationRequest>,
210}
211
212fn build_authorization_request(
214 authorization_data: AuthorizationRequestData,
215 rng: &mut impl Rng,
216) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> {
217 let AuthorizationRequestData {
218 client_id,
219 mut scope,
220 redirect_uri,
221 code_challenge_methods_supported,
222 display,
223 prompt,
224 max_age,
225 ui_locales,
226 id_token_hint,
227 login_hint,
228 acr_values,
229 response_mode,
230 } = authorization_data;
231
232 let state = Alphanumeric.sample_string(rng, 16);
234 let nonce = Alphanumeric.sample_string(rng, 16);
235
236 let (pkce, code_challenge_verifier) = if code_challenge_methods_supported
238 .iter()
239 .any(|methods| methods.contains(&PkceCodeChallengeMethod::S256))
240 {
241 let mut verifier = [0u8; 32];
242 rng.fill(&mut verifier);
243
244 let method = PkceCodeChallengeMethod::S256;
245 let verifier = Base64UrlUnpadded::encode_string(&verifier);
246 let code_challenge = method.compute_challenge(&verifier)?.into();
247
248 let pkce = pkce::AuthorizationRequest {
249 code_challenge_method: method,
250 code_challenge,
251 };
252
253 (Some(pkce), Some(verifier))
254 } else {
255 (None, None)
256 };
257
258 scope.insert(OPENID);
259
260 let auth_request = FullAuthorizationRequest {
261 inner: AuthorizationRequest {
262 response_type: OAuthAuthorizationEndpointResponseType::Code.into(),
263 client_id,
264 redirect_uri: Some(redirect_uri.clone()),
265 scope,
266 state: Some(state.clone()),
267 response_mode,
268 nonce: Some(nonce.clone()),
269 display,
270 prompt,
271 max_age,
272 ui_locales,
273 id_token_hint,
274 login_hint,
275 acr_values,
276 request: None,
277 request_uri: None,
278 registration: None,
279 },
280 pkce,
281 };
282
283 let auth_data = AuthorizationValidationData {
284 state,
285 nonce,
286 redirect_uri,
287 code_challenge_verifier,
288 };
289
290 Ok((auth_request, auth_data))
291}
292
293pub fn build_authorization_url(
324 authorization_endpoint: Url,
325 authorization_data: AuthorizationRequestData,
326 rng: &mut impl Rng,
327) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
328 tracing::debug!(
329 scope = ?authorization_data.scope,
330 "Authorizing..."
331 );
332
333 let (authorization_request, validation_data) =
334 build_authorization_request(authorization_data, rng)?;
335
336 let authorization_query = serde_urlencoded::to_string(authorization_request)?;
337
338 let mut authorization_url = authorization_endpoint;
339
340 let mut full_query = authorization_url
342 .query()
343 .map(ToOwned::to_owned)
344 .unwrap_or_default();
345 if !full_query.is_empty() {
346 full_query.push('&');
347 }
348 full_query.push_str(&authorization_query);
349
350 authorization_url.set_query(Some(&full_query));
351
352 Ok((authorization_url, validation_data))
353}
354
355#[allow(clippy::too_many_arguments)]
393#[tracing::instrument(skip_all, fields(token_endpoint))]
394pub async fn access_token_with_authorization_code(
395 http_client: &reqwest::Client,
396 client_credentials: ClientCredentials,
397 token_endpoint: &Url,
398 code: String,
399 validation_data: AuthorizationValidationData,
400 id_token_verification_data: Option<JwtVerificationData<'_>>,
401 now: DateTime<Utc>,
402 rng: &mut impl Rng,
403) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenAuthorizationCodeError> {
404 tracing::debug!("Exchanging authorization code for access token...");
405
406 let token_response = request_access_token(
407 http_client,
408 client_credentials,
409 token_endpoint,
410 AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
411 code: code.clone(),
412 redirect_uri: Some(validation_data.redirect_uri),
413 code_verifier: validation_data.code_challenge_verifier,
414 }),
415 now,
416 rng,
417 )
418 .await?;
419
420 let id_token = if let Some(verification_data) = id_token_verification_data {
421 let signing_alg = verification_data.signing_algorithm;
422
423 let id_token = token_response
424 .id_token
425 .as_deref()
426 .ok_or(IdTokenError::MissingIdToken)?;
427
428 let id_token = verify_id_token(id_token, verification_data, None, now)?;
429
430 let mut claims = id_token.payload().clone();
431
432 claims::AT_HASH
434 .extract_optional_with_options(
435 &mut claims,
436 TokenHash::new(signing_alg, &token_response.access_token),
437 )
438 .map_err(IdTokenError::from)?;
439
440 claims::C_HASH
442 .extract_optional_with_options(&mut claims, TokenHash::new(signing_alg, &code))
443 .map_err(IdTokenError::from)?;
444
445 claims::NONCE
447 .extract_required_with_options(&mut claims, validation_data.nonce.as_str())
448 .map_err(IdTokenError::from)?;
449
450 Some(id_token.into_owned())
451 } else {
452 None
453 };
454
455 Ok((token_response, id_token))
456}