1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{Clock, UpstreamOAuthLink, UpstreamOAuthProvider, User};
11use mas_storage::{
12 Page, Pagination,
13 pagination::Node,
14 upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
15};
16use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use tracing::Instrument;
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26 DatabaseError,
27 filter::{Filter, StatementExt},
28 iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
29 pagination::QueryBuilderExt,
30 tracing::ExecuteExt,
31};
32
33pub struct PgUpstreamOAuthLinkRepository<'c> {
36 conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthLinkRepository<'c> {
40 pub fn new(conn: &'c mut PgConnection) -> Self {
43 Self { conn }
44 }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct LinkLookup {
50 upstream_oauth_link_id: Uuid,
51 upstream_oauth_provider_id: Uuid,
52 user_id: Option<Uuid>,
53 subject: String,
54 human_account_name: Option<String>,
55 created_at: DateTime<Utc>,
56}
57
58impl Node<Ulid> for LinkLookup {
59 fn cursor(&self) -> Ulid {
60 self.upstream_oauth_link_id.into()
61 }
62}
63
64impl From<LinkLookup> for UpstreamOAuthLink {
65 fn from(value: LinkLookup) -> Self {
66 UpstreamOAuthLink {
67 id: Ulid::from(value.upstream_oauth_link_id),
68 provider_id: Ulid::from(value.upstream_oauth_provider_id),
69 user_id: value.user_id.map(Ulid::from),
70 subject: value.subject,
71 human_account_name: value.human_account_name,
72 created_at: value.created_at,
73 }
74 }
75}
76
77impl Filter for UpstreamOAuthLinkFilter<'_> {
78 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
79 sea_query::Condition::all()
80 .add_option(self.user().map(|user| {
81 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
82 .eq(Uuid::from(user.id))
83 }))
84 .add_option(self.provider().map(|provider| {
85 Expr::col((
86 UpstreamOAuthLinks::Table,
87 UpstreamOAuthLinks::UpstreamOAuthProviderId,
88 ))
89 .eq(Uuid::from(provider.id))
90 }))
91 .add_option(self.provider_enabled().map(|enabled| {
92 Expr::col((
93 UpstreamOAuthLinks::Table,
94 UpstreamOAuthLinks::UpstreamOAuthProviderId,
95 ))
96 .eq(Expr::any(
97 Query::select()
98 .expr(Expr::col((
99 UpstreamOAuthProviders::Table,
100 UpstreamOAuthProviders::UpstreamOAuthProviderId,
101 )))
102 .from(UpstreamOAuthProviders::Table)
103 .and_where(
104 Expr::col((
105 UpstreamOAuthProviders::Table,
106 UpstreamOAuthProviders::DisabledAt,
107 ))
108 .is_null()
109 .eq(enabled),
110 )
111 .take(),
112 ))
113 }))
114 .add_option(self.subject().map(|subject| {
115 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
116 }))
117 }
118}
119
120#[async_trait]
121impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
122 type Error = DatabaseError;
123
124 #[tracing::instrument(
125 name = "db.upstream_oauth_link.lookup",
126 skip_all,
127 fields(
128 db.query.text,
129 upstream_oauth_link.id = %id,
130 ),
131 err,
132 )]
133 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
134 let res = sqlx::query_as!(
135 LinkLookup,
136 r#"
137 SELECT
138 upstream_oauth_link_id,
139 upstream_oauth_provider_id,
140 user_id,
141 subject,
142 human_account_name,
143 created_at
144 FROM upstream_oauth_links
145 WHERE upstream_oauth_link_id = $1
146 "#,
147 Uuid::from(id),
148 )
149 .traced()
150 .fetch_optional(&mut *self.conn)
151 .await?
152 .map(Into::into);
153
154 Ok(res)
155 }
156
157 #[tracing::instrument(
158 name = "db.upstream_oauth_link.find_by_subject",
159 skip_all,
160 fields(
161 db.query.text,
162 upstream_oauth_link.subject = subject,
163 %upstream_oauth_provider.id,
164 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
165 %upstream_oauth_provider.client_id,
166 ),
167 err,
168 )]
169 async fn find_by_subject(
170 &mut self,
171 upstream_oauth_provider: &UpstreamOAuthProvider,
172 subject: &str,
173 ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
174 let res = sqlx::query_as!(
175 LinkLookup,
176 r#"
177 SELECT
178 upstream_oauth_link_id,
179 upstream_oauth_provider_id,
180 user_id,
181 subject,
182 human_account_name,
183 created_at
184 FROM upstream_oauth_links
185 WHERE upstream_oauth_provider_id = $1
186 AND subject = $2
187 "#,
188 Uuid::from(upstream_oauth_provider.id),
189 subject,
190 )
191 .traced()
192 .fetch_optional(&mut *self.conn)
193 .await?
194 .map(Into::into);
195
196 Ok(res)
197 }
198
199 #[tracing::instrument(
200 name = "db.upstream_oauth_link.add",
201 skip_all,
202 fields(
203 db.query.text,
204 upstream_oauth_link.id,
205 upstream_oauth_link.subject = subject,
206 upstream_oauth_link.human_account_name = human_account_name,
207 %upstream_oauth_provider.id,
208 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
209 %upstream_oauth_provider.client_id,
210 ),
211 err,
212 )]
213 async fn add(
214 &mut self,
215 rng: &mut (dyn RngCore + Send),
216 clock: &dyn Clock,
217 upstream_oauth_provider: &UpstreamOAuthProvider,
218 subject: String,
219 human_account_name: Option<String>,
220 ) -> Result<UpstreamOAuthLink, Self::Error> {
221 let created_at = clock.now();
222 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
223 tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
224
225 sqlx::query!(
226 r#"
227 INSERT INTO upstream_oauth_links (
228 upstream_oauth_link_id,
229 upstream_oauth_provider_id,
230 user_id,
231 subject,
232 human_account_name,
233 created_at
234 ) VALUES ($1, $2, NULL, $3, $4, $5)
235 "#,
236 Uuid::from(id),
237 Uuid::from(upstream_oauth_provider.id),
238 &subject,
239 human_account_name.as_deref(),
240 created_at,
241 )
242 .traced()
243 .execute(&mut *self.conn)
244 .await?;
245
246 Ok(UpstreamOAuthLink {
247 id,
248 provider_id: upstream_oauth_provider.id,
249 user_id: None,
250 subject,
251 human_account_name,
252 created_at,
253 })
254 }
255
256 #[tracing::instrument(
257 name = "db.upstream_oauth_link.associate_to_user",
258 skip_all,
259 fields(
260 db.query.text,
261 %upstream_oauth_link.id,
262 %upstream_oauth_link.subject,
263 %user.id,
264 %user.username,
265 ),
266 err,
267 )]
268 async fn associate_to_user(
269 &mut self,
270 upstream_oauth_link: &UpstreamOAuthLink,
271 user: &User,
272 ) -> Result<(), Self::Error> {
273 sqlx::query!(
274 r#"
275 UPDATE upstream_oauth_links
276 SET user_id = $1
277 WHERE upstream_oauth_link_id = $2
278 "#,
279 Uuid::from(user.id),
280 Uuid::from(upstream_oauth_link.id),
281 )
282 .traced()
283 .execute(&mut *self.conn)
284 .await?;
285
286 Ok(())
287 }
288
289 #[tracing::instrument(
290 name = "db.upstream_oauth_link.list",
291 skip_all,
292 fields(
293 db.query.text,
294 ),
295 err,
296 )]
297 async fn list(
298 &mut self,
299 filter: UpstreamOAuthLinkFilter<'_>,
300 pagination: Pagination,
301 ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
302 let (sql, arguments) = Query::select()
303 .expr_as(
304 Expr::col((
305 UpstreamOAuthLinks::Table,
306 UpstreamOAuthLinks::UpstreamOAuthLinkId,
307 )),
308 LinkLookupIden::UpstreamOauthLinkId,
309 )
310 .expr_as(
311 Expr::col((
312 UpstreamOAuthLinks::Table,
313 UpstreamOAuthLinks::UpstreamOAuthProviderId,
314 )),
315 LinkLookupIden::UpstreamOauthProviderId,
316 )
317 .expr_as(
318 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
319 LinkLookupIden::UserId,
320 )
321 .expr_as(
322 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
323 LinkLookupIden::Subject,
324 )
325 .expr_as(
326 Expr::col((
327 UpstreamOAuthLinks::Table,
328 UpstreamOAuthLinks::HumanAccountName,
329 )),
330 LinkLookupIden::HumanAccountName,
331 )
332 .expr_as(
333 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
334 LinkLookupIden::CreatedAt,
335 )
336 .from(UpstreamOAuthLinks::Table)
337 .apply_filter(filter)
338 .generate_pagination(
339 (
340 UpstreamOAuthLinks::Table,
341 UpstreamOAuthLinks::UpstreamOAuthLinkId,
342 ),
343 pagination,
344 )
345 .build_sqlx(PostgresQueryBuilder);
346
347 let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
348 .traced()
349 .fetch_all(&mut *self.conn)
350 .await?;
351
352 let page = pagination.process(edges).map(UpstreamOAuthLink::from);
353
354 Ok(page)
355 }
356
357 #[tracing::instrument(
358 name = "db.upstream_oauth_link.count",
359 skip_all,
360 fields(
361 db.query.text,
362 ),
363 err,
364 )]
365 async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
366 let (sql, arguments) = Query::select()
367 .expr(
368 Expr::col((
369 UpstreamOAuthLinks::Table,
370 UpstreamOAuthLinks::UpstreamOAuthLinkId,
371 ))
372 .count(),
373 )
374 .from(UpstreamOAuthLinks::Table)
375 .apply_filter(filter)
376 .build_sqlx(PostgresQueryBuilder);
377
378 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
379 .traced()
380 .fetch_one(&mut *self.conn)
381 .await?;
382
383 count
384 .try_into()
385 .map_err(DatabaseError::to_invalid_operation)
386 }
387
388 #[tracing::instrument(
389 name = "db.upstream_oauth_link.remove",
390 skip_all,
391 fields(
392 db.query.text,
393 upstream_oauth_link.id,
394 upstream_oauth_link.provider_id,
395 %upstream_oauth_link.subject,
396 ),
397 err,
398 )]
399 async fn remove(
400 &mut self,
401 clock: &dyn Clock,
402 upstream_oauth_link: UpstreamOAuthLink,
403 ) -> Result<(), Self::Error> {
404 let span = tracing::info_span!(
407 "db.upstream_oauth_link.remove.unlink",
408 { DB_QUERY_TEXT } = tracing::field::Empty
409 );
410 sqlx::query!(
411 r#"
412 UPDATE upstream_oauth_authorization_sessions SET
413 upstream_oauth_link_id = NULL,
414 unlinked_at = $2
415 WHERE upstream_oauth_link_id = $1
416 "#,
417 Uuid::from(upstream_oauth_link.id),
418 clock.now()
419 )
420 .record(&span)
421 .execute(&mut *self.conn)
422 .instrument(span)
423 .await?;
424
425 let span = tracing::info_span!(
427 "db.upstream_oauth_link.remove.delete",
428 { DB_QUERY_TEXT } = tracing::field::Empty
429 );
430 let res = sqlx::query!(
431 r#"
432 DELETE FROM upstream_oauth_links
433 WHERE upstream_oauth_link_id = $1
434 "#,
435 Uuid::from(upstream_oauth_link.id),
436 )
437 .record(&span)
438 .execute(&mut *self.conn)
439 .instrument(span)
440 .await?;
441
442 DatabaseError::ensure_affected_rows(&res, 1)?;
443
444 Ok(())
445 }
446
447 #[tracing::instrument(
448 name = "db.upstream_oauth_link.cleanup_orphaned",
449 skip_all,
450 fields(
451 db.query.text,
452 since = since.map(tracing::field::display),
453 until = %until,
454 limit = limit,
455 ),
456 err,
457 )]
458 async fn cleanup_orphaned(
459 &mut self,
460 since: Option<Ulid>,
461 until: Ulid,
462 limit: usize,
463 ) -> Result<(usize, Option<Ulid>), Self::Error> {
464 let res = sqlx::query_scalar!(
468 r#"
469 WITH
470 to_delete AS (
471 SELECT upstream_oauth_link_id
472 FROM upstream_oauth_links
473 WHERE user_id IS NULL
474 AND ($1::uuid IS NULL OR upstream_oauth_link_id > $1)
475 AND upstream_oauth_link_id <= $2
476 ORDER BY upstream_oauth_link_id
477 LIMIT $3
478 ),
479 deleted_sessions AS (
480 DELETE FROM upstream_oauth_authorization_sessions
481 USING to_delete
482 WHERE upstream_oauth_authorization_sessions.upstream_oauth_link_id = to_delete.upstream_oauth_link_id
483 )
484 DELETE FROM upstream_oauth_links
485 USING to_delete
486 WHERE upstream_oauth_links.upstream_oauth_link_id = to_delete.upstream_oauth_link_id
487 RETURNING upstream_oauth_links.upstream_oauth_link_id
488 "#,
489 since.map(Uuid::from),
490 Uuid::from(until),
491 i64::try_from(limit).unwrap_or(i64::MAX)
492 )
493 .traced()
494 .fetch_all(&mut *self.conn)
495 .await?;
496
497 let count = res.len();
498 let max_id = res.into_iter().max();
499
500 Ok((count, max_id.map(Ulid::from)))
501 }
502}