syn2mas/mas_writer/
mod.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6//! # MAS Writer
7//!
8//! This module is responsible for writing new records to MAS' database.
9
10use std::{
11    fmt::Display,
12    net::IpAddr,
13    sync::{
14        Arc,
15        atomic::{AtomicU32, Ordering},
16    },
17};
18
19use chrono::{DateTime, Utc};
20use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
21use sqlx::{Executor, PgConnection, query, query_as};
22use thiserror::Error;
23use thiserror_ext::{Construct, ContextInto};
24use tokio::sync::mpsc::{self, Receiver, Sender};
25use tracing::{Instrument, Level, error, info, warn};
26use uuid::{NonNilUuid, Uuid};
27
28use self::{
29    constraint_pausing::{ConstraintDescription, IndexDescription},
30    locking::LockedMasDatabase,
31};
32use crate::Progress;
33
34pub mod checks;
35pub mod locking;
36
37mod constraint_pausing;
38
39#[derive(Debug, Error, Construct, ContextInto)]
40pub enum Error {
41    #[error("database error whilst {context}")]
42    Database {
43        #[source]
44        source: sqlx::Error,
45        context: String,
46    },
47
48    #[error("writer connection pool shut down due to error")]
49    #[allow(clippy::enum_variant_names)]
50    WriterConnectionPoolError,
51
52    #[error("inconsistent database: {0}")]
53    Inconsistent(String),
54
55    #[error("bug in syn2mas: write buffers not finished")]
56    WriteBuffersNotFinished,
57
58    #[error("{0}")]
59    Multiple(MultipleErrors),
60}
61
62#[derive(Debug)]
63pub struct MultipleErrors {
64    errors: Vec<Error>,
65}
66
67impl Display for MultipleErrors {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "multiple errors")?;
70        for error in &self.errors {
71            write!(f, "\n- {error}")?;
72        }
73        Ok(())
74    }
75}
76
77impl From<Vec<Error>> for MultipleErrors {
78    fn from(value: Vec<Error>) -> Self {
79        MultipleErrors { errors: value }
80    }
81}
82
83struct WriterConnectionPool {
84    /// How many connections are in circulation
85    num_connections: usize,
86
87    /// A receiver handle to get a writer connection
88    /// The writer connection will be mid-transaction!
89    connection_rx: Receiver<Result<PgConnection, Error>>,
90
91    /// A sender handle to return a writer connection to the pool
92    /// The connection should still be mid-transaction!
93    connection_tx: Sender<Result<PgConnection, Error>>,
94}
95
96impl WriterConnectionPool {
97    pub fn new(connections: Vec<PgConnection>) -> Self {
98        let num_connections = connections.len();
99        let (connection_tx, connection_rx) = mpsc::channel(num_connections);
100        for connection in connections {
101            connection_tx
102                .try_send(Ok(connection))
103                .expect("there should be room for this connection");
104        }
105
106        WriterConnectionPool {
107            num_connections,
108            connection_rx,
109            connection_tx,
110        }
111    }
112
113    pub async fn spawn_with_connection<F>(&mut self, task: F) -> Result<(), Error>
114    where
115        F: for<'conn> FnOnce(&'conn mut PgConnection) -> BoxFuture<'conn, Result<(), Error>>
116            + Send
117            + Sync
118            + 'static,
119    {
120        match self.connection_rx.recv().await {
121            Some(Ok(mut connection)) => {
122                let connection_tx = self.connection_tx.clone();
123                tokio::task::spawn(
124                    async move {
125                        let to_return = match task(&mut connection).await {
126                            Ok(()) => Ok(connection),
127                            Err(error) => {
128                                error!("error in writer: {error}");
129                                Err(error)
130                            }
131                        };
132                        // This should always succeed in sending unless we're already shutting
133                        // down for some other reason.
134                        let _: Result<_, _> = connection_tx.send(to_return).await;
135                    }
136                    .instrument(tracing::debug_span!("spawn_with_connection")),
137                );
138
139                Ok(())
140            }
141            Some(Err(error)) => {
142                // This should always succeed in sending unless we're already shutting
143                // down for some other reason.
144                let _: Result<_, _> = self.connection_tx.send(Err(error)).await;
145
146                Err(Error::WriterConnectionPoolError)
147            }
148            None => {
149                unreachable!("we still hold a reference to the sender, so this shouldn't happen")
150            }
151        }
152    }
153
154    /// Finishes writing to the database, committing all changes.
155    ///
156    /// # Errors
157    ///
158    /// - If any errors were returned to the pool.
159    /// - If committing the changes failed.
160    ///
161    /// # Panics
162    ///
163    /// - If connections were not returned to the pool. (This indicates a
164    ///   serious bug.)
165    pub async fn finish(self) -> Result<(), Vec<Error>> {
166        let mut errors = Vec::new();
167
168        let Self {
169            num_connections,
170            mut connection_rx,
171            connection_tx,
172        } = self;
173        // Drop the sender handle so we gracefully allow the receiver to close
174        drop(connection_tx);
175
176        let mut finished_connections = 0;
177
178        while let Some(connection_or_error) = connection_rx.recv().await {
179            finished_connections += 1;
180
181            match connection_or_error {
182                Ok(mut connection) => {
183                    if let Err(err) = query("COMMIT;").execute(&mut connection).await {
184                        errors.push(err.into_database("commit writer transaction"));
185                    }
186                }
187                Err(error) => {
188                    errors.push(error);
189                }
190            }
191        }
192        assert_eq!(
193            finished_connections, num_connections,
194            "syn2mas had a bug: connections went missing {finished_connections} != {num_connections}"
195        );
196
197        if errors.is_empty() {
198            Ok(())
199        } else {
200            Err(errors)
201        }
202    }
203}
204
205/// Small utility to make sure `finish()` is called on all write buffers
206/// before committing to the database.
207#[derive(Default)]
208struct FinishChecker {
209    counter: Arc<AtomicU32>,
210}
211
212struct FinishCheckerHandle {
213    counter: Arc<AtomicU32>,
214}
215
216impl FinishChecker {
217    /// Acquire a new handle, for a task that should declare when it has
218    /// finished.
219    pub fn handle(&self) -> FinishCheckerHandle {
220        self.counter.fetch_add(1, Ordering::SeqCst);
221        FinishCheckerHandle {
222            counter: Arc::clone(&self.counter),
223        }
224    }
225
226    /// Check that all handles have been declared as finished.
227    pub fn check_all_finished(self) -> Result<(), Error> {
228        if self.counter.load(Ordering::SeqCst) == 0 {
229            Ok(())
230        } else {
231            Err(Error::WriteBuffersNotFinished)
232        }
233    }
234}
235
236impl FinishCheckerHandle {
237    /// Declare that the task this handle represents has been finished.
238    pub fn declare_finished(self) {
239        self.counter.fetch_sub(1, Ordering::SeqCst);
240    }
241}
242
243pub struct MasWriter {
244    conn: LockedMasDatabase,
245    writer_pool: WriterConnectionPool,
246
247    indices_to_restore: Vec<IndexDescription>,
248    constraints_to_restore: Vec<ConstraintDescription>,
249
250    write_buffer_finish_checker: FinishChecker,
251}
252
253pub struct MasNewUser {
254    pub user_id: NonNilUuid,
255    pub username: String,
256    pub created_at: DateTime<Utc>,
257    pub locked_at: Option<DateTime<Utc>>,
258    pub deactivated_at: Option<DateTime<Utc>>,
259    pub can_request_admin: bool,
260    /// Whether the user was a Synapse guest.
261    /// Although MAS doesn't support guest access, it's still useful to track
262    /// for the future.
263    pub is_guest: bool,
264}
265
266pub struct MasNewUserPassword {
267    pub user_password_id: Uuid,
268    pub user_id: NonNilUuid,
269    pub hashed_password: String,
270    pub created_at: DateTime<Utc>,
271}
272
273pub struct MasNewEmailThreepid {
274    pub user_email_id: Uuid,
275    pub user_id: NonNilUuid,
276    pub email: String,
277    pub created_at: DateTime<Utc>,
278}
279
280pub struct MasNewUnsupportedThreepid {
281    pub user_id: NonNilUuid,
282    pub medium: String,
283    pub address: String,
284    pub created_at: DateTime<Utc>,
285}
286
287pub struct MasNewUpstreamOauthLink {
288    pub link_id: Uuid,
289    pub user_id: NonNilUuid,
290    pub upstream_provider_id: Uuid,
291    pub subject: String,
292    pub created_at: DateTime<Utc>,
293}
294
295pub struct MasNewCompatSession {
296    pub session_id: Uuid,
297    pub user_id: NonNilUuid,
298    pub device_id: Option<String>,
299    pub human_name: Option<String>,
300    pub created_at: DateTime<Utc>,
301    pub is_synapse_admin: bool,
302    pub last_active_at: Option<DateTime<Utc>>,
303    pub last_active_ip: Option<IpAddr>,
304    pub user_agent: Option<String>,
305}
306
307pub struct MasNewCompatAccessToken {
308    pub token_id: Uuid,
309    pub session_id: Uuid,
310    pub access_token: String,
311    pub created_at: DateTime<Utc>,
312    pub expires_at: Option<DateTime<Utc>>,
313}
314
315pub struct MasNewCompatRefreshToken {
316    pub refresh_token_id: Uuid,
317    pub session_id: Uuid,
318    pub access_token_id: Uuid,
319    pub refresh_token: String,
320    pub created_at: DateTime<Utc>,
321}
322
323/// The 'version' of the password hashing scheme used for passwords when they
324/// are migrated from Synapse to MAS.
325/// This is version 1, as in the previous syn2mas script.
326// TODO hardcoding version to `1` may not be correct long-term?
327pub const MIGRATED_PASSWORD_VERSION: u16 = 1;
328
329/// List of all MAS tables that are written to by syn2mas.
330pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
331    "users",
332    "user_passwords",
333    "user_emails",
334    "user_unsupported_third_party_ids",
335    "upstream_oauth_links",
336    "compat_sessions",
337    "compat_access_tokens",
338    "compat_refresh_tokens",
339];
340
341/// Detect whether a syn2mas migration has started on the given database.
342///
343/// Concretly, this checks for the presence of syn2mas restoration tables.
344///
345/// Returns `true` if syn2mas has started, or `false` if it hasn't.
346///
347/// # Errors
348///
349/// Errors are returned under the following circumstances:
350///
351/// - If any database error occurs whilst querying the database.
352/// - If some, but not all, syn2mas restoration tables are present. (This
353///   shouldn't be possible without syn2mas having been sabotaged!)
354pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result<bool, Error> {
355    // Names of tables used for syn2mas resumption
356    // Must be `String`s, not just `&str`, for the query.
357    let restore_table_names = vec![
358        "syn2mas_restore_constraints".to_owned(),
359        "syn2mas_restore_indices".to_owned(),
360    ];
361
362    let num_resumption_tables = query!(
363        r#"
364        SELECT 1 AS _dummy FROM pg_tables WHERE schemaname = current_schema
365        AND tablename = ANY($1)
366        "#,
367        &restore_table_names,
368    )
369    .fetch_all(conn.as_mut())
370    .await
371    .into_database("failed to query count of resumption tables")?
372    .len();
373
374    if num_resumption_tables == 0 {
375        Ok(false)
376    } else if num_resumption_tables == restore_table_names.len() {
377        Ok(true)
378    } else {
379        Err(Error::inconsistent(
380            "some, but not all, syn2mas resumption tables were found",
381        ))
382    }
383}
384
385impl MasWriter {
386    /// Creates a new MAS writer.
387    ///
388    /// # Errors
389    ///
390    /// Errors are returned in the following conditions:
391    ///
392    /// - If the database connection experiences an error.
393    #[allow(clippy::missing_panics_doc)] // not real
394    #[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)]
395    pub async fn new(
396        mut conn: LockedMasDatabase,
397        mut writer_connections: Vec<PgConnection>,
398    ) -> Result<Self, Error> {
399        // Given that we don't have any concurrent transactions here,
400        // the READ COMMITTED isolation level is sufficient.
401        query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
402            .execute(conn.as_mut())
403            .await
404            .into_database("begin MAS transaction")?;
405
406        let syn2mas_started = is_syn2mas_in_progress(conn.as_mut()).await?;
407
408        let indices_to_restore;
409        let constraints_to_restore;
410
411        if syn2mas_started {
412            // We are resuming from a partially-done syn2mas migration
413            // We should reset the database so that we're starting from scratch.
414            warn!("Partial syn2mas migration has already been done; resetting.");
415            for table in MAS_TABLES_AFFECTED_BY_MIGRATION {
416                query(&format!("TRUNCATE syn2mas__{table};"))
417                    .execute(conn.as_mut())
418                    .await
419                    .into_database_with(|| format!("failed to truncate table syn2mas__{table}"))?;
420            }
421
422            indices_to_restore = query_as!(
423                IndexDescription,
424                "SELECT table_name, name, definition FROM syn2mas_restore_indices ORDER BY order_key"
425            )
426                .fetch_all(conn.as_mut())
427                .await
428                .into_database("failed to get syn2mas restore data (index descriptions)")?;
429            constraints_to_restore = query_as!(
430                ConstraintDescription,
431                "SELECT table_name, name, definition FROM syn2mas_restore_constraints ORDER BY order_key"
432            )
433                .fetch_all(conn.as_mut())
434                .await
435                .into_database("failed to get syn2mas restore data (constraint descriptions)")?;
436        } else {
437            info!("Starting new syn2mas migration");
438
439            conn.as_mut()
440                .execute_many(include_str!("syn2mas_temporary_tables.sql"))
441                // We don't care about any query results
442                .try_collect::<Vec<_>>()
443                .await
444                .into_database("could not create temporary tables")?;
445
446            // Pause (temporarily drop) indices and constraints in order to improve
447            // performance of bulk data loading.
448            (indices_to_restore, constraints_to_restore) =
449                Self::pause_indices(conn.as_mut()).await?;
450
451            // Persist these index and constraint definitions.
452            for IndexDescription {
453                name,
454                table_name,
455                definition,
456            } in &indices_to_restore
457            {
458                query!(
459                    r#"
460                    INSERT INTO syn2mas_restore_indices (name, table_name, definition)
461                    VALUES ($1, $2, $3)
462                    "#,
463                    name,
464                    table_name,
465                    definition
466                )
467                .execute(conn.as_mut())
468                .await
469                .into_database("failed to save restore data (index)")?;
470            }
471            for ConstraintDescription {
472                name,
473                table_name,
474                definition,
475            } in &constraints_to_restore
476            {
477                query!(
478                    r#"
479                    INSERT INTO syn2mas_restore_constraints (name, table_name, definition)
480                    VALUES ($1, $2, $3)
481                    "#,
482                    name,
483                    table_name,
484                    definition
485                )
486                .execute(conn.as_mut())
487                .await
488                .into_database("failed to save restore data (index)")?;
489            }
490        }
491
492        query("COMMIT;")
493            .execute(conn.as_mut())
494            .await
495            .into_database("begin MAS transaction")?;
496
497        // Now after all the schema changes have been done, begin writer transactions
498        for writer_connection in &mut writer_connections {
499            query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
500                .execute(&mut *writer_connection)
501                .await
502                .into_database("begin MAS writer transaction")?;
503        }
504
505        Ok(Self {
506            conn,
507
508            writer_pool: WriterConnectionPool::new(writer_connections),
509            indices_to_restore,
510            constraints_to_restore,
511            write_buffer_finish_checker: FinishChecker::default(),
512        })
513    }
514
515    #[tracing::instrument(skip_all)]
516    async fn pause_indices(
517        conn: &mut PgConnection,
518    ) -> Result<(Vec<IndexDescription>, Vec<ConstraintDescription>), Error> {
519        let mut indices_to_restore = Vec::new();
520        let mut constraints_to_restore = Vec::new();
521
522        for &unprefixed_table in MAS_TABLES_AFFECTED_BY_MIGRATION {
523            let table = format!("syn2mas__{unprefixed_table}");
524            // First drop incoming foreign key constraints
525            for constraint in
526                constraint_pausing::describe_foreign_key_constraints_to_table(&mut *conn, &table)
527                    .await?
528            {
529                constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
530                constraints_to_restore.push(constraint);
531            }
532            // After all incoming foreign key constraints have been removed,
533            // we can now drop internal constraints.
534            for constraint in
535                constraint_pausing::describe_constraints_on_table(&mut *conn, &table).await?
536            {
537                constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
538                constraints_to_restore.push(constraint);
539            }
540            // After all constraints have been removed, we can drop indices.
541            for index in constraint_pausing::describe_indices_on_table(&mut *conn, &table).await? {
542                constraint_pausing::drop_index(&mut *conn, &index).await?;
543                indices_to_restore.push(index);
544            }
545        }
546
547        Ok((indices_to_restore, constraints_to_restore))
548    }
549
550    async fn restore_indices(
551        conn: &mut LockedMasDatabase,
552        indices_to_restore: &[IndexDescription],
553        constraints_to_restore: &[ConstraintDescription],
554        progress: &Progress,
555    ) -> Result<(), Error> {
556        // First restore all indices. The order is not important as far as I know.
557        // However the indices are needed before constraints.
558        for index in indices_to_restore.iter().rev() {
559            progress.rebuild_index(index.name.clone());
560            constraint_pausing::restore_index(conn.as_mut(), index).await?;
561        }
562        // Then restore all constraints.
563        // The order here is the reverse of drop order, since some constraints may rely
564        // on other constraints to work.
565        for constraint in constraints_to_restore.iter().rev() {
566            progress.rebuild_constraint(constraint.name.clone());
567            constraint_pausing::restore_constraint(conn.as_mut(), constraint).await?;
568        }
569        Ok(())
570    }
571
572    /// Finish writing to the MAS database, flushing and committing all changes.
573    /// It returns the unlocked underlying connection.
574    ///
575    /// # Errors
576    ///
577    /// Errors are returned in the following conditions:
578    ///
579    /// - If the database connection experiences an error.
580    #[tracing::instrument(skip_all)]
581    pub async fn finish(mut self, progress: &Progress) -> Result<PgConnection, Error> {
582        self.write_buffer_finish_checker.check_all_finished()?;
583
584        // Commit all writer transactions to the database.
585        self.writer_pool
586            .finish()
587            .await
588            .map_err(|errors| Error::Multiple(MultipleErrors::from(errors)))?;
589
590        // Now all the data has been migrated, finish off by restoring indices and
591        // constraints!
592
593        query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
594            .execute(self.conn.as_mut())
595            .await
596            .into_database("begin MAS transaction")?;
597
598        Self::restore_indices(
599            &mut self.conn,
600            &self.indices_to_restore,
601            &self.constraints_to_restore,
602            progress,
603        )
604        .await?;
605
606        self.conn
607            .as_mut()
608            .execute_many(include_str!("syn2mas_revert_temporary_tables.sql"))
609            // We don't care about any query results
610            .try_collect::<Vec<_>>()
611            .await
612            .into_database("could not revert temporary tables")?;
613
614        query("COMMIT;")
615            .execute(self.conn.as_mut())
616            .await
617            .into_database("ending MAS transaction")?;
618
619        let conn = self
620            .conn
621            .unlock()
622            .await
623            .into_database("could not unlock MAS database")?;
624
625        Ok(conn)
626    }
627
628    /// Write a batch of users to the database.
629    ///
630    /// # Errors
631    ///
632    /// Errors are returned in the following conditions:
633    ///
634    /// - If the database writer connection pool had an error.
635    #[allow(clippy::missing_panics_doc)] // not a real panic
636    #[tracing::instrument(skip_all, level = Level::DEBUG)]
637    pub fn write_users(&mut self, users: Vec<MasNewUser>) -> BoxFuture<'_, Result<(), Error>> {
638        self.writer_pool
639            .spawn_with_connection(move |conn| {
640                Box::pin(async move {
641                    // `UNNEST` is a fast way to do bulk inserts, as it lets us send multiple rows
642                    // in one statement without having to change the statement
643                    // SQL thus altering the query plan. See <https://github.com/launchbadge/sqlx/blob/main/FAQ.md#how-can-i-bind-an-array-to-a-values-clause-how-can-i-do-bulk-inserts>.
644                    // In the future we could consider using sqlx's support for `PgCopyIn` / the
645                    // `COPY FROM STDIN` statement, which is allegedly the best
646                    // for insert performance, but is less simple to encode.
647                    let mut user_ids: Vec<Uuid> = Vec::with_capacity(users.len());
648                    let mut usernames: Vec<String> = Vec::with_capacity(users.len());
649                    let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(users.len());
650                    let mut locked_ats: Vec<Option<DateTime<Utc>>> =
651                        Vec::with_capacity(users.len());
652                    let mut deactivated_ats: Vec<Option<DateTime<Utc>>> =
653                        Vec::with_capacity(users.len());
654                    let mut can_request_admins: Vec<bool> = Vec::with_capacity(users.len());
655                    let mut is_guests: Vec<bool> = Vec::with_capacity(users.len());
656                    for MasNewUser {
657                        user_id,
658                        username,
659                        created_at,
660                        locked_at,
661                        deactivated_at,
662                        can_request_admin,
663                        is_guest,
664                    } in users
665                    {
666                        user_ids.push(user_id.get());
667                        usernames.push(username);
668                        created_ats.push(created_at);
669                        locked_ats.push(locked_at);
670                        deactivated_ats.push(deactivated_at);
671                        can_request_admins.push(can_request_admin);
672                        is_guests.push(is_guest);
673                    }
674
675                    sqlx::query!(
676                        r#"
677                        INSERT INTO syn2mas__users (
678                          user_id, username,
679                          created_at, locked_at,
680                          deactivated_at,
681                          can_request_admin, is_guest)
682                        SELECT * FROM UNNEST(
683                          $1::UUID[], $2::TEXT[],
684                          $3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],
685                          $5::TIMESTAMP WITH TIME ZONE[],
686                          $6::BOOL[], $7::BOOL[])
687                        "#,
688                        &user_ids[..],
689                        &usernames[..],
690                        &created_ats[..],
691                        // We need to override the typing for arrays of optionals (sqlx limitation)
692                        &locked_ats[..] as &[Option<DateTime<Utc>>],
693                        &deactivated_ats[..] as &[Option<DateTime<Utc>>],
694                        &can_request_admins[..],
695                        &is_guests[..],
696                    )
697                    .execute(&mut *conn)
698                    .await
699                    .into_database("writing users to MAS")?;
700
701                    Ok(())
702                })
703            })
704            .boxed()
705    }
706
707    /// Write a batch of user passwords to the database.
708    ///
709    /// # Errors
710    ///
711    /// Errors are returned in the following conditions:
712    ///
713    /// - If the database writer connection pool had an error.
714    #[allow(clippy::missing_panics_doc)] // not a real panic
715    #[tracing::instrument(skip_all, level = Level::DEBUG)]
716    pub fn write_passwords(
717        &mut self,
718        passwords: Vec<MasNewUserPassword>,
719    ) -> BoxFuture<'_, Result<(), Error>> {
720        self.writer_pool.spawn_with_connection(move |conn| Box::pin(async move {
721            let mut user_password_ids: Vec<Uuid> = Vec::with_capacity(passwords.len());
722            let mut user_ids: Vec<Uuid> = Vec::with_capacity(passwords.len());
723            let mut hashed_passwords: Vec<String> = Vec::with_capacity(passwords.len());
724            let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(passwords.len());
725            let mut versions: Vec<i32> = Vec::with_capacity(passwords.len());
726            for MasNewUserPassword {
727                user_password_id,
728                user_id,
729                hashed_password,
730                created_at,
731            } in passwords
732            {
733                user_password_ids.push(user_password_id);
734                user_ids.push(user_id.get());
735                hashed_passwords.push(hashed_password);
736                created_ats.push(created_at);
737                versions.push(MIGRATED_PASSWORD_VERSION.into());
738            }
739
740            sqlx::query!(
741                r#"
742                INSERT INTO syn2mas__user_passwords
743                (user_password_id, user_id, hashed_password, created_at, version)
744                SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $5::INTEGER[])
745                "#,
746                &user_password_ids[..],
747                &user_ids[..],
748                &hashed_passwords[..],
749                &created_ats[..],
750                &versions[..],
751            ).execute(&mut *conn).await.into_database("writing users to MAS")?;
752
753            Ok(())
754        })).boxed()
755    }
756
757    #[tracing::instrument(skip_all, level = Level::DEBUG)]
758    pub fn write_email_threepids(
759        &mut self,
760        threepids: Vec<MasNewEmailThreepid>,
761    ) -> BoxFuture<'_, Result<(), Error>> {
762        self.writer_pool.spawn_with_connection(move |conn| {
763            Box::pin(async move {
764                let mut user_email_ids: Vec<Uuid> = Vec::with_capacity(threepids.len());
765                let mut user_ids: Vec<Uuid> = Vec::with_capacity(threepids.len());
766                let mut emails: Vec<String> = Vec::with_capacity(threepids.len());
767                let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(threepids.len());
768
769                for MasNewEmailThreepid {
770                    user_email_id,
771                    user_id,
772                    email,
773                    created_at,
774                } in threepids
775                {
776                    user_email_ids.push(user_email_id);
777                    user_ids.push(user_id.get());
778                    emails.push(email);
779                    created_ats.push(created_at);
780                }
781
782                // `confirmed_at` is going to get removed in a future MAS release,
783                // so just populate with `created_at`
784                sqlx::query!(
785                    r#"
786                    INSERT INTO syn2mas__user_emails
787                    (user_email_id, user_id, email, created_at, confirmed_at)
788                    SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[])
789                    "#,
790                    &user_email_ids[..],
791                    &user_ids[..],
792                    &emails[..],
793                    &created_ats[..],
794                ).execute(&mut *conn).await.into_database("writing emails to MAS")?;
795
796                Ok(())
797            })
798        }).boxed()
799    }
800
801    #[tracing::instrument(skip_all, level = Level::DEBUG)]
802    pub fn write_unsupported_threepids(
803        &mut self,
804        threepids: Vec<MasNewUnsupportedThreepid>,
805    ) -> BoxFuture<'_, Result<(), Error>> {
806        self.writer_pool.spawn_with_connection(move |conn| {
807            Box::pin(async move {
808                let mut user_ids: Vec<Uuid> = Vec::with_capacity(threepids.len());
809                let mut mediums: Vec<String> = Vec::with_capacity(threepids.len());
810                let mut addresses: Vec<String> = Vec::with_capacity(threepids.len());
811                let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(threepids.len());
812
813                for MasNewUnsupportedThreepid {
814                    user_id,
815                    medium,
816                    address,
817                    created_at,
818                } in threepids
819                {
820                    user_ids.push(user_id.get());
821                    mediums.push(medium);
822                    addresses.push(address);
823                    created_ats.push(created_at);
824                }
825
826                sqlx::query!(
827                    r#"
828                    INSERT INTO syn2mas__user_unsupported_third_party_ids
829                    (user_id, medium, address, created_at)
830                    SELECT * FROM UNNEST($1::UUID[], $2::TEXT[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[])
831                    "#,
832                    &user_ids[..],
833                    &mediums[..],
834                    &addresses[..],
835                    &created_ats[..],
836                ).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
837
838                Ok(())
839            })
840        }).boxed()
841    }
842
843    #[tracing::instrument(skip_all, level = Level::DEBUG)]
844    pub fn write_upstream_oauth_links(
845        &mut self,
846        links: Vec<MasNewUpstreamOauthLink>,
847    ) -> BoxFuture<'_, Result<(), Error>> {
848        self.writer_pool.spawn_with_connection(move |conn| {
849            Box::pin(async move {
850                let mut link_ids: Vec<Uuid> = Vec::with_capacity(links.len());
851                let mut user_ids: Vec<Uuid> = Vec::with_capacity(links.len());
852                let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(links.len());
853                let mut subjects: Vec<String> = Vec::with_capacity(links.len());
854                let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(links.len());
855
856                for MasNewUpstreamOauthLink {
857                    link_id,
858                    user_id,
859                    upstream_provider_id,
860                    subject,
861                    created_at,
862                } in links
863                {
864                    link_ids.push(link_id);
865                    user_ids.push(user_id.get());
866                    upstream_provider_ids.push(upstream_provider_id);
867                    subjects.push(subject);
868                    created_ats.push(created_at);
869                }
870
871                sqlx::query!(
872                    r#"
873                    INSERT INTO syn2mas__upstream_oauth_links
874                    (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
875                    SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
876                    "#,
877                    &link_ids[..],
878                    &user_ids[..],
879                    &upstream_provider_ids[..],
880                    &subjects[..],
881                    &created_ats[..],
882                ).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
883
884                Ok(())
885            })
886        }).boxed()
887    }
888
889    #[tracing::instrument(skip_all, level = Level::DEBUG)]
890    pub fn write_compat_sessions(
891        &mut self,
892        sessions: Vec<MasNewCompatSession>,
893    ) -> BoxFuture<'_, Result<(), Error>> {
894        self.writer_pool
895            .spawn_with_connection(move |conn| {
896                Box::pin(async move {
897                    let mut session_ids: Vec<Uuid> = Vec::with_capacity(sessions.len());
898                    let mut user_ids: Vec<Uuid> = Vec::with_capacity(sessions.len());
899                    let mut device_ids: Vec<Option<String>> = Vec::with_capacity(sessions.len());
900                    let mut human_names: Vec<Option<String>> = Vec::with_capacity(sessions.len());
901                    let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(sessions.len());
902                    let mut is_synapse_admins: Vec<bool> = Vec::with_capacity(sessions.len());
903                    let mut last_active_ats: Vec<Option<DateTime<Utc>>> =
904                        Vec::with_capacity(sessions.len());
905                    let mut last_active_ips: Vec<Option<IpAddr>> =
906                        Vec::with_capacity(sessions.len());
907                    let mut user_agents: Vec<Option<String>> = Vec::with_capacity(sessions.len());
908
909                    for MasNewCompatSession {
910                        session_id,
911                        user_id,
912                        device_id,
913                        human_name,
914                        created_at,
915                        is_synapse_admin,
916                        last_active_at,
917                        last_active_ip,
918                        user_agent,
919                    } in sessions
920                    {
921                        session_ids.push(session_id);
922                        user_ids.push(user_id.get());
923                        device_ids.push(device_id);
924                        human_names.push(human_name);
925                        created_ats.push(created_at);
926                        is_synapse_admins.push(is_synapse_admin);
927                        last_active_ats.push(last_active_at);
928                        last_active_ips.push(last_active_ip);
929                        user_agents.push(user_agent);
930                    }
931
932                    sqlx::query!(
933                        r#"
934                        INSERT INTO syn2mas__compat_sessions (
935                          compat_session_id, user_id,
936                          device_id, human_name,
937                          created_at, is_synapse_admin,
938                          last_active_at, last_active_ip,
939                          user_agent)
940                        SELECT * FROM UNNEST(
941                          $1::UUID[], $2::UUID[],
942                          $3::TEXT[], $4::TEXT[],
943                          $5::TIMESTAMP WITH TIME ZONE[], $6::BOOLEAN[],
944                          $7::TIMESTAMP WITH TIME ZONE[], $8::INET[],
945                          $9::TEXT[])
946                        "#,
947                        &session_ids[..],
948                        &user_ids[..],
949                        &device_ids[..] as &[Option<String>],
950                        &human_names[..] as &[Option<String>],
951                        &created_ats[..],
952                        &is_synapse_admins[..],
953                        // We need to override the typing for arrays of optionals (sqlx limitation)
954                        &last_active_ats[..] as &[Option<DateTime<Utc>>],
955                        &last_active_ips[..] as &[Option<IpAddr>],
956                        &user_agents[..] as &[Option<String>],
957                    )
958                    .execute(&mut *conn)
959                    .await
960                    .into_database("writing compat sessions to MAS")?;
961
962                    Ok(())
963                })
964            })
965            .boxed()
966    }
967
968    #[tracing::instrument(skip_all, level = Level::DEBUG)]
969    pub fn write_compat_access_tokens(
970        &mut self,
971        tokens: Vec<MasNewCompatAccessToken>,
972    ) -> BoxFuture<'_, Result<(), Error>> {
973        self.writer_pool
974            .spawn_with_connection(move |conn| {
975                Box::pin(async move {
976                    let mut token_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
977                    let mut session_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
978                    let mut access_tokens: Vec<String> = Vec::with_capacity(tokens.len());
979                    let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(tokens.len());
980                    let mut expires_ats: Vec<Option<DateTime<Utc>>> =
981                        Vec::with_capacity(tokens.len());
982
983                    for MasNewCompatAccessToken {
984                        token_id,
985                        session_id,
986                        access_token,
987                        created_at,
988                        expires_at,
989                    } in tokens
990                    {
991                        token_ids.push(token_id);
992                        session_ids.push(session_id);
993                        access_tokens.push(access_token);
994                        created_ats.push(created_at);
995                        expires_ats.push(expires_at);
996                    }
997
998                    sqlx::query!(
999                        r#"
1000                        INSERT INTO syn2mas__compat_access_tokens (
1001                          compat_access_token_id,
1002                          compat_session_id,
1003                          access_token,
1004                          created_at,
1005                          expires_at)
1006                        SELECT * FROM UNNEST(
1007                          $1::UUID[],
1008                          $2::UUID[],
1009                          $3::TEXT[],
1010                          $4::TIMESTAMP WITH TIME ZONE[],
1011                          $5::TIMESTAMP WITH TIME ZONE[])
1012                        "#,
1013                        &token_ids[..],
1014                        &session_ids[..],
1015                        &access_tokens[..],
1016                        &created_ats[..],
1017                        // We need to override the typing for arrays of optionals (sqlx limitation)
1018                        &expires_ats[..] as &[Option<DateTime<Utc>>],
1019                    )
1020                    .execute(&mut *conn)
1021                    .await
1022                    .into_database("writing compat access tokens to MAS")?;
1023
1024                    Ok(())
1025                })
1026            })
1027            .boxed()
1028    }
1029
1030    #[tracing::instrument(skip_all, level = Level::DEBUG)]
1031    pub fn write_compat_refresh_tokens(
1032        &mut self,
1033        tokens: Vec<MasNewCompatRefreshToken>,
1034    ) -> BoxFuture<'_, Result<(), Error>> {
1035        self.writer_pool
1036            .spawn_with_connection(move |conn| {
1037                Box::pin(async move {
1038                    let mut refresh_token_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
1039                    let mut session_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
1040                    let mut access_token_ids: Vec<Uuid> = Vec::with_capacity(tokens.len());
1041                    let mut refresh_tokens: Vec<String> = Vec::with_capacity(tokens.len());
1042                    let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(tokens.len());
1043
1044                    for MasNewCompatRefreshToken {
1045                        refresh_token_id,
1046                        session_id,
1047                        access_token_id,
1048                        refresh_token,
1049                        created_at,
1050                    } in tokens
1051                    {
1052                        refresh_token_ids.push(refresh_token_id);
1053                        session_ids.push(session_id);
1054                        access_token_ids.push(access_token_id);
1055                        refresh_tokens.push(refresh_token);
1056                        created_ats.push(created_at);
1057                    }
1058
1059                    sqlx::query!(
1060                        r#"
1061                        INSERT INTO syn2mas__compat_refresh_tokens (
1062                          compat_refresh_token_id,
1063                          compat_session_id,
1064                          compat_access_token_id,
1065                          refresh_token,
1066                          created_at)
1067                        SELECT * FROM UNNEST(
1068                          $1::UUID[],
1069                          $2::UUID[],
1070                          $3::UUID[],
1071                          $4::TEXT[],
1072                          $5::TIMESTAMP WITH TIME ZONE[])
1073                        "#,
1074                        &refresh_token_ids[..],
1075                        &session_ids[..],
1076                        &access_token_ids[..],
1077                        &refresh_tokens[..],
1078                        &created_ats[..],
1079                    )
1080                    .execute(&mut *conn)
1081                    .await
1082                    .into_database("writing compat refresh tokens to MAS")?;
1083
1084                    Ok(())
1085                })
1086            })
1087            .boxed()
1088    }
1089}
1090
1091// How many entries to buffer at once, before writing a batch of rows to the
1092// database.
1093const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
1094
1095/// A function that can accept and flush buffers from a `MasWriteBuffer`.
1096/// Intended uses are the methods on `MasWriter` such as `write_users`.
1097type WriteBufferFlusher<T> =
1098    for<'a> fn(&'a mut MasWriter, Vec<T>) -> BoxFuture<'a, Result<(), Error>>;
1099
1100/// A buffer for writing rows to the MAS database.
1101/// Generic over the type of rows.
1102pub struct MasWriteBuffer<T> {
1103    rows: Vec<T>,
1104    flusher: WriteBufferFlusher<T>,
1105    finish_checker_handle: FinishCheckerHandle,
1106}
1107
1108impl<T> MasWriteBuffer<T> {
1109    pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher<T>) -> Self {
1110        MasWriteBuffer {
1111            rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
1112            flusher,
1113            finish_checker_handle: writer.write_buffer_finish_checker.handle(),
1114        }
1115    }
1116
1117    pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1118        self.flush(writer).await?;
1119        self.finish_checker_handle.declare_finished();
1120        Ok(())
1121    }
1122
1123    pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> {
1124        if self.rows.is_empty() {
1125            return Ok(());
1126        }
1127        let rows = std::mem::take(&mut self.rows);
1128        self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
1129        (self.flusher)(writer, rows).await?;
1130        Ok(())
1131    }
1132
1133    pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> {
1134        self.rows.push(row);
1135        if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
1136            self.flush(writer).await?;
1137        }
1138        Ok(())
1139    }
1140}
1141
1142#[cfg(test)]
1143mod test {
1144    use std::collections::{BTreeMap, BTreeSet};
1145
1146    use chrono::DateTime;
1147    use futures_util::TryStreamExt;
1148    use serde::Serialize;
1149    use sqlx::{Column, PgConnection, PgPool, Row};
1150    use uuid::{NonNilUuid, Uuid};
1151
1152    use crate::{
1153        LockedMasDatabase, MasWriter, Progress,
1154        mas_writer::{
1155            MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
1156            MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
1157            MasNewUserPassword,
1158        },
1159    };
1160
1161    /// A snapshot of a whole database
1162    #[derive(Default, Serialize)]
1163    #[serde(transparent)]
1164    struct DatabaseSnapshot {
1165        tables: BTreeMap<String, TableSnapshot>,
1166    }
1167
1168    #[derive(Serialize)]
1169    #[serde(transparent)]
1170    struct TableSnapshot {
1171        rows: BTreeSet<RowSnapshot>,
1172    }
1173
1174    #[derive(PartialEq, Eq, PartialOrd, Ord, Serialize)]
1175    #[serde(transparent)]
1176    struct RowSnapshot {
1177        columns_to_values: BTreeMap<String, Option<String>>,
1178    }
1179
1180    const SKIPPED_TABLES: &[&str] = &["_sqlx_migrations"];
1181
1182    /// Produces a serialisable snapshot of a database, usable for snapshot
1183    /// testing
1184    ///
1185    /// For brevity, empty tables, as well as [`SKIPPED_TABLES`], will not be
1186    /// included in the snapshot.
1187    async fn snapshot_database(conn: &mut PgConnection) -> DatabaseSnapshot {
1188        let mut out = DatabaseSnapshot::default();
1189        let table_names: Vec<String> = sqlx::query_scalar(
1190            "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema();",
1191        )
1192        .fetch_all(&mut *conn)
1193        .await
1194        .unwrap();
1195
1196        for table_name in table_names {
1197            if SKIPPED_TABLES.contains(&table_name.as_str()) {
1198                continue;
1199            }
1200
1201            let column_names: Vec<String> = sqlx::query_scalar(
1202                "SELECT column_name FROM information_schema.columns WHERE table_name = $1 AND table_schema = current_schema();"
1203            ).bind(&table_name).fetch_all(&mut *conn).await.expect("failed to get column names for table for snapshotting");
1204
1205            let column_name_list = column_names
1206                .iter()
1207                // stringify all the values for simplicity
1208                .map(|column_name| format!("{column_name}::TEXT AS \"{column_name}\""))
1209                .collect::<Vec<_>>()
1210                .join(", ");
1211
1212            let table_rows = sqlx::query(&format!("SELECT {column_name_list} FROM {table_name};"))
1213                .fetch(&mut *conn)
1214                .map_ok(|row| {
1215                    let mut columns_to_values = BTreeMap::new();
1216                    for (idx, column) in row.columns().iter().enumerate() {
1217                        columns_to_values.insert(column.name().to_owned(), row.get(idx));
1218                    }
1219                    RowSnapshot { columns_to_values }
1220                })
1221                .try_collect::<BTreeSet<RowSnapshot>>()
1222                .await
1223                .expect("failed to fetch rows from table for snapshotting");
1224
1225            if !table_rows.is_empty() {
1226                out.tables
1227                    .insert(table_name, TableSnapshot { rows: table_rows });
1228            }
1229        }
1230
1231        out
1232    }
1233
1234    /// Make a snapshot assertion against the database.
1235    macro_rules! assert_db_snapshot {
1236        ($db: expr) => {
1237            let db_snapshot = snapshot_database($db).await;
1238            ::insta::assert_yaml_snapshot!(db_snapshot);
1239        };
1240    }
1241
1242    /// Runs some code with a `MasWriter`.
1243    ///
1244    /// The callback is responsible for `finish`ing the `MasWriter`.
1245    async fn make_mas_writer(pool: &PgPool) -> MasWriter {
1246        let main_conn = pool.acquire().await.unwrap().detach();
1247        let mut writer_conns = Vec::new();
1248        for _ in 0..2 {
1249            writer_conns.push(
1250                pool.acquire()
1251                    .await
1252                    .expect("failed to acquire MasWriter writer connection")
1253                    .detach(),
1254            );
1255        }
1256        let locked_main_conn = LockedMasDatabase::try_new(main_conn)
1257            .await
1258            .expect("failed to lock MAS database")
1259            .expect_left("MAS database is already locked");
1260        MasWriter::new(locked_main_conn, writer_conns)
1261            .await
1262            .expect("failed to construct MasWriter")
1263    }
1264
1265    /// Tests writing a single user, without a password.
1266    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1267    async fn test_write_user(pool: PgPool) {
1268        let mut writer = make_mas_writer(&pool).await;
1269
1270        writer
1271            .write_users(vec![MasNewUser {
1272                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1273                username: "alice".to_owned(),
1274                created_at: DateTime::default(),
1275                locked_at: None,
1276                deactivated_at: None,
1277                can_request_admin: false,
1278                is_guest: false,
1279            }])
1280            .await
1281            .expect("failed to write user");
1282
1283        let mut conn = writer
1284            .finish(&Progress::default())
1285            .await
1286            .expect("failed to finish MasWriter");
1287
1288        assert_db_snapshot!(&mut conn);
1289    }
1290
1291    /// Tests writing a single user, with a password.
1292    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1293    async fn test_write_user_with_password(pool: PgPool) {
1294        const USER_ID: NonNilUuid = NonNilUuid::new(Uuid::from_u128(1u128)).unwrap();
1295
1296        let mut writer = make_mas_writer(&pool).await;
1297
1298        writer
1299            .write_users(vec![MasNewUser {
1300                user_id: USER_ID,
1301                username: "alice".to_owned(),
1302                created_at: DateTime::default(),
1303                locked_at: None,
1304                deactivated_at: None,
1305                can_request_admin: false,
1306                is_guest: false,
1307            }])
1308            .await
1309            .expect("failed to write user");
1310        writer
1311            .write_passwords(vec![MasNewUserPassword {
1312                user_password_id: Uuid::from_u128(42u128),
1313                user_id: USER_ID,
1314                hashed_password: "$bcrypt$aaaaaaaaaaa".to_owned(),
1315                created_at: DateTime::default(),
1316            }])
1317            .await
1318            .expect("failed to write password");
1319
1320        let mut conn = writer
1321            .finish(&Progress::default())
1322            .await
1323            .expect("failed to finish MasWriter");
1324
1325        assert_db_snapshot!(&mut conn);
1326    }
1327
1328    /// Tests writing a single user, with an e-mail address associated.
1329    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1330    async fn test_write_user_with_email(pool: PgPool) {
1331        let mut writer = make_mas_writer(&pool).await;
1332
1333        writer
1334            .write_users(vec![MasNewUser {
1335                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1336                username: "alice".to_owned(),
1337                created_at: DateTime::default(),
1338                locked_at: None,
1339                deactivated_at: None,
1340                can_request_admin: false,
1341                is_guest: false,
1342            }])
1343            .await
1344            .expect("failed to write user");
1345
1346        writer
1347            .write_email_threepids(vec![MasNewEmailThreepid {
1348                user_email_id: Uuid::from_u128(2u128),
1349                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1350                email: "alice@example.org".to_owned(),
1351                created_at: DateTime::default(),
1352            }])
1353            .await
1354            .expect("failed to write e-mail");
1355
1356        let mut conn = writer
1357            .finish(&Progress::default())
1358            .await
1359            .expect("failed to finish MasWriter");
1360
1361        assert_db_snapshot!(&mut conn);
1362    }
1363
1364    /// Tests writing a single user, with a unsupported third-party ID
1365    /// associated.
1366    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1367    async fn test_write_user_with_unsupported_threepid(pool: PgPool) {
1368        let mut writer = make_mas_writer(&pool).await;
1369
1370        writer
1371            .write_users(vec![MasNewUser {
1372                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1373                username: "alice".to_owned(),
1374                created_at: DateTime::default(),
1375                locked_at: None,
1376                deactivated_at: None,
1377                can_request_admin: false,
1378                is_guest: false,
1379            }])
1380            .await
1381            .expect("failed to write user");
1382
1383        writer
1384            .write_unsupported_threepids(vec![MasNewUnsupportedThreepid {
1385                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1386                medium: "msisdn".to_owned(),
1387                address: "441189998819991197253".to_owned(),
1388                created_at: DateTime::default(),
1389            }])
1390            .await
1391            .expect("failed to write phone number (unsupported threepid)");
1392
1393        let mut conn = writer
1394            .finish(&Progress::default())
1395            .await
1396            .expect("failed to finish MasWriter");
1397
1398        assert_db_snapshot!(&mut conn);
1399    }
1400
1401    /// Tests writing a single user, with a link to an upstream provider.
1402    /// There needs to be an upstream provider in the database already — in the
1403    /// real migration, this is done by running a provider sync first.
1404    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
1405    async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
1406        let mut writer = make_mas_writer(&pool).await;
1407
1408        writer
1409            .write_users(vec![MasNewUser {
1410                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1411                username: "alice".to_owned(),
1412                created_at: DateTime::default(),
1413                locked_at: None,
1414                deactivated_at: None,
1415                can_request_admin: false,
1416                is_guest: false,
1417            }])
1418            .await
1419            .expect("failed to write user");
1420
1421        writer
1422            .write_upstream_oauth_links(vec![MasNewUpstreamOauthLink {
1423                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1424                link_id: Uuid::from_u128(3u128),
1425                upstream_provider_id: Uuid::from_u128(4u128),
1426                subject: "12345.67890".to_owned(),
1427                created_at: DateTime::default(),
1428            }])
1429            .await
1430            .expect("failed to write link");
1431
1432        let mut conn = writer
1433            .finish(&Progress::default())
1434            .await
1435            .expect("failed to finish MasWriter");
1436
1437        assert_db_snapshot!(&mut conn);
1438    }
1439
1440    /// Tests writing a single user, with a device (compat session).
1441    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1442    async fn test_write_user_with_device(pool: PgPool) {
1443        let mut writer = make_mas_writer(&pool).await;
1444
1445        writer
1446            .write_users(vec![MasNewUser {
1447                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1448                username: "alice".to_owned(),
1449                created_at: DateTime::default(),
1450                locked_at: None,
1451                deactivated_at: None,
1452                can_request_admin: false,
1453                is_guest: false,
1454            }])
1455            .await
1456            .expect("failed to write user");
1457
1458        writer
1459            .write_compat_sessions(vec![MasNewCompatSession {
1460                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1461                session_id: Uuid::from_u128(5u128),
1462                created_at: DateTime::default(),
1463                device_id: Some("ADEVICE".to_owned()),
1464                human_name: Some("alice's pinephone".to_owned()),
1465                is_synapse_admin: true,
1466                last_active_at: Some(DateTime::default()),
1467                last_active_ip: Some("203.0.113.1".parse().unwrap()),
1468                user_agent: Some("Browser/5.0".to_owned()),
1469            }])
1470            .await
1471            .expect("failed to write compat session");
1472
1473        let mut conn = writer
1474            .finish(&Progress::default())
1475            .await
1476            .expect("failed to finish MasWriter");
1477
1478        assert_db_snapshot!(&mut conn);
1479    }
1480
1481    /// Tests writing a single user, with a device and an access token.
1482    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1483    async fn test_write_user_with_access_token(pool: PgPool) {
1484        let mut writer = make_mas_writer(&pool).await;
1485
1486        writer
1487            .write_users(vec![MasNewUser {
1488                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1489                username: "alice".to_owned(),
1490                created_at: DateTime::default(),
1491                locked_at: None,
1492                deactivated_at: None,
1493                can_request_admin: false,
1494                is_guest: false,
1495            }])
1496            .await
1497            .expect("failed to write user");
1498
1499        writer
1500            .write_compat_sessions(vec![MasNewCompatSession {
1501                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1502                session_id: Uuid::from_u128(5u128),
1503                created_at: DateTime::default(),
1504                device_id: Some("ADEVICE".to_owned()),
1505                human_name: None,
1506                is_synapse_admin: false,
1507                last_active_at: None,
1508                last_active_ip: None,
1509                user_agent: None,
1510            }])
1511            .await
1512            .expect("failed to write compat session");
1513
1514        writer
1515            .write_compat_access_tokens(vec![MasNewCompatAccessToken {
1516                token_id: Uuid::from_u128(6u128),
1517                session_id: Uuid::from_u128(5u128),
1518                access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1519                created_at: DateTime::default(),
1520                expires_at: None,
1521            }])
1522            .await
1523            .expect("failed to write access token");
1524
1525        let mut conn = writer
1526            .finish(&Progress::default())
1527            .await
1528            .expect("failed to finish MasWriter");
1529
1530        assert_db_snapshot!(&mut conn);
1531    }
1532
1533    /// Tests writing a single user, with a device, an access token and a
1534    /// refresh token.
1535    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1536    async fn test_write_user_with_refresh_token(pool: PgPool) {
1537        let mut writer = make_mas_writer(&pool).await;
1538
1539        writer
1540            .write_users(vec![MasNewUser {
1541                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1542                username: "alice".to_owned(),
1543                created_at: DateTime::default(),
1544                locked_at: None,
1545                deactivated_at: None,
1546                can_request_admin: false,
1547                is_guest: false,
1548            }])
1549            .await
1550            .expect("failed to write user");
1551
1552        writer
1553            .write_compat_sessions(vec![MasNewCompatSession {
1554                user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1555                session_id: Uuid::from_u128(5u128),
1556                created_at: DateTime::default(),
1557                device_id: Some("ADEVICE".to_owned()),
1558                human_name: None,
1559                is_synapse_admin: false,
1560                last_active_at: None,
1561                last_active_ip: None,
1562                user_agent: None,
1563            }])
1564            .await
1565            .expect("failed to write compat session");
1566
1567        writer
1568            .write_compat_access_tokens(vec![MasNewCompatAccessToken {
1569                token_id: Uuid::from_u128(6u128),
1570                session_id: Uuid::from_u128(5u128),
1571                access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1572                created_at: DateTime::default(),
1573                expires_at: None,
1574            }])
1575            .await
1576            .expect("failed to write access token");
1577
1578        writer
1579            .write_compat_refresh_tokens(vec![MasNewCompatRefreshToken {
1580                refresh_token_id: Uuid::from_u128(7u128),
1581                session_id: Uuid::from_u128(5u128),
1582                access_token_id: Uuid::from_u128(6u128),
1583                refresh_token: "syr_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1584                created_at: DateTime::default(),
1585            }])
1586            .await
1587            .expect("failed to write refresh token");
1588
1589        let mut conn = writer
1590            .finish(&Progress::default())
1591            .await
1592            .expect("failed to finish MasWriter");
1593
1594        assert_db_snapshot!(&mut conn);
1595    }
1596}