mas_storage_pg/
repository.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::ops::{Deref, DerefMut};
8
9use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
10use mas_storage::{
11    BoxRepository, MapErr, Repository, RepositoryAccess, RepositoryError, RepositoryTransaction,
12    app_session::AppSessionRepository,
13    compat::{
14        CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
15        CompatSsoLoginRepository,
16    },
17    oauth2::{
18        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
19        OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
20    },
21    policy_data::PolicyDataRepository,
22    queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
23    upstream_oauth2::{
24        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
25        UpstreamOAuthSessionRepository,
26    },
27    user::{BrowserSessionRepository, UserEmailRepository, UserPasswordRepository, UserRepository},
28};
29use sqlx::{PgConnection, PgPool, Postgres, Transaction};
30use tracing::Instrument;
31
32use crate::{
33    DatabaseError,
34    app_session::PgAppSessionRepository,
35    compat::{
36        PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
37        PgCompatSsoLoginRepository,
38    },
39    oauth2::{
40        PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
41        PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
42        PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
43    },
44    policy_data::PgPolicyDataRepository,
45    queue::{
46        job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
47        worker::PgQueueWorkerRepository,
48    },
49    upstream_oauth2::{
50        PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
51        PgUpstreamOAuthSessionRepository,
52    },
53    user::{
54        PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
55        PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRepository,
56        PgUserTermsRepository,
57    },
58};
59
60/// An implementation of the [`Repository`] trait backed by a PostgreSQL
61/// transaction.
62pub struct PgRepository<C = Transaction<'static, Postgres>> {
63    conn: C,
64}
65
66impl PgRepository {
67    /// Create a new [`PgRepository`] from a PostgreSQL connection pool,
68    /// starting a transaction.
69    ///
70    /// # Errors
71    ///
72    /// Returns a [`DatabaseError`] if the transaction could not be started.
73    pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
74        let txn = pool.begin().await?;
75        Ok(Self::from_conn(txn))
76    }
77
78    /// Transform the repository into a type-erased [`BoxRepository`]
79    pub fn boxed(self) -> BoxRepository {
80        Box::new(MapErr::new(self, RepositoryError::from_error))
81    }
82}
83
84impl<C> PgRepository<C> {
85    /// Create a new [`PgRepository`] from an existing PostgreSQL connection
86    /// with a transaction
87    pub fn from_conn(conn: C) -> Self {
88        PgRepository { conn }
89    }
90
91    /// Consume this [`PgRepository`], returning the underlying connection.
92    pub fn into_inner(self) -> C {
93        self.conn
94    }
95}
96
97impl<C> AsRef<C> for PgRepository<C> {
98    fn as_ref(&self) -> &C {
99        &self.conn
100    }
101}
102
103impl<C> AsMut<C> for PgRepository<C> {
104    fn as_mut(&mut self) -> &mut C {
105        &mut self.conn
106    }
107}
108
109impl<C> Deref for PgRepository<C> {
110    type Target = C;
111
112    fn deref(&self) -> &Self::Target {
113        &self.conn
114    }
115}
116
117impl<C> DerefMut for PgRepository<C> {
118    fn deref_mut(&mut self) -> &mut Self::Target {
119        &mut self.conn
120    }
121}
122
123impl Repository<DatabaseError> for PgRepository {}
124
125impl RepositoryTransaction for PgRepository {
126    type Error = DatabaseError;
127
128    fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
129        let span = tracing::info_span!("db.save");
130        self.conn
131            .commit()
132            .map_err(DatabaseError::from)
133            .instrument(span)
134            .boxed()
135    }
136
137    fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
138        let span = tracing::info_span!("db.cancel");
139        self.conn
140            .rollback()
141            .map_err(DatabaseError::from)
142            .instrument(span)
143            .boxed()
144    }
145}
146
147impl<C> RepositoryAccess for PgRepository<C>
148where
149    C: AsMut<PgConnection> + Send,
150{
151    type Error = DatabaseError;
152
153    fn upstream_oauth_link<'c>(
154        &'c mut self,
155    ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
156        Box::new(PgUpstreamOAuthLinkRepository::new(self.conn.as_mut()))
157    }
158
159    fn upstream_oauth_provider<'c>(
160        &'c mut self,
161    ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
162        Box::new(PgUpstreamOAuthProviderRepository::new(self.conn.as_mut()))
163    }
164
165    fn upstream_oauth_session<'c>(
166        &'c mut self,
167    ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
168        Box::new(PgUpstreamOAuthSessionRepository::new(self.conn.as_mut()))
169    }
170
171    fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
172        Box::new(PgUserRepository::new(self.conn.as_mut()))
173    }
174
175    fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
176        Box::new(PgUserEmailRepository::new(self.conn.as_mut()))
177    }
178
179    fn user_password<'c>(
180        &'c mut self,
181    ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
182        Box::new(PgUserPasswordRepository::new(self.conn.as_mut()))
183    }
184
185    fn user_recovery<'c>(
186        &'c mut self,
187    ) -> Box<dyn mas_storage::user::UserRecoveryRepository<Error = Self::Error> + 'c> {
188        Box::new(PgUserRecoveryRepository::new(self.conn.as_mut()))
189    }
190
191    fn user_terms<'c>(
192        &'c mut self,
193    ) -> Box<dyn mas_storage::user::UserTermsRepository<Error = Self::Error> + 'c> {
194        Box::new(PgUserTermsRepository::new(self.conn.as_mut()))
195    }
196
197    fn user_registration<'c>(
198        &'c mut self,
199    ) -> Box<dyn mas_storage::user::UserRegistrationRepository<Error = Self::Error> + 'c> {
200        Box::new(PgUserRegistrationRepository::new(self.conn.as_mut()))
201    }
202
203    fn browser_session<'c>(
204        &'c mut self,
205    ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
206        Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
207    }
208
209    fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
210        Box::new(PgAppSessionRepository::new(self.conn.as_mut()))
211    }
212
213    fn oauth2_client<'c>(
214        &'c mut self,
215    ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
216        Box::new(PgOAuth2ClientRepository::new(self.conn.as_mut()))
217    }
218
219    fn oauth2_authorization_grant<'c>(
220        &'c mut self,
221    ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
222        Box::new(PgOAuth2AuthorizationGrantRepository::new(
223            self.conn.as_mut(),
224        ))
225    }
226
227    fn oauth2_session<'c>(
228        &'c mut self,
229    ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
230        Box::new(PgOAuth2SessionRepository::new(self.conn.as_mut()))
231    }
232
233    fn oauth2_access_token<'c>(
234        &'c mut self,
235    ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
236        Box::new(PgOAuth2AccessTokenRepository::new(self.conn.as_mut()))
237    }
238
239    fn oauth2_refresh_token<'c>(
240        &'c mut self,
241    ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
242        Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut()))
243    }
244
245    fn oauth2_device_code_grant<'c>(
246        &'c mut self,
247    ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
248        Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut()))
249    }
250
251    fn compat_session<'c>(
252        &'c mut self,
253    ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
254        Box::new(PgCompatSessionRepository::new(self.conn.as_mut()))
255    }
256
257    fn compat_sso_login<'c>(
258        &'c mut self,
259    ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
260        Box::new(PgCompatSsoLoginRepository::new(self.conn.as_mut()))
261    }
262
263    fn compat_access_token<'c>(
264        &'c mut self,
265    ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
266        Box::new(PgCompatAccessTokenRepository::new(self.conn.as_mut()))
267    }
268
269    fn compat_refresh_token<'c>(
270        &'c mut self,
271    ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
272        Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut()))
273    }
274
275    fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
276        Box::new(PgQueueWorkerRepository::new(self.conn.as_mut()))
277    }
278
279    fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
280        Box::new(PgQueueJobRepository::new(self.conn.as_mut()))
281    }
282
283    fn queue_schedule<'c>(
284        &'c mut self,
285    ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
286        Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
287    }
288
289    fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
290        Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
291    }
292}