mas_storage_pg/
policy_data.rs1use async_trait::async_trait;
10use mas_data_model::PolicyData;
11use mas_storage::{Clock, policy_data::PolicyDataRepository};
12use rand::RngCore;
13use serde_json::Value;
14use sqlx::{PgConnection, types::Json};
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, ExecuteExt};
19
20pub struct PgPolicyDataRepository<'c> {
22 conn: &'c mut PgConnection,
23}
24
25impl<'c> PgPolicyDataRepository<'c> {
26 #[must_use]
29 pub fn new(conn: &'c mut PgConnection) -> Self {
30 Self { conn }
31 }
32}
33
34struct PolicyDataLookup {
35 policy_data_id: Uuid,
36 created_at: chrono::DateTime<chrono::Utc>,
37 data: Json<Value>,
38}
39
40impl From<PolicyDataLookup> for PolicyData {
41 fn from(value: PolicyDataLookup) -> Self {
42 PolicyData {
43 id: value.policy_data_id.into(),
44 created_at: value.created_at,
45 data: value.data.0,
46 }
47 }
48}
49
50#[async_trait]
51impl PolicyDataRepository for PgPolicyDataRepository<'_> {
52 type Error = DatabaseError;
53
54 #[tracing::instrument(
55 name = "db.policy_data.get",
56 skip_all,
57 fields(
58 db.query.text,
59 ),
60 err,
61 )]
62 async fn get(&mut self) -> Result<Option<PolicyData>, Self::Error> {
63 let row = sqlx::query_as!(
64 PolicyDataLookup,
65 r#"
66 SELECT policy_data_id, created_at, data
67 FROM policy_data
68 ORDER BY policy_data_id DESC
69 LIMIT 1
70 "#
71 )
72 .traced()
73 .fetch_optional(&mut *self.conn)
74 .await?;
75
76 let Some(row) = row else {
77 return Ok(None);
78 };
79
80 Ok(Some(row.into()))
81 }
82
83 #[tracing::instrument(
84 name = "db.policy_data.set",
85 skip_all,
86 fields(
87 db.query.text,
88 ),
89 err,
90 )]
91 async fn set(
92 &mut self,
93 rng: &mut (dyn RngCore + Send),
94 clock: &dyn Clock,
95 data: Value,
96 ) -> Result<PolicyData, Self::Error> {
97 let created_at = clock.now();
98 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
99
100 sqlx::query!(
101 r#"
102 INSERT INTO policy_data (policy_data_id, created_at, data)
103 VALUES ($1, $2, $3)
104 "#,
105 Uuid::from(id),
106 created_at,
107 data,
108 )
109 .traced()
110 .execute(&mut *self.conn)
111 .await?;
112
113 Ok(PolicyData {
114 id,
115 created_at,
116 data,
117 })
118 }
119
120 #[tracing::instrument(
121 name = "db.policy_data.prune",
122 skip_all,
123 fields(
124 db.query.text,
125 ),
126 err,
127 )]
128 async fn prune(&mut self, keep: usize) -> Result<usize, Self::Error> {
129 let res = sqlx::query!(
130 r#"
131 DELETE FROM policy_data
132 WHERE policy_data_id IN (
133 SELECT policy_data_id
134 FROM policy_data
135 ORDER BY policy_data_id DESC
136 OFFSET $1
137 )
138 "#,
139 i64::try_from(keep).map_err(DatabaseError::to_invalid_operation)?
140 )
141 .traced()
142 .execute(&mut *self.conn)
143 .await?;
144
145 Ok(res
146 .rows_affected()
147 .try_into()
148 .map_err(DatabaseError::to_invalid_operation)?)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use mas_storage::{clock::MockClock, policy_data::PolicyDataRepository};
155 use rand::SeedableRng;
156 use rand_chacha::ChaChaRng;
157 use serde_json::json;
158 use sqlx::PgPool;
159
160 use crate::policy_data::PgPolicyDataRepository;
161
162 #[sqlx::test(migrator = "crate::MIGRATOR")]
163 async fn test_policy_data(pool: PgPool) {
164 let mut rng = ChaChaRng::seed_from_u64(42);
165 let clock = MockClock::default();
166 let mut conn = pool.acquire().await.unwrap();
167 let mut repo = PgPolicyDataRepository::new(&mut conn);
168
169 let data = repo.get().await.unwrap();
171 assert_eq!(data, None);
172
173 let value1 = json!({"hello": "world"});
175 let policy_data1 = repo.set(&mut rng, &clock, value1.clone()).await.unwrap();
176 assert_eq!(policy_data1.data, value1);
177
178 let data_fetched1 = repo.get().await.unwrap().unwrap();
179 assert_eq!(policy_data1, data_fetched1);
180
181 clock.advance(chrono::Duration::seconds(1));
183 let value2 = json!({"foo": "bar"});
184 let policy_data2 = repo.set(&mut rng, &clock, value2.clone()).await.unwrap();
185 assert_eq!(policy_data2.data, value2);
186
187 let data_fetched2 = repo.get().await.unwrap().unwrap();
189 assert_eq!(data_fetched2, policy_data2);
190
191 let affected = repo.prune(1).await.unwrap();
193 let data_fetched3 = repo.get().await.unwrap().unwrap();
194 assert_eq!(data_fetched3, policy_data2);
195 assert_eq!(affected, 1);
196
197 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM policy_data")
199 .fetch_one(&mut *conn)
200 .await
201 .unwrap();
202 assert_eq!(count, 1);
203 }
204}