1use std::num::NonZeroU32;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
13};
14use mas_iana::oauth::PkceCodeChallengeMethod;
15use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
16use oauth2_types::{requests::ResponseMode, scope::Scope};
17use rand::RngCore;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use url::Url;
21use uuid::Uuid;
22
23use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
24
25pub struct PgOAuth2AuthorizationGrantRepository<'c> {
28 conn: &'c mut PgConnection,
29}
30
31impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
32 pub fn new(conn: &'c mut PgConnection) -> Self {
35 Self { conn }
36 }
37}
38
39#[allow(clippy::struct_excessive_bools)]
40struct GrantLookup {
41 oauth2_authorization_grant_id: Uuid,
42 created_at: DateTime<Utc>,
43 cancelled_at: Option<DateTime<Utc>>,
44 fulfilled_at: Option<DateTime<Utc>>,
45 exchanged_at: Option<DateTime<Utc>>,
46 scope: String,
47 state: Option<String>,
48 nonce: Option<String>,
49 redirect_uri: String,
50 response_mode: String,
51 max_age: Option<i32>,
52 response_type_code: bool,
53 response_type_id_token: bool,
54 authorization_code: Option<String>,
55 code_challenge: Option<String>,
56 code_challenge_method: Option<String>,
57 requires_consent: bool,
58 login_hint: Option<String>,
59 oauth2_client_id: Uuid,
60 oauth2_session_id: Option<Uuid>,
61}
62
63impl TryFrom<GrantLookup> for AuthorizationGrant {
64 type Error = DatabaseInconsistencyError;
65
66 #[allow(clippy::too_many_lines)]
67 fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
68 let id = value.oauth2_authorization_grant_id.into();
69 let scope: Scope = value.scope.parse().map_err(|e| {
70 DatabaseInconsistencyError::on("oauth2_authorization_grants")
71 .column("scope")
72 .row(id)
73 .source(e)
74 })?;
75
76 let stage = match (
77 value.fulfilled_at,
78 value.exchanged_at,
79 value.cancelled_at,
80 value.oauth2_session_id,
81 ) {
82 (None, None, None, None) => AuthorizationGrantStage::Pending,
83 (Some(fulfilled_at), None, None, Some(session_id)) => {
84 AuthorizationGrantStage::Fulfilled {
85 session_id: session_id.into(),
86 fulfilled_at,
87 }
88 }
89 (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
90 AuthorizationGrantStage::Exchanged {
91 session_id: session_id.into(),
92 fulfilled_at,
93 exchanged_at,
94 }
95 }
96 (None, None, Some(cancelled_at), None) => {
97 AuthorizationGrantStage::Cancelled { cancelled_at }
98 }
99 _ => {
100 return Err(
101 DatabaseInconsistencyError::on("oauth2_authorization_grants")
102 .column("stage")
103 .row(id),
104 );
105 }
106 };
107
108 let pkce = match (value.code_challenge, value.code_challenge_method) {
109 (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
110 Some(Pkce {
111 challenge_method: PkceCodeChallengeMethod::Plain,
112 challenge,
113 })
114 }
115 (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
116 challenge_method: PkceCodeChallengeMethod::S256,
117 challenge,
118 }),
119 (None, None) => None,
120 _ => {
121 return Err(
122 DatabaseInconsistencyError::on("oauth2_authorization_grants")
123 .column("code_challenge_method")
124 .row(id),
125 );
126 }
127 };
128
129 let code: Option<AuthorizationCode> =
130 match (value.response_type_code, value.authorization_code, pkce) {
131 (false, None, None) => None,
132 (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
133 _ => {
134 return Err(
135 DatabaseInconsistencyError::on("oauth2_authorization_grants")
136 .column("authorization_code")
137 .row(id),
138 );
139 }
140 };
141
142 let redirect_uri = value.redirect_uri.parse().map_err(|e| {
143 DatabaseInconsistencyError::on("oauth2_authorization_grants")
144 .column("redirect_uri")
145 .row(id)
146 .source(e)
147 })?;
148
149 let response_mode = value.response_mode.parse().map_err(|e| {
150 DatabaseInconsistencyError::on("oauth2_authorization_grants")
151 .column("response_mode")
152 .row(id)
153 .source(e)
154 })?;
155
156 let max_age = value
157 .max_age
158 .map(u32::try_from)
159 .transpose()
160 .map_err(|e| {
161 DatabaseInconsistencyError::on("oauth2_authorization_grants")
162 .column("max_age")
163 .row(id)
164 .source(e)
165 })?
166 .map(NonZeroU32::try_from)
167 .transpose()
168 .map_err(|e| {
169 DatabaseInconsistencyError::on("oauth2_authorization_grants")
170 .column("max_age")
171 .row(id)
172 .source(e)
173 })?;
174
175 Ok(AuthorizationGrant {
176 id,
177 stage,
178 client_id: value.oauth2_client_id.into(),
179 code,
180 scope,
181 state: value.state,
182 nonce: value.nonce,
183 max_age,
184 response_mode,
185 redirect_uri,
186 created_at: value.created_at,
187 response_type_id_token: value.response_type_id_token,
188 requires_consent: value.requires_consent,
189 login_hint: value.login_hint,
190 })
191 }
192}
193
194#[async_trait]
195impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
196 type Error = DatabaseError;
197
198 #[tracing::instrument(
199 name = "db.oauth2_authorization_grant.add",
200 skip_all,
201 fields(
202 db.query.text,
203 grant.id,
204 grant.scope = %scope,
205 %client.id,
206 ),
207 err,
208 )]
209 async fn add(
210 &mut self,
211 rng: &mut (dyn RngCore + Send),
212 clock: &dyn Clock,
213 client: &Client,
214 redirect_uri: Url,
215 scope: Scope,
216 code: Option<AuthorizationCode>,
217 state: Option<String>,
218 nonce: Option<String>,
219 max_age: Option<NonZeroU32>,
220 response_mode: ResponseMode,
221 response_type_id_token: bool,
222 requires_consent: bool,
223 login_hint: Option<String>,
224 ) -> Result<AuthorizationGrant, Self::Error> {
225 let code_challenge = code
226 .as_ref()
227 .and_then(|c| c.pkce.as_ref())
228 .map(|p| &p.challenge);
229 let code_challenge_method = code
230 .as_ref()
231 .and_then(|c| c.pkce.as_ref())
232 .map(|p| p.challenge_method.to_string());
233 let max_age_i32 = max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX));
235 let code_str = code.as_ref().map(|c| &c.code);
236
237 let created_at = clock.now();
238 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
239 tracing::Span::current().record("grant.id", tracing::field::display(id));
240
241 sqlx::query!(
242 r#"
243 INSERT INTO oauth2_authorization_grants (
244 oauth2_authorization_grant_id,
245 oauth2_client_id,
246 redirect_uri,
247 scope,
248 state,
249 nonce,
250 max_age,
251 response_mode,
252 code_challenge,
253 code_challenge_method,
254 response_type_code,
255 response_type_id_token,
256 authorization_code,
257 requires_consent,
258 login_hint,
259 created_at
260 )
261 VALUES
262 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
263 "#,
264 Uuid::from(id),
265 Uuid::from(client.id),
266 redirect_uri.to_string(),
267 scope.to_string(),
268 state,
269 nonce,
270 max_age_i32,
271 response_mode.to_string(),
272 code_challenge,
273 code_challenge_method,
274 code.is_some(),
275 response_type_id_token,
276 code_str,
277 requires_consent,
278 login_hint,
279 created_at,
280 )
281 .traced()
282 .execute(&mut *self.conn)
283 .await?;
284
285 Ok(AuthorizationGrant {
286 id,
287 stage: AuthorizationGrantStage::Pending,
288 code,
289 redirect_uri,
290 client_id: client.id,
291 scope,
292 state,
293 nonce,
294 max_age,
295 response_mode,
296 created_at,
297 response_type_id_token,
298 requires_consent,
299 login_hint,
300 })
301 }
302
303 #[tracing::instrument(
304 name = "db.oauth2_authorization_grant.lookup",
305 skip_all,
306 fields(
307 db.query.text,
308 grant.id = %id,
309 ),
310 err,
311 )]
312 async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
313 let res = sqlx::query_as!(
314 GrantLookup,
315 r#"
316 SELECT oauth2_authorization_grant_id
317 , created_at
318 , cancelled_at
319 , fulfilled_at
320 , exchanged_at
321 , scope
322 , state
323 , redirect_uri
324 , response_mode
325 , nonce
326 , max_age
327 , oauth2_client_id
328 , authorization_code
329 , response_type_code
330 , response_type_id_token
331 , code_challenge
332 , code_challenge_method
333 , requires_consent
334 , login_hint
335 , oauth2_session_id
336 FROM
337 oauth2_authorization_grants
338
339 WHERE oauth2_authorization_grant_id = $1
340 "#,
341 Uuid::from(id),
342 )
343 .traced()
344 .fetch_optional(&mut *self.conn)
345 .await?;
346
347 let Some(res) = res else { return Ok(None) };
348
349 Ok(Some(res.try_into()?))
350 }
351
352 #[tracing::instrument(
353 name = "db.oauth2_authorization_grant.find_by_code",
354 skip_all,
355 fields(
356 db.query.text,
357 ),
358 err,
359 )]
360 async fn find_by_code(
361 &mut self,
362 code: &str,
363 ) -> Result<Option<AuthorizationGrant>, Self::Error> {
364 let res = sqlx::query_as!(
365 GrantLookup,
366 r#"
367 SELECT oauth2_authorization_grant_id
368 , created_at
369 , cancelled_at
370 , fulfilled_at
371 , exchanged_at
372 , scope
373 , state
374 , redirect_uri
375 , response_mode
376 , nonce
377 , max_age
378 , oauth2_client_id
379 , authorization_code
380 , response_type_code
381 , response_type_id_token
382 , code_challenge
383 , code_challenge_method
384 , requires_consent
385 , login_hint
386 , oauth2_session_id
387 FROM
388 oauth2_authorization_grants
389
390 WHERE authorization_code = $1
391 "#,
392 code,
393 )
394 .traced()
395 .fetch_optional(&mut *self.conn)
396 .await?;
397
398 let Some(res) = res else { return Ok(None) };
399
400 Ok(Some(res.try_into()?))
401 }
402
403 #[tracing::instrument(
404 name = "db.oauth2_authorization_grant.fulfill",
405 skip_all,
406 fields(
407 db.query.text,
408 %grant.id,
409 client.id = %grant.client_id,
410 %session.id,
411 ),
412 err,
413 )]
414 async fn fulfill(
415 &mut self,
416 clock: &dyn Clock,
417 session: &Session,
418 grant: AuthorizationGrant,
419 ) -> Result<AuthorizationGrant, Self::Error> {
420 let fulfilled_at = clock.now();
421 let res = sqlx::query!(
422 r#"
423 UPDATE oauth2_authorization_grants
424 SET fulfilled_at = $2
425 , oauth2_session_id = $3
426 WHERE oauth2_authorization_grant_id = $1
427 "#,
428 Uuid::from(grant.id),
429 fulfilled_at,
430 Uuid::from(session.id),
431 )
432 .traced()
433 .execute(&mut *self.conn)
434 .await?;
435
436 DatabaseError::ensure_affected_rows(&res, 1)?;
437
438 let grant = grant
440 .fulfill(fulfilled_at, session)
441 .map_err(DatabaseError::to_invalid_operation)?;
442
443 Ok(grant)
444 }
445
446 #[tracing::instrument(
447 name = "db.oauth2_authorization_grant.exchange",
448 skip_all,
449 fields(
450 db.query.text,
451 %grant.id,
452 client.id = %grant.client_id,
453 ),
454 err,
455 )]
456 async fn exchange(
457 &mut self,
458 clock: &dyn Clock,
459 grant: AuthorizationGrant,
460 ) -> Result<AuthorizationGrant, Self::Error> {
461 let exchanged_at = clock.now();
462 let res = sqlx::query!(
463 r#"
464 UPDATE oauth2_authorization_grants
465 SET exchanged_at = $2
466 WHERE oauth2_authorization_grant_id = $1
467 "#,
468 Uuid::from(grant.id),
469 exchanged_at,
470 )
471 .traced()
472 .execute(&mut *self.conn)
473 .await?;
474
475 DatabaseError::ensure_affected_rows(&res, 1)?;
476
477 let grant = grant
478 .exchange(exchanged_at)
479 .map_err(DatabaseError::to_invalid_operation)?;
480
481 Ok(grant)
482 }
483
484 #[tracing::instrument(
485 name = "db.oauth2_authorization_grant.give_consent",
486 skip_all,
487 fields(
488 db.query.text,
489 %grant.id,
490 client.id = %grant.client_id,
491 ),
492 err,
493 )]
494 async fn give_consent(
495 &mut self,
496 mut grant: AuthorizationGrant,
497 ) -> Result<AuthorizationGrant, Self::Error> {
498 sqlx::query!(
499 r#"
500 UPDATE oauth2_authorization_grants AS og
501 SET
502 requires_consent = 'f'
503 WHERE
504 og.oauth2_authorization_grant_id = $1
505 "#,
506 Uuid::from(grant.id),
507 )
508 .traced()
509 .execute(&mut *self.conn)
510 .await?;
511
512 grant.requires_consent = false;
513
514 Ok(grant)
515 }
516}