1use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{HomeserverConnection as _, MatrixUser, ProvisionRequest};
14
15#[derive(Clone)]
16pub struct MockUser {
17 sub: String,
18 avatar_url: Option<String>,
19 displayname: Option<String>,
20 devices: HashSet<String>,
21 emails: Option<Vec<String>>,
22 cross_signing_reset_allowed: bool,
23 deactivated: bool,
24 pub locked: bool,
25}
26
27pub struct HomeserverConnection {
30 homeserver: String,
31 users: RwLock<HashMap<String, MockUser>>,
32 reserved_localparts: RwLock<HashSet<&'static str>>,
33}
34
35impl HomeserverConnection {
36 pub const VALID_BEARER_TOKEN: &str = "mock_homeserver_bearer_token";
39
40 pub fn new<H>(homeserver: H) -> Self
42 where
43 H: Into<String>,
44 {
45 Self {
46 homeserver: homeserver.into(),
47 users: RwLock::new(HashMap::new()),
48 reserved_localparts: RwLock::new(HashSet::new()),
49 }
50 }
51
52 pub async fn reserve_localpart(&self, localpart: &'static str) {
53 self.reserved_localparts.write().await.insert(localpart);
54 }
55
56 pub async fn query_user_raw(&self, localpart: &str) -> Result<MockUser, anyhow::Error> {
62 let mxid = self.mxid(localpart);
63 let users = self.users.read().await;
64 let user = users.get(&mxid).context("User not found")?;
65 Ok(user.clone())
66 }
67}
68
69#[async_trait]
70impl crate::HomeserverConnection for HomeserverConnection {
71 fn homeserver(&self) -> &str {
72 &self.homeserver
73 }
74
75 async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
76 Ok(token == Self::VALID_BEARER_TOKEN)
77 }
78
79 async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
80 let mxid = self.mxid(localpart);
81 let users = self.users.read().await;
82 let user = users.get(&mxid).context("User not found")?;
83 Ok(MatrixUser {
84 displayname: user.displayname.clone(),
85 avatar_url: user.avatar_url.clone(),
86 deactivated: user.deactivated,
87 })
88 }
89
90 async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
91 let mut users = self.users.write().await;
92 let mxid = self.mxid(request.localpart());
93 let inserted = !users.contains_key(&mxid);
94 let user = users.entry(mxid).or_insert(MockUser {
95 sub: request.sub().to_owned(),
96 avatar_url: None,
97 displayname: None,
98 devices: HashSet::new(),
99 emails: None,
100 cross_signing_reset_allowed: false,
101 deactivated: false,
102 locked: false,
103 });
104
105 anyhow::ensure!(
106 user.sub == request.sub(),
107 "User already provisioned with different sub"
108 );
109
110 request.on_emails(|emails| {
111 user.emails = emails.map(ToOwned::to_owned);
112 });
113
114 request.on_displayname(|displayname| {
115 user.displayname = displayname.map(ToOwned::to_owned);
116 });
117
118 request.on_avatar_url(|avatar_url| {
119 user.avatar_url = avatar_url.map(ToOwned::to_owned);
120 });
121
122 user.locked = request.locked();
123
124 Ok(inserted)
125 }
126
127 async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
128 if self.reserved_localparts.read().await.contains(localpart) {
129 return Ok(false);
130 }
131
132 let mxid = self.mxid(localpart);
133 let users = self.users.read().await;
134 Ok(!users.contains_key(&mxid))
135 }
136
137 async fn upsert_device(
138 &self,
139 localpart: &str,
140 device_id: &str,
141 _initial_display_name: Option<&str>,
142 ) -> Result<(), anyhow::Error> {
143 let mxid = self.mxid(localpart);
144 let mut users = self.users.write().await;
145 let user = users.get_mut(&mxid).context("User not found")?;
146 user.devices.insert(device_id.to_owned());
147 Ok(())
148 }
149
150 async fn update_device_display_name(
151 &self,
152 localpart: &str,
153 device_id: &str,
154 _display_name: &str,
155 ) -> Result<(), anyhow::Error> {
156 let mxid = self.mxid(localpart);
157 let mut users = self.users.write().await;
158 let user = users.get_mut(&mxid).context("User not found")?;
159 user.devices.get(device_id).context("Device not found")?;
160 Ok(())
161 }
162
163 async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
164 let mxid = self.mxid(localpart);
165 let mut users = self.users.write().await;
166 let user = users.get_mut(&mxid).context("User not found")?;
167 user.devices.remove(device_id);
168 Ok(())
169 }
170
171 async fn sync_devices(
172 &self,
173 localpart: &str,
174 devices: HashSet<String>,
175 ) -> Result<(), anyhow::Error> {
176 let mxid = self.mxid(localpart);
177 let mut users = self.users.write().await;
178 let user = users.get_mut(&mxid).context("User not found")?;
179 user.devices = devices;
180 Ok(())
181 }
182
183 async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
184 let mxid = self.mxid(localpart);
185 let mut users = self.users.write().await;
186 let user = users.get_mut(&mxid).context("User not found")?;
187 user.devices.clear();
188 user.emails = None;
189 user.deactivated = true;
190 if erase {
191 user.avatar_url = None;
192 user.displayname = None;
193 }
194
195 Ok(())
196 }
197
198 async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
199 let mxid = self.mxid(localpart);
200 let mut users = self.users.write().await;
201 let user = users.get_mut(&mxid).context("User not found")?;
202 user.deactivated = false;
203
204 Ok(())
205 }
206
207 async fn set_displayname(
208 &self,
209 localpart: &str,
210 displayname: &str,
211 ) -> Result<(), anyhow::Error> {
212 let mxid = self.mxid(localpart);
213 let mut users = self.users.write().await;
214 let user = users.get_mut(&mxid).context("User not found")?;
215 user.displayname = Some(displayname.to_owned());
216 Ok(())
217 }
218
219 async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
220 let mxid = self.mxid(localpart);
221 let mut users = self.users.write().await;
222 let user = users.get_mut(&mxid).context("User not found")?;
223 user.displayname = None;
224 Ok(())
225 }
226
227 async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
228 let mxid = self.mxid(localpart);
229 let mut users = self.users.write().await;
230 let user = users.get_mut(&mxid).context("User not found")?;
231 user.cross_signing_reset_allowed = true;
232 Ok(())
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[tokio::test]
241 async fn test_mock_connection() {
242 let conn = HomeserverConnection::new("example.org");
243
244 let mxid = "@test:example.org";
245 let device = "test";
246 assert_eq!(conn.homeserver(), "example.org");
247 assert_eq!(conn.mxid("test"), mxid);
248
249 assert!(conn.query_user("test").await.is_err());
250 assert!(conn.upsert_device("test", device, None).await.is_err());
251 assert!(conn.delete_device("test", device).await.is_err());
252
253 let request = ProvisionRequest::new("test", "test", false)
254 .set_displayname("Test User".into())
255 .set_avatar_url("mxc://example.org/1234567890".into())
256 .set_emails(vec!["test@example.org".to_owned()]);
257
258 let inserted = conn.provision_user(&request).await.unwrap();
259 assert!(inserted);
260
261 let user = conn.query_user("test").await.unwrap();
262 assert_eq!(user.displayname, Some("Test User".into()));
263 assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
264
265 assert!(conn.set_displayname("test", "John").await.is_ok());
267
268 let user = conn.query_user("test").await.unwrap();
269 assert_eq!(user.displayname, Some("John".into()));
270
271 assert!(conn.unset_displayname("test").await.is_ok());
273
274 let user = conn.query_user("test").await.unwrap();
275 assert_eq!(user.displayname, None);
276
277 assert!(conn.delete_device("test", device).await.is_ok());
279
280 assert!(conn.upsert_device("test", device, None).await.is_ok());
282 assert!(conn.upsert_device("test", device, None).await.is_ok());
284
285 assert!(conn.delete_device("test", device).await.is_ok());
288
289 assert!(!conn.is_localpart_available("test").await.unwrap());
291 assert!(conn.is_localpart_available("alice").await.unwrap());
293
294 conn.reserve_localpart("alice").await;
296 assert!(!conn.is_localpart_available("alice").await.unwrap());
297 }
298}