mas_matrix/
mock.rs
1use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16 sub: String,
17 avatar_url: Option<String>,
18 displayname: Option<String>,
19 devices: HashSet<String>,
20 emails: Option<Vec<String>>,
21 cross_signing_reset_allowed: bool,
22 deactivated: bool,
23}
24
25pub struct HomeserverConnection {
28 homeserver: String,
29 users: RwLock<HashMap<String, MockUser>>,
30 reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34 pub fn new<H>(homeserver: H) -> Self
36 where
37 H: Into<String>,
38 {
39 Self {
40 homeserver: homeserver.into(),
41 users: RwLock::new(HashMap::new()),
42 reserved_localparts: RwLock::new(HashSet::new()),
43 }
44 }
45
46 pub async fn reserve_localpart(&self, localpart: &'static str) {
47 self.reserved_localparts.write().await.insert(localpart);
48 }
49}
50
51#[async_trait]
52impl crate::HomeserverConnection for HomeserverConnection {
53 fn homeserver(&self) -> &str {
54 &self.homeserver
55 }
56
57 async fn query_user(&self, mxid: &str) -> Result<MatrixUser, anyhow::Error> {
58 let users = self.users.read().await;
59 let user = users.get(mxid).context("User not found")?;
60 Ok(MatrixUser {
61 displayname: user.displayname.clone(),
62 avatar_url: user.avatar_url.clone(),
63 deactivated: user.deactivated,
64 })
65 }
66
67 async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
68 let mut users = self.users.write().await;
69 let inserted = !users.contains_key(request.mxid());
70 let user = users.entry(request.mxid().to_owned()).or_insert(MockUser {
71 sub: request.sub().to_owned(),
72 avatar_url: None,
73 displayname: None,
74 devices: HashSet::new(),
75 emails: None,
76 cross_signing_reset_allowed: false,
77 deactivated: false,
78 });
79
80 anyhow::ensure!(
81 user.sub == request.sub(),
82 "User already provisioned with different sub"
83 );
84
85 request.on_emails(|emails| {
86 user.emails = emails.map(ToOwned::to_owned);
87 });
88
89 request.on_displayname(|displayname| {
90 user.displayname = displayname.map(ToOwned::to_owned);
91 });
92
93 request.on_avatar_url(|avatar_url| {
94 user.avatar_url = avatar_url.map(ToOwned::to_owned);
95 });
96
97 Ok(inserted)
98 }
99
100 async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
101 if self.reserved_localparts.read().await.contains(localpart) {
102 return Ok(false);
103 }
104
105 let mxid = self.mxid(localpart);
106 let users = self.users.read().await;
107 Ok(!users.contains_key(&mxid))
108 }
109
110 async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> {
111 let mut users = self.users.write().await;
112 let user = users.get_mut(mxid).context("User not found")?;
113 user.devices.insert(device_id.to_owned());
114 Ok(())
115 }
116
117 async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> {
118 let mut users = self.users.write().await;
119 let user = users.get_mut(mxid).context("User not found")?;
120 user.devices.remove(device_id);
121 Ok(())
122 }
123
124 async fn sync_devices(
125 &self,
126 mxid: &str,
127 devices: HashSet<String>,
128 ) -> Result<(), anyhow::Error> {
129 let mut users = self.users.write().await;
130 let user = users.get_mut(mxid).context("User not found")?;
131 user.devices = devices;
132 Ok(())
133 }
134
135 async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> {
136 let mut users = self.users.write().await;
137 let user = users.get_mut(mxid).context("User not found")?;
138 user.devices.clear();
139 user.emails = None;
140 user.deactivated = true;
141 if erase {
142 user.avatar_url = None;
143 user.displayname = None;
144 }
145
146 Ok(())
147 }
148
149 async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error> {
150 let mut users = self.users.write().await;
151 let user = users.get_mut(mxid).context("User not found")?;
152 user.deactivated = false;
153
154 Ok(())
155 }
156
157 async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> {
158 let mut users = self.users.write().await;
159 let user = users.get_mut(mxid).context("User not found")?;
160 user.displayname = Some(displayname.to_owned());
161 Ok(())
162 }
163
164 async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> {
165 let mut users = self.users.write().await;
166 let user = users.get_mut(mxid).context("User not found")?;
167 user.displayname = None;
168 Ok(())
169 }
170
171 async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> {
172 let mut users = self.users.write().await;
173 let user = users.get_mut(mxid).context("User not found")?;
174 user.cross_signing_reset_allowed = true;
175 Ok(())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::HomeserverConnection as _;
183
184 #[tokio::test]
185 async fn test_mock_connection() {
186 let conn = HomeserverConnection::new("example.org");
187
188 let mxid = "@test:example.org";
189 let device = "test";
190 assert_eq!(conn.homeserver(), "example.org");
191 assert_eq!(conn.mxid("test"), mxid);
192
193 assert!(conn.query_user(mxid).await.is_err());
194 assert!(conn.create_device(mxid, device).await.is_err());
195 assert!(conn.delete_device(mxid, device).await.is_err());
196
197 let request = ProvisionRequest::new("@test:example.org", "test")
198 .set_displayname("Test User".into())
199 .set_avatar_url("mxc://example.org/1234567890".into())
200 .set_emails(vec!["test@example.org".to_owned()]);
201
202 let inserted = conn.provision_user(&request).await.unwrap();
203 assert!(inserted);
204
205 let user = conn.query_user(mxid).await.unwrap();
206 assert_eq!(user.displayname, Some("Test User".into()));
207 assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
208
209 assert!(conn.set_displayname(mxid, "John").await.is_ok());
211
212 let user = conn.query_user(mxid).await.unwrap();
213 assert_eq!(user.displayname, Some("John".into()));
214
215 assert!(conn.unset_displayname(mxid).await.is_ok());
217
218 let user = conn.query_user(mxid).await.unwrap();
219 assert_eq!(user.displayname, None);
220
221 assert!(conn.delete_device(mxid, device).await.is_ok());
223
224 assert!(conn.create_device(mxid, device).await.is_ok());
226 assert!(conn.create_device(mxid, device).await.is_ok());
228
229 assert!(conn.delete_device(mxid, device).await.is_ok());
232
233 assert!(!conn.is_localpart_available("test").await.unwrap());
235 assert!(conn.is_localpart_available("alice").await.unwrap());
237
238 conn.reserve_localpart("alice").await;
240 assert!(!conn.is_localpart_available("alice").await.unwrap());
241 }
242}