1use async_trait::async_trait;
11use mas_data_model::User;
12use mas_storage::{
13 Clock,
14 user::{UserFilter, UserRepository},
15};
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24 DatabaseError,
25 filter::{Filter, StatementExt},
26 iden::Users,
27 pagination::QueryBuilderExt,
28 tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod session;
36mod terms;
37
38#[cfg(test)]
39mod tests;
40
41pub use self::{
42 email::PgUserEmailRepository, password::PgUserPasswordRepository,
43 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
44 session::PgBrowserSessionRepository, terms::PgUserTermsRepository,
45};
46
47pub struct PgUserRepository<'c> {
49 conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53 pub fn new(conn: &'c mut PgConnection) -> Self {
55 Self { conn }
56 }
57}
58
59mod priv_ {
60 #![allow(missing_docs)]
63
64 use chrono::{DateTime, Utc};
65 use sea_query::enum_def;
66 use uuid::Uuid;
67
68 #[derive(Debug, Clone, sqlx::FromRow)]
69 #[enum_def]
70 pub(super) struct UserLookup {
71 pub(super) user_id: Uuid,
72 pub(super) username: String,
73 pub(super) created_at: DateTime<Utc>,
74 pub(super) locked_at: Option<DateTime<Utc>>,
75 pub(super) deactivated_at: Option<DateTime<Utc>>,
76 pub(super) can_request_admin: bool,
77 }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83 fn from(value: UserLookup) -> Self {
84 let id = value.user_id.into();
85 Self {
86 id,
87 username: value.username,
88 sub: id.to_string(),
89 created_at: value.created_at,
90 locked_at: value.locked_at,
91 deactivated_at: value.deactivated_at,
92 can_request_admin: value.can_request_admin,
93 }
94 }
95}
96
97impl Filter for UserFilter<'_> {
98 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
99 sea_query::Condition::all()
100 .add_option(self.state().map(|state| {
101 match state {
102 mas_storage::user::UserState::Deactivated => {
103 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
104 }
105 mas_storage::user::UserState::Locked => {
106 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
107 }
108 mas_storage::user::UserState::Active => {
109 Expr::col((Users::Table, Users::LockedAt))
110 .is_null()
111 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
112 }
113 }
114 }))
115 .add_option(self.can_request_admin().map(|can_request_admin| {
116 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
117 }))
118 }
119}
120
121#[async_trait]
122impl UserRepository for PgUserRepository<'_> {
123 type Error = DatabaseError;
124
125 #[tracing::instrument(
126 name = "db.user.lookup",
127 skip_all,
128 fields(
129 db.query.text,
130 user.id = %id,
131 ),
132 err,
133 )]
134 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
135 let res = sqlx::query_as!(
136 UserLookup,
137 r#"
138 SELECT user_id
139 , username
140 , created_at
141 , locked_at
142 , deactivated_at
143 , can_request_admin
144 FROM users
145 WHERE user_id = $1
146 "#,
147 Uuid::from(id),
148 )
149 .traced()
150 .fetch_optional(&mut *self.conn)
151 .await?;
152
153 let Some(res) = res else { return Ok(None) };
154
155 Ok(Some(res.into()))
156 }
157
158 #[tracing::instrument(
159 name = "db.user.find_by_username",
160 skip_all,
161 fields(
162 db.query.text,
163 user.username = username,
164 ),
165 err,
166 )]
167 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
168 let res = sqlx::query_as!(
169 UserLookup,
170 r#"
171 SELECT user_id
172 , username
173 , created_at
174 , locked_at
175 , deactivated_at
176 , can_request_admin
177 FROM users
178 WHERE username = $1
179 "#,
180 username,
181 )
182 .traced()
183 .fetch_optional(&mut *self.conn)
184 .await?;
185
186 let Some(res) = res else { return Ok(None) };
187
188 Ok(Some(res.into()))
189 }
190
191 #[tracing::instrument(
192 name = "db.user.add",
193 skip_all,
194 fields(
195 db.query.text,
196 user.username = username,
197 user.id,
198 ),
199 err,
200 )]
201 async fn add(
202 &mut self,
203 rng: &mut (dyn RngCore + Send),
204 clock: &dyn Clock,
205 username: String,
206 ) -> Result<User, Self::Error> {
207 let created_at = clock.now();
208 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
209 tracing::Span::current().record("user.id", tracing::field::display(id));
210
211 let res = sqlx::query!(
212 r#"
213 INSERT INTO users (user_id, username, created_at)
214 VALUES ($1, $2, $3)
215 ON CONFLICT (username) DO NOTHING
216 "#,
217 Uuid::from(id),
218 username,
219 created_at,
220 )
221 .traced()
222 .execute(&mut *self.conn)
223 .await?;
224
225 DatabaseError::ensure_affected_rows(&res, 1)?;
228
229 Ok(User {
230 id,
231 username,
232 sub: id.to_string(),
233 created_at,
234 locked_at: None,
235 deactivated_at: None,
236 can_request_admin: false,
237 })
238 }
239
240 #[tracing::instrument(
241 name = "db.user.exists",
242 skip_all,
243 fields(
244 db.query.text,
245 user.username = username,
246 ),
247 err,
248 )]
249 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
250 let exists = sqlx::query_scalar!(
251 r#"
252 SELECT EXISTS(
253 SELECT 1 FROM users WHERE username = $1
254 ) AS "exists!"
255 "#,
256 username
257 )
258 .traced()
259 .fetch_one(&mut *self.conn)
260 .await?;
261
262 Ok(exists)
263 }
264
265 #[tracing::instrument(
266 name = "db.user.lock",
267 skip_all,
268 fields(
269 db.query.text,
270 %user.id,
271 ),
272 err,
273 )]
274 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
275 if user.locked_at.is_some() {
276 return Ok(user);
277 }
278
279 let locked_at = clock.now();
280 let res = sqlx::query!(
281 r#"
282 UPDATE users
283 SET locked_at = $1
284 WHERE user_id = $2
285 "#,
286 locked_at,
287 Uuid::from(user.id),
288 )
289 .traced()
290 .execute(&mut *self.conn)
291 .await?;
292
293 DatabaseError::ensure_affected_rows(&res, 1)?;
294
295 user.locked_at = Some(locked_at);
296
297 Ok(user)
298 }
299
300 #[tracing::instrument(
301 name = "db.user.unlock",
302 skip_all,
303 fields(
304 db.query.text,
305 %user.id,
306 ),
307 err,
308 )]
309 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
310 if user.locked_at.is_none() {
311 return Ok(user);
312 }
313
314 let res = sqlx::query!(
315 r#"
316 UPDATE users
317 SET locked_at = NULL
318 WHERE user_id = $1
319 "#,
320 Uuid::from(user.id),
321 )
322 .traced()
323 .execute(&mut *self.conn)
324 .await?;
325
326 DatabaseError::ensure_affected_rows(&res, 1)?;
327
328 user.locked_at = None;
329
330 Ok(user)
331 }
332
333 #[tracing::instrument(
334 name = "db.user.deactivate",
335 skip_all,
336 fields(
337 db.query.text,
338 %user.id,
339 ),
340 err,
341 )]
342 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
343 if user.deactivated_at.is_some() {
344 return Ok(user);
345 }
346
347 let deactivated_at = clock.now();
348 let res = sqlx::query!(
349 r#"
350 UPDATE users
351 SET deactivated_at = $2
352 WHERE user_id = $1
353 AND deactivated_at IS NULL
354 "#,
355 Uuid::from(user.id),
356 deactivated_at,
357 )
358 .traced()
359 .execute(&mut *self.conn)
360 .await?;
361
362 DatabaseError::ensure_affected_rows(&res, 1)?;
363
364 user.deactivated_at = Some(user.created_at);
365
366 Ok(user)
367 }
368
369 #[tracing::instrument(
370 name = "db.user.set_can_request_admin",
371 skip_all,
372 fields(
373 db.query.text,
374 %user.id,
375 user.can_request_admin = can_request_admin,
376 ),
377 err,
378 )]
379 async fn set_can_request_admin(
380 &mut self,
381 mut user: User,
382 can_request_admin: bool,
383 ) -> Result<User, Self::Error> {
384 let res = sqlx::query!(
385 r#"
386 UPDATE users
387 SET can_request_admin = $2
388 WHERE user_id = $1
389 "#,
390 Uuid::from(user.id),
391 can_request_admin,
392 )
393 .traced()
394 .execute(&mut *self.conn)
395 .await?;
396
397 DatabaseError::ensure_affected_rows(&res, 1)?;
398
399 user.can_request_admin = can_request_admin;
400
401 Ok(user)
402 }
403
404 #[tracing::instrument(
405 name = "db.user.list",
406 skip_all,
407 fields(
408 db.query.text,
409 ),
410 err,
411 )]
412 async fn list(
413 &mut self,
414 filter: UserFilter<'_>,
415 pagination: mas_storage::Pagination,
416 ) -> Result<mas_storage::Page<User>, Self::Error> {
417 let (sql, arguments) = Query::select()
418 .expr_as(
419 Expr::col((Users::Table, Users::UserId)),
420 UserLookupIden::UserId,
421 )
422 .expr_as(
423 Expr::col((Users::Table, Users::Username)),
424 UserLookupIden::Username,
425 )
426 .expr_as(
427 Expr::col((Users::Table, Users::CreatedAt)),
428 UserLookupIden::CreatedAt,
429 )
430 .expr_as(
431 Expr::col((Users::Table, Users::LockedAt)),
432 UserLookupIden::LockedAt,
433 )
434 .expr_as(
435 Expr::col((Users::Table, Users::DeactivatedAt)),
436 UserLookupIden::DeactivatedAt,
437 )
438 .expr_as(
439 Expr::col((Users::Table, Users::CanRequestAdmin)),
440 UserLookupIden::CanRequestAdmin,
441 )
442 .from(Users::Table)
443 .apply_filter(filter)
444 .generate_pagination((Users::Table, Users::UserId), pagination)
445 .build_sqlx(PostgresQueryBuilder);
446
447 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
448 .traced()
449 .fetch_all(&mut *self.conn)
450 .await?;
451
452 let page = pagination.process(edges).map(User::from);
453
454 Ok(page)
455 }
456
457 #[tracing::instrument(
458 name = "db.user.count",
459 skip_all,
460 fields(
461 db.query.text,
462 ),
463 err,
464 )]
465 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
466 let (sql, arguments) = Query::select()
467 .expr(Expr::col((Users::Table, Users::UserId)).count())
468 .from(Users::Table)
469 .apply_filter(filter)
470 .build_sqlx(PostgresQueryBuilder);
471
472 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
473 .traced()
474 .fetch_one(&mut *self.conn)
475 .await?;
476
477 count
478 .try_into()
479 .map_err(DatabaseError::to_invalid_operation)
480 }
481
482 #[tracing::instrument(
483 name = "db.user.acquire_lock_for_sync",
484 skip_all,
485 fields(
486 db.query.text,
487 user.id = %user.id,
488 ),
489 err,
490 )]
491 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
492 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
500
501 sqlx::query!(
504 r#"
505 SELECT pg_advisory_xact_lock($1)
506 "#,
507 lock_id,
508 )
509 .traced()
510 .execute(&mut *self.conn)
511 .await?;
512
513 Ok(())
514 }
515}