mas_storage_pg/user/
terms.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::User;
9use mas_storage::{Clock, user::UserTermsRepository};
10use rand::RngCore;
11use sqlx::PgConnection;
12use ulid::Ulid;
13use url::Url;
14use uuid::Uuid;
15
16use crate::{DatabaseError, tracing::ExecuteExt};
17
18/// An implementation of [`UserTermsRepository`] for a PostgreSQL connection
19pub struct PgUserTermsRepository<'c> {
20    conn: &'c mut PgConnection,
21}
22
23impl<'c> PgUserTermsRepository<'c> {
24    /// Create a new [`PgUserTermsRepository`] from an active PostgreSQL
25    /// connection
26    pub fn new(conn: &'c mut PgConnection) -> Self {
27        Self { conn }
28    }
29}
30
31#[async_trait]
32impl UserTermsRepository for PgUserTermsRepository<'_> {
33    type Error = DatabaseError;
34
35    #[tracing::instrument(
36        name = "db.user_terms.accept_terms",
37        skip_all,
38        fields(
39            db.query.text,
40            %user.id,
41            user_terms.id,
42            %user_terms.url = terms_url.as_str(),
43        ),
44        err,
45    )]
46    async fn accept_terms(
47        &mut self,
48        rng: &mut (dyn RngCore + Send),
49        clock: &dyn Clock,
50        user: &User,
51        terms_url: Url,
52    ) -> Result<(), Self::Error> {
53        let created_at = clock.now();
54        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
55        tracing::Span::current().record("user_terms.id", tracing::field::display(id));
56
57        sqlx::query!(
58            r#"
59            INSERT INTO user_terms (user_terms_id, user_id, terms_url, created_at)
60            VALUES ($1, $2, $3, $4)
61            ON CONFLICT (user_id, terms_url) DO NOTHING
62            "#,
63            Uuid::from(id),
64            Uuid::from(user.id),
65            terms_url.as_str(),
66            created_at,
67        )
68        .traced()
69        .execute(&mut *self.conn)
70        .await?;
71
72        Ok(())
73    }
74}