Skip to main content

mas_policy/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14    Runtime,
15    wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22    AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput,
23    EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
24    ViolationVariant,
25};
26
27#[derive(Debug, Error)]
28pub enum LoadError {
29    #[error("failed to read module")]
30    Read(#[from] tokio::io::Error),
31
32    #[error("failed to create WASM engine")]
33    Engine(#[source] opa_wasm::wasmtime::Error),
34
35    #[error("module compilation task crashed")]
36    CompilationTask(#[from] tokio::task::JoinError),
37
38    #[error("failed to compile WASM module")]
39    Compilation(#[source] anyhow::Error),
40
41    #[error("invalid policy data")]
42    InvalidData(#[source] anyhow::Error),
43
44    #[error("failed to instantiate a test instance")]
45    Instantiate(#[source] InstantiateError),
46}
47
48impl LoadError {
49    /// Creates an example of an invalid data error, used for API response
50    /// documentation
51    #[doc(hidden)]
52    #[must_use]
53    pub fn invalid_data_example() -> Self {
54        Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
55    }
56}
57
58#[derive(Debug, Error)]
59pub enum InstantiateError {
60    #[error("failed to create WASM runtime")]
61    Runtime(#[source] anyhow::Error),
62
63    #[error("missing entrypoint {entrypoint}")]
64    MissingEntrypoint { entrypoint: String },
65
66    #[error("failed to load policy data")]
67    LoadData(#[source] anyhow::Error),
68}
69
70/// Holds the entrypoint of each policy
71#[derive(Debug, Clone)]
72pub struct Entrypoints {
73    pub register: String,
74    pub client_registration: String,
75    pub authorization_grant: String,
76    pub compat_login: String,
77    pub email: String,
78}
79
80impl Entrypoints {
81    fn all(&self) -> [&str; 5] {
82        [
83            self.register.as_str(),
84            self.client_registration.as_str(),
85            self.authorization_grant.as_str(),
86            self.compat_login.as_str(),
87            self.email.as_str(),
88        ]
89    }
90}
91
92/// Global static data that stays the same for the life of the [`PolicyFactory`]
93#[derive(Debug)]
94pub struct Data {
95    base: BaseData,
96
97    // We will merge this in a custom way, so don't emit as part of the base
98    rest: Option<serde_json::Value>,
99}
100
101#[derive(Serialize, Debug)]
102pub struct BaseData {
103    pub server_name: String,
104
105    /// Limits on the number of application sessions that each user can have
106    pub session_limit: Option<SessionLimitConfig>,
107}
108
109impl Data {
110    #[must_use]
111    pub fn new(base_data: BaseData) -> Self {
112        Self {
113            base: base_data,
114
115            rest: None,
116        }
117    }
118
119    #[must_use]
120    pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
121        self.rest = Some(rest);
122        self
123    }
124
125    fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
126        let base = serde_json::to_value(&self.base)?;
127
128        if let Some(rest) = &self.rest {
129            merge_data(base, rest.clone())
130        } else {
131            Ok(base)
132        }
133    }
134}
135
136fn value_kind(value: &serde_json::Value) -> &'static str {
137    match value {
138        serde_json::Value::Object(_) => "object",
139        serde_json::Value::Array(_) => "array",
140        serde_json::Value::String(_) => "string",
141        serde_json::Value::Number(_) => "number",
142        serde_json::Value::Bool(_) => "boolean",
143        serde_json::Value::Null => "null",
144    }
145}
146
147fn merge_data(
148    mut left: serde_json::Value,
149    right: serde_json::Value,
150) -> Result<serde_json::Value, anyhow::Error> {
151    merge_data_rec(&mut left, right)?;
152    Ok(left)
153}
154
155fn merge_data_rec(
156    left: &mut serde_json::Value,
157    right: serde_json::Value,
158) -> Result<(), anyhow::Error> {
159    match (left, right) {
160        (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
161            for (key, value) in right {
162                if let Some(left_value) = left.get_mut(&key) {
163                    merge_data_rec(left_value, value)?;
164                } else {
165                    left.insert(key, value);
166                }
167            }
168        }
169        (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
170            left.extend(right);
171        }
172        // Other values override
173        (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
174            *left = right;
175        }
176        (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
177            *left = right;
178        }
179        (serde_json::Value::String(left), serde_json::Value::String(right)) => {
180            *left = right;
181        }
182
183        // Null gets overridden by anything
184        (left, right) if left.is_null() => *left = right,
185
186        // Null on the right makes the left value null
187        (left, right) if right.is_null() => *left = right,
188
189        (left, right) => anyhow::bail!(
190            "Cannot merge a {} into a {}",
191            value_kind(&right),
192            value_kind(left),
193        ),
194    }
195
196    Ok(())
197}
198
199/// Global dynamic data
200///
201/// Hint: there is an admin API to manage this, see
202/// `crates/handlers/src/admin/v1/policy_data/set.rs`
203struct DynamicData {
204    version: Option<Ulid>,
205    merged: serde_json::Value,
206}
207
208pub struct PolicyFactory {
209    engine: Engine,
210    module: Module,
211    data: Data,
212    dynamic_data: ArcSwap<DynamicData>,
213    entrypoints: Entrypoints,
214}
215
216impl PolicyFactory {
217    /// Load the policy from the given data source.
218    ///
219    /// # Errors
220    ///
221    /// Returns an error if the policy can't be loaded or instantiated.
222    #[tracing::instrument(name = "policy.load", skip(source))]
223    pub async fn load(
224        mut source: impl AsyncRead + std::marker::Unpin,
225        data: Data,
226        entrypoints: Entrypoints,
227    ) -> Result<Self, LoadError> {
228        let mut config = Config::default();
229        config.cranelift_opt_level(OptLevel::SpeedAndSize);
230
231        let engine = Engine::new(&config).map_err(LoadError::Engine)?;
232
233        // Read and compile the module
234        let mut buf = Vec::new();
235        source.read_to_end(&mut buf).await?;
236        // Compilation is CPU-bound, so spawn that in a blocking task
237        let (engine, module) = tokio::task::spawn_blocking(move || {
238            let module = Module::new(&engine, buf)?;
239            anyhow::Ok((engine, module))
240        })
241        .await?
242        .map_err(LoadError::Compilation)?;
243
244        let merged = data.to_value().map_err(LoadError::InvalidData)?;
245        let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
246            version: None,
247            merged,
248        }));
249
250        let factory = Self {
251            engine,
252            module,
253            data,
254            dynamic_data,
255            entrypoints,
256        };
257
258        // Try to instantiate
259        factory
260            .instantiate()
261            .await
262            .map_err(LoadError::Instantiate)?;
263
264        Ok(factory)
265    }
266
267    /// Set the dynamic data for the policy.
268    ///
269    /// The `dynamic_data` object is merged with the static data given when the
270    /// policy was loaded.
271    ///
272    /// Returns `true` if the data was updated, `false` if the version
273    /// of the dynamic data was the same as the one we already have.
274    ///
275    /// # Errors
276    ///
277    /// Returns an error if the data can't be merged with the static data, or if
278    /// the policy can't be instantiated with the new data.
279    pub async fn set_dynamic_data(
280        &self,
281        dynamic_data: mas_data_model::PolicyData,
282    ) -> Result<bool, LoadError> {
283        // Check if the version of the dynamic data we have is the same as the one we're
284        // trying to set
285        if self.dynamic_data.load().version == Some(dynamic_data.id) {
286            // Don't do anything if the version is the same
287            return Ok(false);
288        }
289
290        let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
291        let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
292
293        // Try to instantiate with the new data
294        self.instantiate_with_data(&merged)
295            .await
296            .map_err(LoadError::Instantiate)?;
297
298        // If instantiation succeeds, swap the data
299        self.dynamic_data.store(Arc::new(DynamicData {
300            version: Some(dynamic_data.id),
301            merged,
302        }));
303
304        Ok(true)
305    }
306
307    /// Create a new policy instance.
308    ///
309    /// # Errors
310    ///
311    /// Returns an error if the policy can't be instantiated with the current
312    /// dynamic data.
313    #[tracing::instrument(name = "policy.instantiate", skip_all)]
314    pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
315        let data = self.dynamic_data.load();
316        self.instantiate_with_data(&data.merged).await
317    }
318
319    async fn instantiate_with_data(
320        &self,
321        data: &serde_json::Value,
322    ) -> Result<Policy, InstantiateError> {
323        tracing::debug!("Instantiating policy with data={}", data);
324        let mut store = Store::new(&self.engine, ());
325        let runtime = Runtime::new(&mut store, &self.module)
326            .await
327            .map_err(InstantiateError::Runtime)?;
328
329        // Check that we have the required entrypoints
330        let policy_entrypoints = runtime.entrypoints();
331
332        for e in self.entrypoints.all() {
333            if !policy_entrypoints.contains(e) {
334                return Err(InstantiateError::MissingEntrypoint {
335                    entrypoint: e.to_owned(),
336                });
337            }
338        }
339
340        let instance = runtime
341            .with_data(&mut store, data)
342            .await
343            .map_err(InstantiateError::LoadData)?;
344
345        Ok(Policy {
346            store,
347            instance,
348            entrypoints: self.entrypoints.clone(),
349        })
350    }
351}
352
353pub struct Policy {
354    store: Store<()>,
355    instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
356    entrypoints: Entrypoints,
357}
358
359#[derive(Debug, Error)]
360#[error("failed to evaluate policy")]
361pub enum EvaluationError {
362    Serialization(#[from] serde_json::Error),
363    Evaluation(#[from] anyhow::Error),
364}
365
366impl Policy {
367    /// Evaluate the 'email' entrypoint.
368    ///
369    /// # Errors
370    ///
371    /// Returns an error if the policy engine fails to evaluate the entrypoint.
372    #[tracing::instrument(
373        name = "policy.evaluate_email",
374        skip_all,
375        fields(
376            %input.email,
377        ),
378    )]
379    pub async fn evaluate_email(
380        &mut self,
381        input: EmailInput<'_>,
382    ) -> Result<EvaluationResult, EvaluationError> {
383        let [res]: [EvaluationResult; 1] = self
384            .instance
385            .evaluate(&mut self.store, &self.entrypoints.email, &input)
386            .await?;
387
388        Ok(res)
389    }
390
391    /// Evaluate the 'register' entrypoint.
392    ///
393    /// # Errors
394    ///
395    /// Returns an error if the policy engine fails to evaluate the entrypoint.
396    #[tracing::instrument(
397        name = "policy.evaluate.register",
398        skip_all,
399        fields(
400            ?input.registration_method,
401            input.username = input.username,
402            input.email = input.email,
403        ),
404    )]
405    pub async fn evaluate_register(
406        &mut self,
407        input: RegisterInput<'_>,
408    ) -> Result<EvaluationResult, EvaluationError> {
409        let [res]: [EvaluationResult; 1] = self
410            .instance
411            .evaluate(&mut self.store, &self.entrypoints.register, &input)
412            .await?;
413
414        Ok(res)
415    }
416
417    /// Evaluate the `client_registration` entrypoint.
418    ///
419    /// # Errors
420    ///
421    /// Returns an error if the policy engine fails to evaluate the entrypoint.
422    #[tracing::instrument(skip(self))]
423    pub async fn evaluate_client_registration(
424        &mut self,
425        input: ClientRegistrationInput<'_>,
426    ) -> Result<EvaluationResult, EvaluationError> {
427        let [res]: [EvaluationResult; 1] = self
428            .instance
429            .evaluate(
430                &mut self.store,
431                &self.entrypoints.client_registration,
432                &input,
433            )
434            .await?;
435
436        Ok(res)
437    }
438
439    /// Evaluate the `authorization_grant` entrypoint.
440    ///
441    /// # Errors
442    ///
443    /// Returns an error if the policy engine fails to evaluate the entrypoint.
444    #[tracing::instrument(
445        name = "policy.evaluate.authorization_grant",
446        skip_all,
447        fields(
448            %input.scope,
449            %input.client.id,
450        ),
451    )]
452    pub async fn evaluate_authorization_grant(
453        &mut self,
454        input: AuthorizationGrantInput<'_>,
455    ) -> Result<EvaluationResult, EvaluationError> {
456        let [res]: [EvaluationResult; 1] = self
457            .instance
458            .evaluate(
459                &mut self.store,
460                &self.entrypoints.authorization_grant,
461                &input,
462            )
463            .await?;
464
465        Ok(res)
466    }
467
468    /// Evaluate the `compat_login` entrypoint.
469    ///
470    /// # Errors
471    ///
472    /// Returns an error if the policy engine fails to evaluate the entrypoint.
473    #[tracing::instrument(
474        name = "policy.evaluate.compat_login",
475        skip_all,
476        fields(
477            %input.user.id,
478        ),
479    )]
480    pub async fn evaluate_compat_login(
481        &mut self,
482        input: CompatLoginInput<'_>,
483    ) -> Result<EvaluationResult, EvaluationError> {
484        let [res]: [EvaluationResult; 1] = self
485            .instance
486            .evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
487            .await?;
488
489        Ok(res)
490    }
491}
492
493#[cfg(test)]
494mod tests {
495
496    use std::time::SystemTime;
497
498    use super::*;
499
500    fn make_entrypoints() -> Entrypoints {
501        Entrypoints {
502            register: "register/violation".to_owned(),
503            client_registration: "client_registration/violation".to_owned(),
504            authorization_grant: "authorization_grant/violation".to_owned(),
505            compat_login: "compat_login/violation".to_owned(),
506            email: "email/violation".to_owned(),
507        }
508    }
509
510    #[tokio::test]
511    async fn test_register() {
512        let data = Data::new(BaseData {
513            server_name: "example.com".to_owned(),
514            session_limit: None,
515        })
516        .with_rest(serde_json::json!({
517            "allowed_domains": ["element.io", "*.element.io"],
518            "banned_domains": ["staging.element.io"],
519        }));
520
521        #[allow(clippy::disallowed_types)]
522        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
523            .join("..")
524            .join("..")
525            .join("policies")
526            .join("policy.wasm");
527
528        let file = tokio::fs::File::open(path).await.unwrap();
529
530        let factory = PolicyFactory::load(file, data, make_entrypoints())
531            .await
532            .unwrap();
533
534        let mut policy = factory.instantiate().await.unwrap();
535
536        let res = policy
537            .evaluate_register(RegisterInput {
538                registration_method: RegistrationMethod::Password,
539                username: "hello",
540                email: Some("hello@example.com"),
541                requester: Requester {
542                    ip_address: None,
543                    user_agent: None,
544                },
545            })
546            .await
547            .unwrap();
548        assert!(!res.valid());
549
550        let res = policy
551            .evaluate_register(RegisterInput {
552                registration_method: RegistrationMethod::Password,
553                username: "hello",
554                email: Some("hello@foo.element.io"),
555                requester: Requester {
556                    ip_address: None,
557                    user_agent: None,
558                },
559            })
560            .await
561            .unwrap();
562        assert!(res.valid());
563
564        let res = policy
565            .evaluate_register(RegisterInput {
566                registration_method: RegistrationMethod::Password,
567                username: "hello",
568                email: Some("hello@staging.element.io"),
569                requester: Requester {
570                    ip_address: None,
571                    user_agent: None,
572                },
573            })
574            .await
575            .unwrap();
576        assert!(!res.valid());
577    }
578
579    #[tokio::test]
580    async fn test_dynamic_data() {
581        let data = Data::new(BaseData {
582            server_name: "example.com".to_owned(),
583            session_limit: None,
584        });
585
586        #[allow(clippy::disallowed_types)]
587        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
588            .join("..")
589            .join("..")
590            .join("policies")
591            .join("policy.wasm");
592
593        let file = tokio::fs::File::open(path).await.unwrap();
594
595        let factory = PolicyFactory::load(file, data, make_entrypoints())
596            .await
597            .unwrap();
598
599        let mut policy = factory.instantiate().await.unwrap();
600
601        let res = policy
602            .evaluate_register(RegisterInput {
603                registration_method: RegistrationMethod::Password,
604                username: "hello",
605                email: Some("hello@example.com"),
606                requester: Requester {
607                    ip_address: None,
608                    user_agent: None,
609                },
610            })
611            .await
612            .unwrap();
613        assert!(res.valid());
614
615        // Update the policy data
616        factory
617            .set_dynamic_data(mas_data_model::PolicyData {
618                id: Ulid::nil(),
619                created_at: SystemTime::now().into(),
620                data: serde_json::json!({
621                    "emails": {
622                        "banned_addresses": {
623                            "substrings": ["hello"]
624                        }
625                    }
626                }),
627            })
628            .await
629            .unwrap();
630        let mut policy = factory.instantiate().await.unwrap();
631        let res = policy
632            .evaluate_register(RegisterInput {
633                registration_method: RegistrationMethod::Password,
634                username: "hello",
635                email: Some("hello@example.com"),
636                requester: Requester {
637                    ip_address: None,
638                    user_agent: None,
639                },
640            })
641            .await
642            .unwrap();
643        assert!(!res.valid());
644    }
645
646    #[tokio::test]
647    async fn test_big_dynamic_data() {
648        let data = Data::new(BaseData {
649            server_name: "example.com".to_owned(),
650            session_limit: None,
651        });
652
653        #[allow(clippy::disallowed_types)]
654        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
655            .join("..")
656            .join("..")
657            .join("policies")
658            .join("policy.wasm");
659
660        let file = tokio::fs::File::open(path).await.unwrap();
661
662        let factory = PolicyFactory::load(file, data, make_entrypoints())
663            .await
664            .unwrap();
665
666        // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
667        // characters including the quotes and a comma.
668        let data: Vec<String> = (0..(1024 * 1024 / 8))
669            .map(|i| format!("{:05}", i % 100_000))
670            .collect();
671        let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
672        factory
673            .set_dynamic_data(mas_data_model::PolicyData {
674                id: Ulid::nil(),
675                created_at: SystemTime::now().into(),
676                data: json,
677            })
678            .await
679            .unwrap();
680
681        // Try instantiating the policy, make sure 5-digit numbers are banned from email
682        // addresses
683        let mut policy = factory.instantiate().await.unwrap();
684        let res = policy
685            .evaluate_register(RegisterInput {
686                registration_method: RegistrationMethod::Password,
687                username: "hello",
688                email: Some("12345@example.com"),
689                requester: Requester {
690                    ip_address: None,
691                    user_agent: None,
692                },
693            })
694            .await
695            .unwrap();
696        assert!(!res.valid());
697    }
698
699    #[test]
700    fn test_merge() {
701        use serde_json::json as j;
702
703        // Merging objects
704        let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
705        assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
706
707        // Override a value of the same type
708        let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
709        assert_eq!(res, j!({"hello": "john"}));
710
711        let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
712        assert_eq!(res, j!({"hello": false}));
713
714        let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
715        assert_eq!(res, j!({"hello": 42}));
716
717        // Override a value of a different type
718        merge_data(j!({"hello": "world"}), j!({"hello": 123}))
719            .expect_err("Can't merge different types");
720
721        // Merge arrays
722        let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
723        assert_eq!(res, j!({"hello": ["world", "john"]}));
724
725        // Null overrides a value
726        let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
727        assert_eq!(res, j!({"hello": null}));
728
729        // Null gets overridden by a value
730        let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
731        assert_eq!(res, j!({"hello": "world"}));
732
733        // Objects get deeply merged
734        let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
735        assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
736    }
737}