1use std::fmt::Display;
12
13use chrono::{DateTime, Utc};
14use futures_util::{Stream, TryStreamExt};
15use sqlx::{Acquire, FromRow, PgConnection, Postgres, Transaction, Type, query};
16use thiserror::Error;
17use thiserror_ext::ContextInto;
18
19pub mod checks;
20pub mod config;
21
22#[derive(Debug, Error, ContextInto)]
23pub enum Error {
24 #[error("database error whilst {context}")]
25 Database {
26 #[source]
27 source: sqlx::Error,
28 context: String,
29 },
30}
31
32#[derive(Clone, Debug, sqlx::Decode, PartialEq, Eq, PartialOrd, Ord)]
33pub struct FullUserId(pub String);
34
35impl Display for FullUserId {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 self.0.fmt(f)
38 }
39}
40
41impl Type<Postgres> for FullUserId {
42 fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
43 <String as Type<Postgres>>::type_info()
44 }
45}
46
47#[derive(Debug, Error)]
48pub enum ExtractLocalpartError {
49 #[error("user ID does not start with `@` sigil")]
50 NoAtSigil,
51 #[error("user ID does not have a `:` separator")]
52 NoSeparator,
53 #[error("wrong server name: expected {expected:?}, got {found:?}")]
54 WrongServerName { expected: String, found: String },
55}
56
57impl FullUserId {
58 pub fn extract_localpart(
69 &self,
70 expected_server_name: &str,
71 ) -> Result<&str, ExtractLocalpartError> {
72 let Some(without_sigil) = self.0.strip_prefix('@') else {
73 return Err(ExtractLocalpartError::NoAtSigil);
74 };
75
76 let Some((localpart, server_name)) = without_sigil.split_once(':') else {
77 return Err(ExtractLocalpartError::NoSeparator);
78 };
79
80 if server_name != expected_server_name {
81 return Err(ExtractLocalpartError::WrongServerName {
82 expected: expected_server_name.to_owned(),
83 found: server_name.to_owned(),
84 });
85 }
86
87 Ok(localpart)
88 }
89}
90
91#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
95pub struct SynapseBool(bool);
96
97impl<'r> sqlx::Decode<'r, Postgres> for SynapseBool {
98 fn decode(
99 value: <Postgres as sqlx::Database>::ValueRef<'r>,
100 ) -> Result<Self, sqlx::error::BoxDynError> {
101 <i16 as sqlx::Decode<Postgres>>::decode(value)
102 .map(|boolean_int| SynapseBool(boolean_int != 0))
103 }
104}
105
106impl sqlx::Type<Postgres> for SynapseBool {
107 fn type_info() -> <Postgres as sqlx::Database>::TypeInfo {
108 <i16 as sqlx::Type<Postgres>>::type_info()
109 }
110}
111
112impl From<SynapseBool> for bool {
113 fn from(SynapseBool(value): SynapseBool) -> Self {
114 value
115 }
116}
117
118#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
122pub struct SecondsTimestamp(DateTime<Utc>);
123
124impl From<SecondsTimestamp> for DateTime<Utc> {
125 fn from(SecondsTimestamp(value): SecondsTimestamp) -> Self {
126 value
127 }
128}
129
130impl<'r> sqlx::Decode<'r, Postgres> for SecondsTimestamp {
131 fn decode(
132 value: <Postgres as sqlx::Database>::ValueRef<'r>,
133 ) -> Result<Self, sqlx::error::BoxDynError> {
134 <i64 as sqlx::Decode<Postgres>>::decode(value).map(|seconds_since_epoch| {
135 SecondsTimestamp(DateTime::from_timestamp_nanos(
136 seconds_since_epoch * 1_000_000_000,
137 ))
138 })
139 }
140}
141
142impl sqlx::Type<Postgres> for SecondsTimestamp {
143 fn type_info() -> <Postgres as sqlx::Database>::TypeInfo {
144 <i64 as sqlx::Type<Postgres>>::type_info()
145 }
146}
147
148#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
151pub struct MillisecondsTimestamp(DateTime<Utc>);
152
153impl From<MillisecondsTimestamp> for DateTime<Utc> {
154 fn from(MillisecondsTimestamp(value): MillisecondsTimestamp) -> Self {
155 value
156 }
157}
158
159impl<'r> sqlx::Decode<'r, Postgres> for MillisecondsTimestamp {
160 fn decode(
161 value: <Postgres as sqlx::Database>::ValueRef<'r>,
162 ) -> Result<Self, sqlx::error::BoxDynError> {
163 <i64 as sqlx::Decode<Postgres>>::decode(value).map(|milliseconds_since_epoch| {
164 MillisecondsTimestamp(DateTime::from_timestamp_nanos(
165 milliseconds_since_epoch * 1_000_000,
166 ))
167 })
168 }
169}
170
171impl sqlx::Type<Postgres> for MillisecondsTimestamp {
172 fn type_info() -> <Postgres as sqlx::Database>::TypeInfo {
173 <i64 as sqlx::Type<Postgres>>::type_info()
174 }
175}
176
177#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
178pub struct SynapseUser {
179 pub name: FullUserId,
181 pub password_hash: Option<String>,
184 pub admin: SynapseBool,
186 pub deactivated: SynapseBool,
188 pub locked: bool,
190 pub creation_ts: SecondsTimestamp,
192 pub is_guest: SynapseBool,
196 pub appservice_id: Option<String>,
198}
199
200#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
202pub struct SynapseThreepid {
203 pub user_id: FullUserId,
204 pub medium: String,
205 pub address: String,
206 pub added_at: MillisecondsTimestamp,
207}
208
209#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
211pub struct SynapseExternalId {
212 pub user_id: FullUserId,
213 pub auth_provider: String,
214 pub external_id: String,
215}
216
217#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
219pub struct SynapseDevice {
220 pub user_id: FullUserId,
221 pub device_id: String,
222 pub display_name: Option<String>,
223 pub last_seen: Option<MillisecondsTimestamp>,
224 pub ip: Option<String>,
225 pub user_agent: Option<String>,
226}
227
228#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
230pub struct SynapseAccessToken {
231 pub user_id: FullUserId,
232 pub device_id: Option<String>,
233 pub token: String,
234 pub valid_until_ms: Option<MillisecondsTimestamp>,
235 pub last_validated: Option<MillisecondsTimestamp>,
236}
237
238#[derive(Clone, Debug, FromRow, PartialEq, Eq, PartialOrd, Ord)]
240pub struct SynapseRefreshableTokenPair {
241 pub user_id: FullUserId,
242 pub device_id: String,
243 pub access_token: String,
244 pub refresh_token: String,
245 pub valid_until_ms: Option<MillisecondsTimestamp>,
246 pub last_validated: Option<MillisecondsTimestamp>,
247}
248
249const TABLES_TO_LOCK: &[&str] = &[
255 "users",
256 "user_threepids",
257 "user_external_ids",
258 "devices",
259 "access_tokens",
260 "refresh_tokens",
261];
262
263#[derive(Clone, Debug)]
266pub struct SynapseRowCounts {
267 pub users: usize,
268 pub devices: usize,
269 pub threepids: usize,
270 pub external_ids: usize,
271 pub access_tokens: usize,
272 pub refresh_tokens: usize,
273}
274
275pub struct SynapseReader<'c> {
276 txn: Transaction<'c, Postgres>,
277}
278
279impl<'conn> SynapseReader<'conn> {
280 pub async fn new(
291 synapse_connection: &'conn mut PgConnection,
292 dry_run: bool,
293 ) -> Result<Self, Error> {
294 let mut txn = synapse_connection
295 .begin()
296 .await
297 .into_database("begin transaction")?;
298
299 query("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE READ ONLY DEFERRABLE;")
300 .execute(&mut *txn)
301 .await
302 .into_database("set transaction")?;
303
304 let lock_type = if dry_run {
305 "ACCESS SHARE"
308 } else {
309 "EXCLUSIVE"
310 };
311 for table in TABLES_TO_LOCK {
312 query(&format!("LOCK TABLE {table} IN {lock_type} MODE NOWAIT;"))
313 .execute(&mut *txn)
314 .await
315 .into_database_with(|| format!("locking Synapse table `{table}`"))?;
316 }
317
318 Ok(Self { txn })
319 }
320
321 pub async fn finish(self) -> Result<(), Error> {
329 self.txn.commit().await.into_database("end transaction")?;
330 Ok(())
331 }
332
333 pub async fn count_rows(&mut self) -> Result<SynapseRowCounts, Error> {
342 let users = sqlx::query_scalar::<_, i64>(
348 "
349 SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'users'::regclass;
350 ",
351 )
352 .fetch_one(&mut *self.txn)
353 .await
354 .into_database("estimating count of users")?
355 .max(0)
356 .try_into()
357 .unwrap_or(usize::MAX);
358
359 let devices = sqlx::query_scalar::<_, i64>(
360 "
361 SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'devices'::regclass;
362 ",
363 )
364 .fetch_one(&mut *self.txn)
365 .await
366 .into_database("estimating count of devices")?
367 .max(0)
368 .try_into()
369 .unwrap_or(usize::MAX);
370
371 let threepids = sqlx::query_scalar::<_, i64>(
372 "
373 SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'user_threepids'::regclass;
374 "
375 )
376 .fetch_one(&mut *self.txn)
377 .await
378 .into_database("estimating count of threepids")?
379 .max(0)
380 .try_into()
381 .unwrap_or(usize::MAX);
382
383 let access_tokens = sqlx::query_scalar::<_, i64>(
384 "
385 SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'access_tokens'::regclass;
386 "
387 )
388 .fetch_one(&mut *self.txn)
389 .await
390 .into_database("estimating count of access tokens")?
391 .max(0)
392 .try_into()
393 .unwrap_or(usize::MAX);
394
395 let refresh_tokens = sqlx::query_scalar::<_, i64>(
396 "
397 SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'refresh_tokens'::regclass;
398 "
399 )
400 .fetch_one(&mut *self.txn)
401 .await
402 .into_database("estimating count of refresh tokens")?
403 .max(0)
404 .try_into()
405 .unwrap_or(usize::MAX);
406
407 let external_ids = sqlx::query_scalar::<_, i64>(
408 "
409 SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'user_external_ids'::regclass;
410 "
411 )
412 .fetch_one(&mut *self.txn)
413 .await
414 .into_database("estimating count of external IDs")?
415 .max(0)
416 .try_into()
417 .unwrap_or(usize::MAX);
418
419 Ok(SynapseRowCounts {
420 users,
421 devices,
422 threepids,
423 external_ids,
424 access_tokens,
425 refresh_tokens,
426 })
427 }
428
429 pub fn read_users(&mut self) -> impl Stream<Item = Result<SynapseUser, Error>> + '_ {
432 sqlx::query_as(
433 "
434 SELECT
435 name, password_hash, admin, deactivated, locked, creation_ts, is_guest, appservice_id
436 FROM users
437 ",
438 )
439 .fetch(&mut *self.txn)
440 .map_err(|err| err.into_database("reading Synapse users"))
441 }
442
443 pub fn read_threepids(&mut self) -> impl Stream<Item = Result<SynapseThreepid, Error>> + '_ {
446 sqlx::query_as(
447 "
448 SELECT
449 user_id, medium, address, added_at
450 FROM user_threepids
451 ",
452 )
453 .fetch(&mut *self.txn)
454 .map_err(|err| err.into_database("reading Synapse threepids"))
455 }
456
457 pub fn read_user_external_ids(
459 &mut self,
460 ) -> impl Stream<Item = Result<SynapseExternalId, Error>> + '_ {
461 sqlx::query_as(
462 "
463 SELECT
464 user_id, auth_provider, external_id
465 FROM user_external_ids
466 ",
467 )
468 .fetch(&mut *self.txn)
469 .map_err(|err| err.into_database("reading Synapse user external IDs"))
470 }
471
472 pub fn read_devices(&mut self) -> impl Stream<Item = Result<SynapseDevice, Error>> + '_ {
476 sqlx::query_as(
477 "
478 SELECT
479 user_id, device_id, display_name, last_seen, ip, user_agent
480 FROM devices
481 WHERE NOT hidden AND device_id != 'guest_device'
482 ",
483 )
484 .fetch(&mut *self.txn)
485 .map_err(|err| err.into_database("reading Synapse devices"))
486 }
487
488 pub fn read_unrefreshable_access_tokens(
498 &mut self,
499 ) -> impl Stream<Item = Result<SynapseAccessToken, Error>> + '_ {
500 sqlx::query_as(
501 "
502 SELECT
503 at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated
504 FROM access_tokens at0
505 INNER JOIN devices USING (user_id, device_id)
506 WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL
507
508 UNION ALL
509
510 SELECT
511 at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated
512 FROM access_tokens at0
513 WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL AND at0.device_id IS NULL
514 ",
515 )
516 .fetch(&mut *self.txn)
517 .map_err(|err| err.into_database("reading Synapse access tokens"))
518 }
519
520 pub fn read_refreshable_token_pairs(
530 &mut self,
531 ) -> impl Stream<Item = Result<SynapseRefreshableTokenPair, Error>> + '_ {
532 sqlx::query_as(
533 "
534 SELECT
535 rt0.user_id, rt0.device_id, at0.token AS access_token, rt0.token AS refresh_token, at0.valid_until_ms, at0.last_validated
536 FROM refresh_tokens rt0
537 INNER JOIN devices USING (user_id, device_id)
538 INNER JOIN access_tokens at0 ON at0.refresh_token_id = rt0.id AND at0.user_id = rt0.user_id AND at0.device_id = rt0.device_id
539 LEFT JOIN access_tokens at1 ON at1.refresh_token_id = rt0.next_token_id
540 WHERE NOT at1.used OR at1.used IS NULL
541 ",
542 )
543 .fetch(&mut *self.txn)
544 .map_err(|err| err.into_database("reading Synapse refresh tokens"))
545 }
546}
547
548#[cfg(test)]
549mod test {
550 use std::collections::BTreeSet;
551
552 use futures_util::TryStreamExt;
553 use insta::assert_debug_snapshot;
554 use sqlx::{PgPool, migrate::Migrator};
555
556 use crate::{
557 SynapseReader,
558 synapse_reader::{
559 SynapseAccessToken, SynapseDevice, SynapseExternalId, SynapseRefreshableTokenPair,
560 SynapseThreepid, SynapseUser,
561 },
562 };
563
564 static MIGRATOR: Migrator = sqlx::migrate!("./test_synapse_migrations");
565
566 #[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice"))]
567 async fn test_read_users(pool: PgPool) {
568 let mut conn = pool.acquire().await.expect("failed to get connection");
569 let mut reader = SynapseReader::new(&mut conn, false)
570 .await
571 .expect("failed to make SynapseReader");
572
573 let users: BTreeSet<SynapseUser> = reader
574 .read_users()
575 .try_collect()
576 .await
577 .expect("failed to read Synapse users");
578
579 assert_debug_snapshot!(users);
580 }
581
582 #[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "threepids_alice"))]
583 async fn test_read_threepids(pool: PgPool) {
584 let mut conn = pool.acquire().await.expect("failed to get connection");
585 let mut reader = SynapseReader::new(&mut conn, false)
586 .await
587 .expect("failed to make SynapseReader");
588
589 let threepids: BTreeSet<SynapseThreepid> = reader
590 .read_threepids()
591 .try_collect()
592 .await
593 .expect("failed to read Synapse threepids");
594
595 assert_debug_snapshot!(threepids);
596 }
597
598 #[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "external_ids_alice"))]
599 async fn test_read_external_ids(pool: PgPool) {
600 let mut conn = pool.acquire().await.expect("failed to get connection");
601 let mut reader = SynapseReader::new(&mut conn, false)
602 .await
603 .expect("failed to make SynapseReader");
604
605 let external_ids: BTreeSet<SynapseExternalId> = reader
606 .read_user_external_ids()
607 .try_collect()
608 .await
609 .expect("failed to read Synapse external user IDs");
610
611 assert_debug_snapshot!(external_ids);
612 }
613
614 #[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "devices_alice"))]
615 async fn test_read_devices(pool: PgPool) {
616 let mut conn = pool.acquire().await.expect("failed to get connection");
617 let mut reader = SynapseReader::new(&mut conn, false)
618 .await
619 .expect("failed to make SynapseReader");
620
621 let devices: BTreeSet<SynapseDevice> = reader
622 .read_devices()
623 .try_collect()
624 .await
625 .expect("failed to read Synapse devices");
626
627 assert_debug_snapshot!(devices);
628 }
629
630 #[sqlx::test(
631 migrator = "MIGRATOR",
632 fixtures("user_alice", "devices_alice", "access_token_alice")
633 )]
634 async fn test_read_access_token(pool: PgPool) {
635 let mut conn = pool.acquire().await.expect("failed to get connection");
636 let mut reader = SynapseReader::new(&mut conn, false)
637 .await
638 .expect("failed to make SynapseReader");
639
640 let access_tokens: BTreeSet<SynapseAccessToken> = reader
641 .read_unrefreshable_access_tokens()
642 .try_collect()
643 .await
644 .expect("failed to read Synapse access tokens");
645
646 assert_debug_snapshot!(access_tokens);
647 }
648
649 #[sqlx::test(
651 migrator = "MIGRATOR",
652 fixtures("user_alice", "devices_alice", "access_token_alice_with_puppet")
653 )]
654 async fn test_read_access_token_puppet(pool: PgPool) {
655 let mut conn = pool.acquire().await.expect("failed to get connection");
656 let mut reader = SynapseReader::new(&mut conn, false)
657 .await
658 .expect("failed to make SynapseReader");
659
660 let access_tokens: BTreeSet<SynapseAccessToken> = reader
661 .read_unrefreshable_access_tokens()
662 .try_collect()
663 .await
664 .expect("failed to read Synapse access tokens");
665
666 assert!(access_tokens.is_empty());
667 }
668
669 #[sqlx::test(
670 migrator = "MIGRATOR",
671 fixtures("user_alice", "devices_alice", "access_token_alice_with_refresh_token")
672 )]
673 async fn test_read_access_and_refresh_tokens(pool: PgPool) {
674 let mut conn = pool.acquire().await.expect("failed to get connection");
675 let mut reader = SynapseReader::new(&mut conn, false)
676 .await
677 .expect("failed to make SynapseReader");
678
679 let access_tokens: BTreeSet<SynapseAccessToken> = reader
680 .read_unrefreshable_access_tokens()
681 .try_collect()
682 .await
683 .expect("failed to read Synapse access tokens");
684
685 let refresh_tokens: BTreeSet<SynapseRefreshableTokenPair> = reader
686 .read_refreshable_token_pairs()
687 .try_collect()
688 .await
689 .expect("failed to read Synapse refresh tokens");
690
691 assert!(
692 access_tokens.is_empty(),
693 "there are no unrefreshable access tokens"
694 );
695 assert_debug_snapshot!(refresh_tokens);
696 }
697
698 #[sqlx::test(
699 migrator = "MIGRATOR",
700 fixtures(
701 "user_alice",
702 "devices_alice",
703 "access_token_alice_with_unused_refresh_token"
704 )
705 )]
706 async fn test_read_access_and_unused_refresh_tokens(pool: PgPool) {
707 let mut conn = pool.acquire().await.expect("failed to get connection");
708 let mut reader = SynapseReader::new(&mut conn, false)
709 .await
710 .expect("failed to make SynapseReader");
711
712 let access_tokens: BTreeSet<SynapseAccessToken> = reader
713 .read_unrefreshable_access_tokens()
714 .try_collect()
715 .await
716 .expect("failed to read Synapse access tokens");
717
718 let refresh_tokens: BTreeSet<SynapseRefreshableTokenPair> = reader
719 .read_refreshable_token_pairs()
720 .try_collect()
721 .await
722 .expect("failed to read Synapse refresh tokens");
723
724 assert!(
725 access_tokens.is_empty(),
726 "there are no unrefreshable access tokens"
727 );
728 assert_debug_snapshot!(refresh_tokens);
729 }
730}