mas_matrix/
mock.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16    sub: String,
17    avatar_url: Option<String>,
18    displayname: Option<String>,
19    devices: HashSet<String>,
20    emails: Option<Vec<String>>,
21    cross_signing_reset_allowed: bool,
22    deactivated: bool,
23}
24
25/// A mock implementation of a [`HomeserverConnection`], which never fails and
26/// doesn't do anything.
27pub struct HomeserverConnection {
28    homeserver: String,
29    users: RwLock<HashMap<String, MockUser>>,
30    reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34    /// Create a new mock connection.
35    pub fn new<H>(homeserver: H) -> Self
36    where
37        H: Into<String>,
38    {
39        Self {
40            homeserver: homeserver.into(),
41            users: RwLock::new(HashMap::new()),
42            reserved_localparts: RwLock::new(HashSet::new()),
43        }
44    }
45
46    pub async fn reserve_localpart(&self, localpart: &'static str) {
47        self.reserved_localparts.write().await.insert(localpart);
48    }
49}
50
51#[async_trait]
52impl crate::HomeserverConnection for HomeserverConnection {
53    fn homeserver(&self) -> &str {
54        &self.homeserver
55    }
56
57    async fn query_user(&self, mxid: &str) -> Result<MatrixUser, anyhow::Error> {
58        let users = self.users.read().await;
59        let user = users.get(mxid).context("User not found")?;
60        Ok(MatrixUser {
61            displayname: user.displayname.clone(),
62            avatar_url: user.avatar_url.clone(),
63            deactivated: user.deactivated,
64        })
65    }
66
67    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
68        let mut users = self.users.write().await;
69        let inserted = !users.contains_key(request.mxid());
70        let user = users.entry(request.mxid().to_owned()).or_insert(MockUser {
71            sub: request.sub().to_owned(),
72            avatar_url: None,
73            displayname: None,
74            devices: HashSet::new(),
75            emails: None,
76            cross_signing_reset_allowed: false,
77            deactivated: false,
78        });
79
80        anyhow::ensure!(
81            user.sub == request.sub(),
82            "User already provisioned with different sub"
83        );
84
85        request.on_emails(|emails| {
86            user.emails = emails.map(ToOwned::to_owned);
87        });
88
89        request.on_displayname(|displayname| {
90            user.displayname = displayname.map(ToOwned::to_owned);
91        });
92
93        request.on_avatar_url(|avatar_url| {
94            user.avatar_url = avatar_url.map(ToOwned::to_owned);
95        });
96
97        Ok(inserted)
98    }
99
100    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
101        if self.reserved_localparts.read().await.contains(localpart) {
102            return Ok(false);
103        }
104
105        let mxid = self.mxid(localpart);
106        let users = self.users.read().await;
107        Ok(!users.contains_key(&mxid))
108    }
109
110    async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> {
111        let mut users = self.users.write().await;
112        let user = users.get_mut(mxid).context("User not found")?;
113        user.devices.insert(device_id.to_owned());
114        Ok(())
115    }
116
117    async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> {
118        let mut users = self.users.write().await;
119        let user = users.get_mut(mxid).context("User not found")?;
120        user.devices.remove(device_id);
121        Ok(())
122    }
123
124    async fn sync_devices(
125        &self,
126        mxid: &str,
127        devices: HashSet<String>,
128    ) -> Result<(), anyhow::Error> {
129        let mut users = self.users.write().await;
130        let user = users.get_mut(mxid).context("User not found")?;
131        user.devices = devices;
132        Ok(())
133    }
134
135    async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> {
136        let mut users = self.users.write().await;
137        let user = users.get_mut(mxid).context("User not found")?;
138        user.devices.clear();
139        user.emails = None;
140        user.deactivated = true;
141        if erase {
142            user.avatar_url = None;
143            user.displayname = None;
144        }
145
146        Ok(())
147    }
148
149    async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error> {
150        let mut users = self.users.write().await;
151        let user = users.get_mut(mxid).context("User not found")?;
152        user.deactivated = false;
153
154        Ok(())
155    }
156
157    async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> {
158        let mut users = self.users.write().await;
159        let user = users.get_mut(mxid).context("User not found")?;
160        user.displayname = Some(displayname.to_owned());
161        Ok(())
162    }
163
164    async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> {
165        let mut users = self.users.write().await;
166        let user = users.get_mut(mxid).context("User not found")?;
167        user.displayname = None;
168        Ok(())
169    }
170
171    async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> {
172        let mut users = self.users.write().await;
173        let user = users.get_mut(mxid).context("User not found")?;
174        user.cross_signing_reset_allowed = true;
175        Ok(())
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::HomeserverConnection as _;
183
184    #[tokio::test]
185    async fn test_mock_connection() {
186        let conn = HomeserverConnection::new("example.org");
187
188        let mxid = "@test:example.org";
189        let device = "test";
190        assert_eq!(conn.homeserver(), "example.org");
191        assert_eq!(conn.mxid("test"), mxid);
192
193        assert!(conn.query_user(mxid).await.is_err());
194        assert!(conn.create_device(mxid, device).await.is_err());
195        assert!(conn.delete_device(mxid, device).await.is_err());
196
197        let request = ProvisionRequest::new("@test:example.org", "test")
198            .set_displayname("Test User".into())
199            .set_avatar_url("mxc://example.org/1234567890".into())
200            .set_emails(vec!["test@example.org".to_owned()]);
201
202        let inserted = conn.provision_user(&request).await.unwrap();
203        assert!(inserted);
204
205        let user = conn.query_user(mxid).await.unwrap();
206        assert_eq!(user.displayname, Some("Test User".into()));
207        assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
208
209        // Set the displayname again
210        assert!(conn.set_displayname(mxid, "John").await.is_ok());
211
212        let user = conn.query_user(mxid).await.unwrap();
213        assert_eq!(user.displayname, Some("John".into()));
214
215        // Unset the displayname
216        assert!(conn.unset_displayname(mxid).await.is_ok());
217
218        let user = conn.query_user(mxid).await.unwrap();
219        assert_eq!(user.displayname, None);
220
221        // Deleting a non-existent device should not fail
222        assert!(conn.delete_device(mxid, device).await.is_ok());
223
224        // Create the device
225        assert!(conn.create_device(mxid, device).await.is_ok());
226        // Create the same device again
227        assert!(conn.create_device(mxid, device).await.is_ok());
228
229        // XXX: there is no API to query devices yet in the trait
230        // Delete the device
231        assert!(conn.delete_device(mxid, device).await.is_ok());
232
233        // The user we just created should be not available
234        assert!(!conn.is_localpart_available("test").await.unwrap());
235        // But another user should be
236        assert!(conn.is_localpart_available("alice").await.unwrap());
237
238        // Reserve the localpart, it should not be available anymore
239        conn.reserve_localpart("alice").await;
240        assert!(!conn.is_localpart_available("alice").await.unwrap());
241    }
242}