mas_matrix/
mock.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{HomeserverConnection as _, MatrixUser, ProvisionRequest};
14
15#[derive(Clone)]
16pub struct MockUser {
17    sub: String,
18    avatar_url: Option<String>,
19    displayname: Option<String>,
20    devices: HashSet<String>,
21    emails: Option<Vec<String>>,
22    cross_signing_reset_allowed: bool,
23    deactivated: bool,
24    pub locked: bool,
25}
26
27/// A mock implementation of a [`HomeserverConnection`], which never fails and
28/// doesn't do anything.
29pub struct HomeserverConnection {
30    homeserver: String,
31    users: RwLock<HashMap<String, MockUser>>,
32    reserved_localparts: RwLock<HashSet<&'static str>>,
33}
34
35impl HomeserverConnection {
36    /// A valid bearer token that will be accepted by
37    /// [`crate::HomeserverConnection::verify_token`].
38    pub const VALID_BEARER_TOKEN: &str = "mock_homeserver_bearer_token";
39
40    /// Create a new mock connection.
41    pub fn new<H>(homeserver: H) -> Self
42    where
43        H: Into<String>,
44    {
45        Self {
46            homeserver: homeserver.into(),
47            users: RwLock::new(HashMap::new()),
48            reserved_localparts: RwLock::new(HashSet::new()),
49        }
50    }
51
52    pub async fn reserve_localpart(&self, localpart: &'static str) {
53        self.reserved_localparts.write().await.insert(localpart);
54    }
55
56    /// Like `query_user` but get the raw test state of the user.
57    ///
58    /// # Errors
59    ///
60    /// Will fail if the user doesn't exist.
61    pub async fn query_user_raw(&self, localpart: &str) -> Result<MockUser, anyhow::Error> {
62        let mxid = self.mxid(localpart);
63        let users = self.users.read().await;
64        let user = users.get(&mxid).context("User not found")?;
65        Ok(user.clone())
66    }
67}
68
69#[async_trait]
70impl crate::HomeserverConnection for HomeserverConnection {
71    fn homeserver(&self) -> &str {
72        &self.homeserver
73    }
74
75    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
76        Ok(token == Self::VALID_BEARER_TOKEN)
77    }
78
79    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
80        let mxid = self.mxid(localpart);
81        let users = self.users.read().await;
82        let user = users.get(&mxid).context("User not found")?;
83        Ok(MatrixUser {
84            displayname: user.displayname.clone(),
85            avatar_url: user.avatar_url.clone(),
86            deactivated: user.deactivated,
87        })
88    }
89
90    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
91        let mut users = self.users.write().await;
92        let mxid = self.mxid(request.localpart());
93        let inserted = !users.contains_key(&mxid);
94        let user = users.entry(mxid).or_insert(MockUser {
95            sub: request.sub().to_owned(),
96            avatar_url: None,
97            displayname: None,
98            devices: HashSet::new(),
99            emails: None,
100            cross_signing_reset_allowed: false,
101            deactivated: false,
102            locked: false,
103        });
104
105        anyhow::ensure!(
106            user.sub == request.sub(),
107            "User already provisioned with different sub"
108        );
109
110        request.on_emails(|emails| {
111            user.emails = emails.map(ToOwned::to_owned);
112        });
113
114        request.on_displayname(|displayname| {
115            user.displayname = displayname.map(ToOwned::to_owned);
116        });
117
118        request.on_avatar_url(|avatar_url| {
119            user.avatar_url = avatar_url.map(ToOwned::to_owned);
120        });
121
122        user.locked = request.locked();
123
124        Ok(inserted)
125    }
126
127    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
128        if self.reserved_localparts.read().await.contains(localpart) {
129            return Ok(false);
130        }
131
132        let mxid = self.mxid(localpart);
133        let users = self.users.read().await;
134        Ok(!users.contains_key(&mxid))
135    }
136
137    async fn upsert_device(
138        &self,
139        localpart: &str,
140        device_id: &str,
141        _initial_display_name: Option<&str>,
142    ) -> Result<(), anyhow::Error> {
143        let mxid = self.mxid(localpart);
144        let mut users = self.users.write().await;
145        let user = users.get_mut(&mxid).context("User not found")?;
146        user.devices.insert(device_id.to_owned());
147        Ok(())
148    }
149
150    async fn update_device_display_name(
151        &self,
152        localpart: &str,
153        device_id: &str,
154        _display_name: &str,
155    ) -> Result<(), anyhow::Error> {
156        let mxid = self.mxid(localpart);
157        let mut users = self.users.write().await;
158        let user = users.get_mut(&mxid).context("User not found")?;
159        user.devices.get(device_id).context("Device not found")?;
160        Ok(())
161    }
162
163    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
164        let mxid = self.mxid(localpart);
165        let mut users = self.users.write().await;
166        let user = users.get_mut(&mxid).context("User not found")?;
167        user.devices.remove(device_id);
168        Ok(())
169    }
170
171    async fn sync_devices(
172        &self,
173        localpart: &str,
174        devices: HashSet<String>,
175    ) -> Result<(), anyhow::Error> {
176        let mxid = self.mxid(localpart);
177        let mut users = self.users.write().await;
178        let user = users.get_mut(&mxid).context("User not found")?;
179        user.devices = devices;
180        Ok(())
181    }
182
183    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
184        let mxid = self.mxid(localpart);
185        let mut users = self.users.write().await;
186        let user = users.get_mut(&mxid).context("User not found")?;
187        user.devices.clear();
188        user.emails = None;
189        user.deactivated = true;
190        if erase {
191            user.avatar_url = None;
192            user.displayname = None;
193        }
194
195        Ok(())
196    }
197
198    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
199        let mxid = self.mxid(localpart);
200        let mut users = self.users.write().await;
201        let user = users.get_mut(&mxid).context("User not found")?;
202        user.deactivated = false;
203
204        Ok(())
205    }
206
207    async fn set_displayname(
208        &self,
209        localpart: &str,
210        displayname: &str,
211    ) -> Result<(), anyhow::Error> {
212        let mxid = self.mxid(localpart);
213        let mut users = self.users.write().await;
214        let user = users.get_mut(&mxid).context("User not found")?;
215        user.displayname = Some(displayname.to_owned());
216        Ok(())
217    }
218
219    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
220        let mxid = self.mxid(localpart);
221        let mut users = self.users.write().await;
222        let user = users.get_mut(&mxid).context("User not found")?;
223        user.displayname = None;
224        Ok(())
225    }
226
227    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
228        let mxid = self.mxid(localpart);
229        let mut users = self.users.write().await;
230        let user = users.get_mut(&mxid).context("User not found")?;
231        user.cross_signing_reset_allowed = true;
232        Ok(())
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[tokio::test]
241    async fn test_mock_connection() {
242        let conn = HomeserverConnection::new("example.org");
243
244        let mxid = "@test:example.org";
245        let device = "test";
246        assert_eq!(conn.homeserver(), "example.org");
247        assert_eq!(conn.mxid("test"), mxid);
248
249        assert!(conn.query_user("test").await.is_err());
250        assert!(conn.upsert_device("test", device, None).await.is_err());
251        assert!(conn.delete_device("test", device).await.is_err());
252
253        let request = ProvisionRequest::new("test", "test", false)
254            .set_displayname("Test User".into())
255            .set_avatar_url("mxc://example.org/1234567890".into())
256            .set_emails(vec!["test@example.org".to_owned()]);
257
258        let inserted = conn.provision_user(&request).await.unwrap();
259        assert!(inserted);
260
261        let user = conn.query_user("test").await.unwrap();
262        assert_eq!(user.displayname, Some("Test User".into()));
263        assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
264
265        // Set the displayname again
266        assert!(conn.set_displayname("test", "John").await.is_ok());
267
268        let user = conn.query_user("test").await.unwrap();
269        assert_eq!(user.displayname, Some("John".into()));
270
271        // Unset the displayname
272        assert!(conn.unset_displayname("test").await.is_ok());
273
274        let user = conn.query_user("test").await.unwrap();
275        assert_eq!(user.displayname, None);
276
277        // Deleting a non-existent device should not fail
278        assert!(conn.delete_device("test", device).await.is_ok());
279
280        // Create the device
281        assert!(conn.upsert_device("test", device, None).await.is_ok());
282        // Create the same device again
283        assert!(conn.upsert_device("test", device, None).await.is_ok());
284
285        // XXX: there is no API to query devices yet in the trait
286        // Delete the device
287        assert!(conn.delete_device("test", device).await.is_ok());
288
289        // The user we just created should be not available
290        assert!(!conn.is_localpart_available("test").await.unwrap());
291        // But another user should be
292        assert!(conn.is_localpart_available("alice").await.unwrap());
293
294        // Reserve the localpart, it should not be available anymore
295        conn.reserve_localpart("alice").await;
296        assert!(!conn.is_localpart_available("alice").await.unwrap());
297    }
298}