mas_storage_pg/
policy_data.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6//! A module containing the PostgreSQL implementation of the policy data
7//! storage.
8
9use async_trait::async_trait;
10use mas_data_model::PolicyData;
11use mas_storage::{Clock, policy_data::PolicyDataRepository};
12use rand::RngCore;
13use serde_json::Value;
14use sqlx::{PgConnection, types::Json};
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, ExecuteExt};
19
20/// An implementation of [`PolicyDataRepository`] for a PostgreSQL connection.
21pub struct PgPolicyDataRepository<'c> {
22    conn: &'c mut PgConnection,
23}
24
25impl<'c> PgPolicyDataRepository<'c> {
26    /// Create a new [`PgPolicyDataRepository`] from an active PostgreSQL
27    /// connection.
28    #[must_use]
29    pub fn new(conn: &'c mut PgConnection) -> Self {
30        Self { conn }
31    }
32}
33
34struct PolicyDataLookup {
35    policy_data_id: Uuid,
36    created_at: chrono::DateTime<chrono::Utc>,
37    data: Json<Value>,
38}
39
40impl From<PolicyDataLookup> for PolicyData {
41    fn from(value: PolicyDataLookup) -> Self {
42        PolicyData {
43            id: value.policy_data_id.into(),
44            created_at: value.created_at,
45            data: value.data.0,
46        }
47    }
48}
49
50#[async_trait]
51impl PolicyDataRepository for PgPolicyDataRepository<'_> {
52    type Error = DatabaseError;
53
54    #[tracing::instrument(
55        name = "db.policy_data.get",
56        skip_all,
57        fields(
58            db.query.text,
59        ),
60        err,
61    )]
62    async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error> {
63        let row = sqlx::query_as!(
64            PolicyDataLookup,
65            r#"
66            SELECT policy_data_id, created_at, data
67            FROM policy_data
68            ORDER BY policy_data_id DESC
69            LIMIT 1
70            "#
71        )
72        .traced()
73        .fetch_optional(&mut *self.conn)
74        .await?;
75
76        let Some(row) = row else {
77            return Ok(None);
78        };
79
80        Ok(Some(row.into()))
81    }
82
83    #[tracing::instrument(
84        name = "db.policy_data.set",
85        skip_all,
86        fields(
87            db.query.text,
88        ),
89        err,
90    )]
91    async fn set(
92        &mut self,
93        rng: &mut (dyn RngCore + Send),
94        clock: &dyn Clock,
95        data: Value,
96    ) -> Result<PolicyData, Self::Error> {
97        let created_at = clock.now();
98        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
99
100        sqlx::query!(
101            r#"
102            INSERT INTO policy_data (policy_data_id, created_at, data)
103            VALUES ($1, $2, $3)
104            "#,
105            Uuid::from(id),
106            created_at,
107            data,
108        )
109        .traced()
110        .execute(&mut *self.conn)
111        .await?;
112
113        Ok(PolicyData {
114            id,
115            created_at,
116            data,
117        })
118    }
119
120    #[tracing::instrument(
121        name = "db.policy_data.prune",
122        skip_all,
123        fields(
124            db.query.text,
125        ),
126        err,
127    )]
128    async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error> {
129        let res = sqlx::query!(
130            r#"
131            DELETE FROM policy_data
132            WHERE policy_data_id IN (
133                SELECT policy_data_id
134                FROM policy_data
135                ORDER BY policy_data_id DESC
136                OFFSET $1
137            )
138            "#,
139            i64::try_from(keep).map_err(DatabaseError::to_invalid_operation)?
140        )
141        .traced()
142        .execute(&mut *self.conn)
143        .await?;
144
145        Ok(res
146            .rows_affected()
147            .try_into()
148            .map_err(DatabaseError::to_invalid_operation)?)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use mas_storage::{clock::MockClock, policy_data::PolicyDataRepository};
155    use rand::SeedableRng;
156    use rand_chacha::ChaChaRng;
157    use serde_json::json;
158    use sqlx::PgPool;
159
160    use crate::policy_data::PgPolicyDataRepository;
161
162    #[sqlx::test(migrator = "crate::MIGRATOR")]
163    async fn test_policy_data(pool: PgPool) {
164        let mut rng = ChaChaRng::seed_from_u64(42);
165        let clock = MockClock::default();
166        let mut conn = pool.acquire().await.unwrap();
167        let mut repo = PgPolicyDataRepository::new(&mut conn);
168
169        // Get an empty state at first
170        let data = repo.get().await.unwrap();
171        assert_eq!(data, None);
172
173        // Set some data
174        let value1 = json!({"hello": "world"});
175        let policy_data1 = repo.set(&mut rng, &clock, value1.clone()).await.unwrap();
176        assert_eq!(policy_data1.data, value1);
177
178        let data_fetched1 = repo.get().await.unwrap().unwrap();
179        assert_eq!(policy_data1, data_fetched1);
180
181        // Set some new data
182        clock.advance(chrono::Duration::seconds(1));
183        let value2 = json!({"foo": "bar"});
184        let policy_data2 = repo.set(&mut rng, &clock, value2.clone()).await.unwrap();
185        assert_eq!(policy_data2.data, value2);
186
187        // Check the new data is fetched
188        let data_fetched2 = repo.get().await.unwrap().unwrap();
189        assert_eq!(data_fetched2, policy_data2);
190
191        // Prune until the first entry
192        let affected = repo.prune(1).await.unwrap();
193        let data_fetched3 = repo.get().await.unwrap().unwrap();
194        assert_eq!(data_fetched3, policy_data2);
195        assert_eq!(affected, 1);
196
197        // Do a raw query to check the other rows were pruned
198        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM policy_data")
199            .fetch_one(&mut *conn)
200            .await
201            .unwrap();
202        assert_eq!(count, 1);
203    }
204}