1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
10use mas_storage::{
11 Clock, Page, Pagination,
12 upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
13};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use uuid::Uuid;
20
21use crate::{
22 DatabaseError,
23 filter::{Filter, StatementExt},
24 iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
25 pagination::QueryBuilderExt,
26 tracing::ExecuteExt,
27};
28
29pub struct PgUpstreamOAuthLinkRepository<'c> {
32 conn: &'c mut PgConnection,
33}
34
35impl<'c> PgUpstreamOAuthLinkRepository<'c> {
36 pub fn new(conn: &'c mut PgConnection) -> Self {
39 Self { conn }
40 }
41}
42
43#[derive(sqlx::FromRow)]
44#[enum_def]
45struct LinkLookup {
46 upstream_oauth_link_id: Uuid,
47 upstream_oauth_provider_id: Uuid,
48 user_id: Option<Uuid>,
49 subject: String,
50 human_account_name: Option<String>,
51 created_at: DateTime<Utc>,
52}
53
54impl From<LinkLookup> for UpstreamOAuthLink {
55 fn from(value: LinkLookup) -> Self {
56 UpstreamOAuthLink {
57 id: Ulid::from(value.upstream_oauth_link_id),
58 provider_id: Ulid::from(value.upstream_oauth_provider_id),
59 user_id: value.user_id.map(Ulid::from),
60 subject: value.subject,
61 human_account_name: value.human_account_name,
62 created_at: value.created_at,
63 }
64 }
65}
66
67impl Filter for UpstreamOAuthLinkFilter<'_> {
68 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
69 sea_query::Condition::all()
70 .add_option(self.user().map(|user| {
71 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
72 .eq(Uuid::from(user.id))
73 }))
74 .add_option(self.provider().map(|provider| {
75 Expr::col((
76 UpstreamOAuthLinks::Table,
77 UpstreamOAuthLinks::UpstreamOAuthProviderId,
78 ))
79 .eq(Uuid::from(provider.id))
80 }))
81 .add_option(self.provider_enabled().map(|enabled| {
82 Expr::col((
83 UpstreamOAuthLinks::Table,
84 UpstreamOAuthLinks::UpstreamOAuthProviderId,
85 ))
86 .eq(Expr::any(
87 Query::select()
88 .expr(Expr::col((
89 UpstreamOAuthProviders::Table,
90 UpstreamOAuthProviders::UpstreamOAuthProviderId,
91 )))
92 .from(UpstreamOAuthProviders::Table)
93 .and_where(
94 Expr::col((
95 UpstreamOAuthProviders::Table,
96 UpstreamOAuthProviders::DisabledAt,
97 ))
98 .is_null()
99 .eq(enabled),
100 )
101 .take(),
102 ))
103 }))
104 .add_option(self.subject().map(|subject| {
105 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
106 }))
107 }
108}
109
110#[async_trait]
111impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
112 type Error = DatabaseError;
113
114 #[tracing::instrument(
115 name = "db.upstream_oauth_link.lookup",
116 skip_all,
117 fields(
118 db.query.text,
119 upstream_oauth_link.id = %id,
120 ),
121 err,
122 )]
123 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
124 let res = sqlx::query_as!(
125 LinkLookup,
126 r#"
127 SELECT
128 upstream_oauth_link_id,
129 upstream_oauth_provider_id,
130 user_id,
131 subject,
132 human_account_name,
133 created_at
134 FROM upstream_oauth_links
135 WHERE upstream_oauth_link_id = $1
136 "#,
137 Uuid::from(id),
138 )
139 .traced()
140 .fetch_optional(&mut *self.conn)
141 .await?
142 .map(Into::into);
143
144 Ok(res)
145 }
146
147 #[tracing::instrument(
148 name = "db.upstream_oauth_link.find_by_subject",
149 skip_all,
150 fields(
151 db.query.text,
152 upstream_oauth_link.subject = subject,
153 %upstream_oauth_provider.id,
154 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
155 %upstream_oauth_provider.client_id,
156 ),
157 err,
158 )]
159 async fn find_by_subject(
160 &mut self,
161 upstream_oauth_provider: &UpstreamOAuthProvider,
162 subject: &str,
163 ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
164 let res = sqlx::query_as!(
165 LinkLookup,
166 r#"
167 SELECT
168 upstream_oauth_link_id,
169 upstream_oauth_provider_id,
170 user_id,
171 subject,
172 human_account_name,
173 created_at
174 FROM upstream_oauth_links
175 WHERE upstream_oauth_provider_id = $1
176 AND subject = $2
177 "#,
178 Uuid::from(upstream_oauth_provider.id),
179 subject,
180 )
181 .traced()
182 .fetch_optional(&mut *self.conn)
183 .await?
184 .map(Into::into);
185
186 Ok(res)
187 }
188
189 #[tracing::instrument(
190 name = "db.upstream_oauth_link.add",
191 skip_all,
192 fields(
193 db.query.text,
194 upstream_oauth_link.id,
195 upstream_oauth_link.subject = subject,
196 upstream_oauth_link.human_account_name = human_account_name,
197 %upstream_oauth_provider.id,
198 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
199 %upstream_oauth_provider.client_id,
200 ),
201 err,
202 )]
203 async fn add(
204 &mut self,
205 rng: &mut (dyn RngCore + Send),
206 clock: &dyn Clock,
207 upstream_oauth_provider: &UpstreamOAuthProvider,
208 subject: String,
209 human_account_name: Option<String>,
210 ) -> Result<UpstreamOAuthLink, Self::Error> {
211 let created_at = clock.now();
212 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
213 tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
214
215 sqlx::query!(
216 r#"
217 INSERT INTO upstream_oauth_links (
218 upstream_oauth_link_id,
219 upstream_oauth_provider_id,
220 user_id,
221 subject,
222 human_account_name,
223 created_at
224 ) VALUES ($1, $2, NULL, $3, $4, $5)
225 "#,
226 Uuid::from(id),
227 Uuid::from(upstream_oauth_provider.id),
228 &subject,
229 human_account_name.as_deref(),
230 created_at,
231 )
232 .traced()
233 .execute(&mut *self.conn)
234 .await?;
235
236 Ok(UpstreamOAuthLink {
237 id,
238 provider_id: upstream_oauth_provider.id,
239 user_id: None,
240 subject,
241 human_account_name,
242 created_at,
243 })
244 }
245
246 #[tracing::instrument(
247 name = "db.upstream_oauth_link.associate_to_user",
248 skip_all,
249 fields(
250 db.query.text,
251 %upstream_oauth_link.id,
252 %upstream_oauth_link.subject,
253 %user.id,
254 %user.username,
255 ),
256 err,
257 )]
258 async fn associate_to_user(
259 &mut self,
260 upstream_oauth_link: &UpstreamOAuthLink,
261 user: &User,
262 ) -> Result<(), Self::Error> {
263 sqlx::query!(
264 r#"
265 UPDATE upstream_oauth_links
266 SET user_id = $1
267 WHERE upstream_oauth_link_id = $2
268 "#,
269 Uuid::from(user.id),
270 Uuid::from(upstream_oauth_link.id),
271 )
272 .traced()
273 .execute(&mut *self.conn)
274 .await?;
275
276 Ok(())
277 }
278
279 #[tracing::instrument(
280 name = "db.upstream_oauth_link.list",
281 skip_all,
282 fields(
283 db.query.text,
284 ),
285 err,
286 )]
287 async fn list(
288 &mut self,
289 filter: UpstreamOAuthLinkFilter<'_>,
290 pagination: Pagination,
291 ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
292 let (sql, arguments) = Query::select()
293 .expr_as(
294 Expr::col((
295 UpstreamOAuthLinks::Table,
296 UpstreamOAuthLinks::UpstreamOAuthLinkId,
297 )),
298 LinkLookupIden::UpstreamOauthLinkId,
299 )
300 .expr_as(
301 Expr::col((
302 UpstreamOAuthLinks::Table,
303 UpstreamOAuthLinks::UpstreamOAuthProviderId,
304 )),
305 LinkLookupIden::UpstreamOauthProviderId,
306 )
307 .expr_as(
308 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
309 LinkLookupIden::UserId,
310 )
311 .expr_as(
312 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
313 LinkLookupIden::Subject,
314 )
315 .expr_as(
316 Expr::col((
317 UpstreamOAuthLinks::Table,
318 UpstreamOAuthLinks::HumanAccountName,
319 )),
320 LinkLookupIden::HumanAccountName,
321 )
322 .expr_as(
323 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
324 LinkLookupIden::CreatedAt,
325 )
326 .from(UpstreamOAuthLinks::Table)
327 .apply_filter(filter)
328 .generate_pagination(
329 (
330 UpstreamOAuthLinks::Table,
331 UpstreamOAuthLinks::UpstreamOAuthLinkId,
332 ),
333 pagination,
334 )
335 .build_sqlx(PostgresQueryBuilder);
336
337 let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
338 .traced()
339 .fetch_all(&mut *self.conn)
340 .await?;
341
342 let page = pagination.process(edges).map(UpstreamOAuthLink::from);
343
344 Ok(page)
345 }
346
347 #[tracing::instrument(
348 name = "db.upstream_oauth_link.count",
349 skip_all,
350 fields(
351 db.query.text,
352 ),
353 err,
354 )]
355 async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
356 let (sql, arguments) = Query::select()
357 .expr(
358 Expr::col((
359 UpstreamOAuthLinks::Table,
360 UpstreamOAuthLinks::UpstreamOAuthLinkId,
361 ))
362 .count(),
363 )
364 .from(UpstreamOAuthLinks::Table)
365 .apply_filter(filter)
366 .build_sqlx(PostgresQueryBuilder);
367
368 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
369 .traced()
370 .fetch_one(&mut *self.conn)
371 .await?;
372
373 count
374 .try_into()
375 .map_err(DatabaseError::to_invalid_operation)
376 }
377}