mas_matrix/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7mod mock;
8mod readonly;
9
10use std::{collections::HashSet, sync::Arc};
11
12use ruma_common::UserId;
13
14pub use self::{
15    mock::HomeserverConnection as MockHomeserverConnection, readonly::ReadOnlyHomeserverConnection,
16};
17
18#[derive(Debug)]
19pub struct MatrixUser {
20    pub displayname: Option<String>,
21    pub avatar_url: Option<String>,
22    pub deactivated: bool,
23}
24
25#[derive(Debug, Default)]
26enum FieldAction<T> {
27    #[default]
28    DoNothing,
29    Set(T),
30    Unset,
31}
32
33pub struct ProvisionRequest {
34    localpart: String,
35    sub: String,
36    locked: bool,
37    displayname: FieldAction<String>,
38    avatar_url: FieldAction<String>,
39    emails: FieldAction<Vec<String>>,
40}
41
42impl ProvisionRequest {
43    /// Create a new [`ProvisionRequest`].
44    ///
45    /// # Parameters
46    ///
47    /// * `localpart` - The localpart of the user to provision.
48    /// * `sub` - The `sub` of the user, aka the internal ID.
49    /// * `locked` - Whether the user is locked.
50    #[must_use]
51    pub fn new(localpart: impl Into<String>, sub: impl Into<String>, locked: bool) -> Self {
52        Self {
53            localpart: localpart.into(),
54            sub: sub.into(),
55            locked,
56            displayname: FieldAction::DoNothing,
57            avatar_url: FieldAction::DoNothing,
58            emails: FieldAction::DoNothing,
59        }
60    }
61
62    /// Get the `sub` of the user to provision, aka the internal ID.
63    #[must_use]
64    pub fn sub(&self) -> &str {
65        &self.sub
66    }
67
68    /// Get the localpart of the user to provision.
69    #[must_use]
70    pub fn localpart(&self) -> &str {
71        &self.localpart
72    }
73
74    /// Get the locked flag of the user to provision
75    #[must_use]
76    pub fn locked(&self) -> bool {
77        self.locked
78    }
79
80    /// Ask to set the displayname of the user.
81    ///
82    /// # Parameters
83    ///
84    /// * `displayname` - The displayname to set.
85    #[must_use]
86    pub fn set_displayname(mut self, displayname: String) -> Self {
87        self.displayname = FieldAction::Set(displayname);
88        self
89    }
90
91    /// Ask to unset the displayname of the user.
92    #[must_use]
93    pub fn unset_displayname(mut self) -> Self {
94        self.displayname = FieldAction::Unset;
95        self
96    }
97
98    /// Call the given callback if the displayname should be set or unset.
99    ///
100    /// # Parameters
101    ///
102    /// * `callback` - The callback to call.
103    pub fn on_displayname<F>(&self, callback: F) -> &Self
104    where
105        F: FnOnce(Option<&str>),
106    {
107        match &self.displayname {
108            FieldAction::Unset => callback(None),
109            FieldAction::Set(displayname) => callback(Some(displayname)),
110            FieldAction::DoNothing => {}
111        }
112
113        self
114    }
115
116    /// Ask to set the avatar URL of the user.
117    ///
118    /// # Parameters
119    ///
120    /// * `avatar_url` - The avatar URL to set.
121    #[must_use]
122    pub fn set_avatar_url(mut self, avatar_url: String) -> Self {
123        self.avatar_url = FieldAction::Set(avatar_url);
124        self
125    }
126
127    /// Ask to unset the avatar URL of the user.
128    #[must_use]
129    pub fn unset_avatar_url(mut self) -> Self {
130        self.avatar_url = FieldAction::Unset;
131        self
132    }
133
134    /// Call the given callback if the avatar URL should be set or unset.
135    ///
136    /// # Parameters
137    ///
138    /// * `callback` - The callback to call.
139    pub fn on_avatar_url<F>(&self, callback: F) -> &Self
140    where
141        F: FnOnce(Option<&str>),
142    {
143        match &self.avatar_url {
144            FieldAction::Unset => callback(None),
145            FieldAction::Set(avatar_url) => callback(Some(avatar_url)),
146            FieldAction::DoNothing => {}
147        }
148
149        self
150    }
151
152    /// Ask to set the emails of the user.
153    ///
154    /// # Parameters
155    ///
156    /// * `emails` - The list of emails to set.
157    #[must_use]
158    pub fn set_emails(mut self, emails: Vec<String>) -> Self {
159        self.emails = FieldAction::Set(emails);
160        self
161    }
162
163    /// Ask to unset the emails of the user.
164    #[must_use]
165    pub fn unset_emails(mut self) -> Self {
166        self.emails = FieldAction::Unset;
167        self
168    }
169
170    /// Call the given callback if the emails should be set or unset.
171    ///
172    /// # Parameters
173    ///
174    /// * `callback` - The callback to call.
175    pub fn on_emails<F>(&self, callback: F) -> &Self
176    where
177        F: FnOnce(Option<&[String]>),
178    {
179        match &self.emails {
180            FieldAction::Unset => callback(None),
181            FieldAction::Set(emails) => callback(Some(emails)),
182            FieldAction::DoNothing => {}
183        }
184
185        self
186    }
187}
188
189#[async_trait::async_trait]
190pub trait HomeserverConnection: Send + Sync {
191    /// Get the homeserver URL.
192    fn homeserver(&self) -> &str;
193
194    /// Get the Matrix ID of the user with the given localpart.
195    ///
196    /// # Parameters
197    ///
198    /// * `localpart` - The localpart of the user.
199    fn mxid(&self, localpart: &str) -> String {
200        format!("@{}:{}", localpart, self.homeserver())
201    }
202
203    /// Get the localpart of a Matrix ID if it has the right server name
204    ///
205    /// Returns [`None`] if the input isn't a valid MXID, or if the server name
206    /// doesn't match
207    ///
208    /// # Parameters
209    ///
210    /// * `mxid` - The MXID of the user
211    fn localpart<'a>(&self, mxid: &'a str) -> Option<&'a str> {
212        let mxid = <&UserId>::try_from(mxid).ok()?;
213        if mxid.server_name() != self.homeserver() {
214            return None;
215        }
216        Some(mxid.localpart())
217    }
218
219    /// Verify a bearer token coming from the homeserver for homeserver to MAS
220    /// interactions
221    ///
222    /// Returns `true` if the token is valid, `false` otherwise.
223    ///
224    /// # Parameters
225    ///
226    /// * `token` - The token to verify.
227    ///
228    /// # Errors
229    ///
230    /// Returns an error if the token failed to verify.
231    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error>;
232
233    /// Query the state of a user on the homeserver.
234    ///
235    /// # Parameters
236    ///
237    /// * `localpart` - The localpart of the user to query.
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if the homeserver is unreachable or the user does not
242    /// exist.
243    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error>;
244
245    /// Provision a user on the homeserver.
246    ///
247    /// # Parameters
248    ///
249    /// * `request` - a [`ProvisionRequest`] containing the details of the user
250    ///   to provision.
251    ///
252    /// # Errors
253    ///
254    /// Returns an error if the homeserver is unreachable or the user could not
255    /// be provisioned.
256    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error>;
257
258    /// Check whether a given username is available on the homeserver.
259    ///
260    /// # Parameters
261    ///
262    /// * `localpart` - The localpart to check.
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if the homeserver is unreachable.
267    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error>;
268
269    /// Create a device for a user on the homeserver.
270    ///
271    /// # Parameters
272    ///
273    /// * `localpart` - The localpart of the user to create a device for.
274    /// * `device_id` - The device ID to create.
275    ///
276    /// # Errors
277    ///
278    /// Returns an error if the homeserver is unreachable or the device could
279    /// not be created.
280    async fn upsert_device(
281        &self,
282        localpart: &str,
283        device_id: &str,
284        initial_display_name: Option<&str>,
285    ) -> Result<(), anyhow::Error>;
286
287    /// Update the display name of a device for a user on the homeserver.
288    ///
289    /// # Parameters
290    ///
291    /// * `localpart` - The localpart of the user to update a device for.
292    /// * `device_id` - The device ID to update.
293    /// * `display_name` - The new display name to set
294    ///
295    /// # Errors
296    ///
297    /// Returns an error if the homeserver is unreachable or the device could
298    /// not be updated.
299    async fn update_device_display_name(
300        &self,
301        localpart: &str,
302        device_id: &str,
303        display_name: &str,
304    ) -> Result<(), anyhow::Error>;
305
306    /// Delete a device for a user on the homeserver.
307    ///
308    /// # Parameters
309    ///
310    /// * `localpart` - The localpart of the user to delete a device for.
311    /// * `device_id` - The device ID to delete.
312    ///
313    /// # Errors
314    ///
315    /// Returns an error if the homeserver is unreachable or the device could
316    /// not be deleted.
317    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error>;
318
319    /// Sync the list of devices of a user with the homeserver.
320    ///
321    /// # Parameters
322    ///
323    /// * `localpart` - The localpart of the user to sync the devices for.
324    /// * `devices` - The list of devices to sync.
325    ///
326    /// # Errors
327    ///
328    /// Returns an error if the homeserver is unreachable or the devices could
329    /// not be synced.
330    async fn sync_devices(
331        &self,
332        localpart: &str,
333        devices: HashSet<String>,
334    ) -> Result<(), anyhow::Error>;
335
336    /// Delete a user on the homeserver.
337    ///
338    /// # Parameters
339    ///
340    /// * `localpart` - The localpart of the user to delete.
341    /// * `erase` - Whether to ask the homeserver to erase the user's data.
342    ///
343    /// # Errors
344    ///
345    /// Returns an error if the homeserver is unreachable or the user could not
346    /// be deleted.
347    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error>;
348
349    /// Reactivate a user on the homeserver.
350    ///
351    /// # Parameters
352    ///
353    /// * `localpart` - The localpart of the user to reactivate.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if the homeserver is unreachable or the user could not
358    /// be reactivated.
359    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error>;
360
361    /// Set the displayname of a user on the homeserver.
362    ///
363    /// # Parameters
364    ///
365    /// * `localpart` - The localpart of the user to set the displayname for.
366    /// * `displayname` - The displayname to set.
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if the homeserver is unreachable or the displayname
371    /// could not be set.
372    async fn set_displayname(
373        &self,
374        localpart: &str,
375        displayname: &str,
376    ) -> Result<(), anyhow::Error>;
377
378    /// Unset the displayname of a user on the homeserver.
379    ///
380    /// # Parameters
381    ///
382    /// * `localpart` - The localpart of the user to unset the displayname for.
383    ///
384    /// # Errors
385    ///
386    /// Returns an error if the homeserver is unreachable or the displayname
387    /// could not be unset.
388    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error>;
389
390    /// Temporarily allow a user to reset their cross-signing keys.
391    ///
392    /// # Parameters
393    ///
394    /// * `localpart` - The localpart of the user to allow cross-signing key
395    ///   reset
396    ///
397    /// # Errors
398    ///
399    /// Returns an error if the homeserver is unreachable or the cross-signing
400    /// reset could not be allowed.
401    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error>;
402}
403
404#[async_trait::async_trait]
405impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T {
406    fn homeserver(&self) -> &str {
407        (**self).homeserver()
408    }
409
410    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
411        (**self).verify_token(token).await
412    }
413
414    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
415        (**self).query_user(localpart).await
416    }
417
418    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
419        (**self).provision_user(request).await
420    }
421
422    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
423        (**self).is_localpart_available(localpart).await
424    }
425
426    async fn upsert_device(
427        &self,
428        localpart: &str,
429        device_id: &str,
430        initial_display_name: Option<&str>,
431    ) -> Result<(), anyhow::Error> {
432        (**self)
433            .upsert_device(localpart, device_id, initial_display_name)
434            .await
435    }
436
437    async fn update_device_display_name(
438        &self,
439        localpart: &str,
440        device_id: &str,
441        display_name: &str,
442    ) -> Result<(), anyhow::Error> {
443        (**self)
444            .update_device_display_name(localpart, device_id, display_name)
445            .await
446    }
447
448    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
449        (**self).delete_device(localpart, device_id).await
450    }
451
452    async fn sync_devices(
453        &self,
454        localpart: &str,
455        devices: HashSet<String>,
456    ) -> Result<(), anyhow::Error> {
457        (**self).sync_devices(localpart, devices).await
458    }
459
460    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
461        (**self).delete_user(localpart, erase).await
462    }
463
464    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
465        (**self).reactivate_user(localpart).await
466    }
467
468    async fn set_displayname(
469        &self,
470        localpart: &str,
471        displayname: &str,
472    ) -> Result<(), anyhow::Error> {
473        (**self).set_displayname(localpart, displayname).await
474    }
475
476    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
477        (**self).unset_displayname(localpart).await
478    }
479
480    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
481        (**self).allow_cross_signing_reset(localpart).await
482    }
483}
484
485// Implement for Arc<T> where T: HomeserverConnection
486#[async_trait::async_trait]
487impl<T: HomeserverConnection + ?Sized> HomeserverConnection for Arc<T> {
488    fn homeserver(&self) -> &str {
489        (**self).homeserver()
490    }
491
492    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
493        (**self).verify_token(token).await
494    }
495
496    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
497        (**self).query_user(localpart).await
498    }
499
500    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
501        (**self).provision_user(request).await
502    }
503
504    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
505        (**self).is_localpart_available(localpart).await
506    }
507
508    async fn upsert_device(
509        &self,
510        localpart: &str,
511        device_id: &str,
512        initial_display_name: Option<&str>,
513    ) -> Result<(), anyhow::Error> {
514        (**self)
515            .upsert_device(localpart, device_id, initial_display_name)
516            .await
517    }
518
519    async fn update_device_display_name(
520        &self,
521        localpart: &str,
522        device_id: &str,
523        display_name: &str,
524    ) -> Result<(), anyhow::Error> {
525        (**self)
526            .update_device_display_name(localpart, device_id, display_name)
527            .await
528    }
529
530    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
531        (**self).delete_device(localpart, device_id).await
532    }
533
534    async fn sync_devices(
535        &self,
536        localpart: &str,
537        devices: HashSet<String>,
538    ) -> Result<(), anyhow::Error> {
539        (**self).sync_devices(localpart, devices).await
540    }
541
542    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
543        (**self).delete_user(localpart, erase).await
544    }
545
546    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
547        (**self).reactivate_user(localpart).await
548    }
549
550    async fn set_displayname(
551        &self,
552        localpart: &str,
553        displayname: &str,
554    ) -> Result<(), anyhow::Error> {
555        (**self).set_displayname(localpart, displayname).await
556    }
557
558    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
559        (**self).unset_displayname(localpart).await
560    }
561
562    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
563        (**self).allow_cross_signing_reset(localpart).await
564    }
565}