1pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14 Runtime,
15 wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22 AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput,
23 EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
24 ViolationVariant,
25};
26
27#[derive(Debug, Error)]
28pub enum LoadError {
29 #[error("failed to read module")]
30 Read(#[from] tokio::io::Error),
31
32 #[error("failed to create WASM engine")]
33 Engine(#[source] opa_wasm::wasmtime::Error),
34
35 #[error("module compilation task crashed")]
36 CompilationTask(#[from] tokio::task::JoinError),
37
38 #[error("failed to compile WASM module")]
39 Compilation(#[source] anyhow::Error),
40
41 #[error("invalid policy data")]
42 InvalidData(#[source] anyhow::Error),
43
44 #[error("failed to instantiate a test instance")]
45 Instantiate(#[source] InstantiateError),
46}
47
48impl LoadError {
49 #[doc(hidden)]
52 #[must_use]
53 pub fn invalid_data_example() -> Self {
54 Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
55 }
56}
57
58#[derive(Debug, Error)]
59pub enum InstantiateError {
60 #[error("failed to create WASM runtime")]
61 Runtime(#[source] anyhow::Error),
62
63 #[error("missing entrypoint {entrypoint}")]
64 MissingEntrypoint { entrypoint: String },
65
66 #[error("failed to load policy data")]
67 LoadData(#[source] anyhow::Error),
68}
69
70#[derive(Debug, Clone)]
72pub struct Entrypoints {
73 pub register: String,
74 pub client_registration: String,
75 pub authorization_grant: String,
76 pub compat_login: String,
77 pub email: String,
78}
79
80impl Entrypoints {
81 fn all(&self) -> [&str; 5] {
82 [
83 self.register.as_str(),
84 self.client_registration.as_str(),
85 self.authorization_grant.as_str(),
86 self.compat_login.as_str(),
87 self.email.as_str(),
88 ]
89 }
90}
91
92#[derive(Debug)]
94pub struct Data {
95 base: BaseData,
96
97 rest: Option<serde_json::Value>,
99}
100
101#[derive(Serialize, Debug)]
102pub struct BaseData {
103 pub server_name: String,
104
105 pub session_limit: Option<SessionLimitConfig>,
107}
108
109impl Data {
110 #[must_use]
111 pub fn new(base_data: BaseData) -> Self {
112 Self {
113 base: base_data,
114
115 rest: None,
116 }
117 }
118
119 #[must_use]
120 pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
121 self.rest = Some(rest);
122 self
123 }
124
125 fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
126 let base = serde_json::to_value(&self.base)?;
127
128 if let Some(rest) = &self.rest {
129 merge_data(base, rest.clone())
130 } else {
131 Ok(base)
132 }
133 }
134}
135
136fn value_kind(value: &serde_json::Value) -> &'static str {
137 match value {
138 serde_json::Value::Object(_) => "object",
139 serde_json::Value::Array(_) => "array",
140 serde_json::Value::String(_) => "string",
141 serde_json::Value::Number(_) => "number",
142 serde_json::Value::Bool(_) => "boolean",
143 serde_json::Value::Null => "null",
144 }
145}
146
147fn merge_data(
148 mut left: serde_json::Value,
149 right: serde_json::Value,
150) -> Result<serde_json::Value, anyhow::Error> {
151 merge_data_rec(&mut left, right)?;
152 Ok(left)
153}
154
155fn merge_data_rec(
156 left: &mut serde_json::Value,
157 right: serde_json::Value,
158) -> Result<(), anyhow::Error> {
159 match (left, right) {
160 (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
161 for (key, value) in right {
162 if let Some(left_value) = left.get_mut(&key) {
163 merge_data_rec(left_value, value)?;
164 } else {
165 left.insert(key, value);
166 }
167 }
168 }
169 (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
170 left.extend(right);
171 }
172 (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
174 *left = right;
175 }
176 (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
177 *left = right;
178 }
179 (serde_json::Value::String(left), serde_json::Value::String(right)) => {
180 *left = right;
181 }
182
183 (left, right) if left.is_null() => *left = right,
185
186 (left, right) if right.is_null() => *left = right,
188
189 (left, right) => anyhow::bail!(
190 "Cannot merge a {} into a {}",
191 value_kind(&right),
192 value_kind(left),
193 ),
194 }
195
196 Ok(())
197}
198
199struct DynamicData {
204 version: Option<Ulid>,
205 merged: serde_json::Value,
206}
207
208pub struct PolicyFactory {
209 engine: Engine,
210 module: Module,
211 data: Data,
212 dynamic_data: ArcSwap<DynamicData>,
213 entrypoints: Entrypoints,
214}
215
216impl PolicyFactory {
217 #[tracing::instrument(name = "policy.load", skip(source))]
223 pub async fn load(
224 mut source: impl AsyncRead + std::marker::Unpin,
225 data: Data,
226 entrypoints: Entrypoints,
227 ) -> Result<Self, LoadError> {
228 let mut config = Config::default();
229 config.cranelift_opt_level(OptLevel::SpeedAndSize);
230
231 let engine = Engine::new(&config).map_err(LoadError::Engine)?;
232
233 let mut buf = Vec::new();
235 source.read_to_end(&mut buf).await?;
236 let (engine, module) = tokio::task::spawn_blocking(move || {
238 let module = Module::new(&engine, buf)?;
239 anyhow::Ok((engine, module))
240 })
241 .await?
242 .map_err(LoadError::Compilation)?;
243
244 let merged = data.to_value().map_err(LoadError::InvalidData)?;
245 let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
246 version: None,
247 merged,
248 }));
249
250 let factory = Self {
251 engine,
252 module,
253 data,
254 dynamic_data,
255 entrypoints,
256 };
257
258 factory
260 .instantiate()
261 .await
262 .map_err(LoadError::Instantiate)?;
263
264 Ok(factory)
265 }
266
267 pub async fn set_dynamic_data(
280 &self,
281 dynamic_data: mas_data_model::PolicyData,
282 ) -> Result<bool, LoadError> {
283 if self.dynamic_data.load().version == Some(dynamic_data.id) {
286 return Ok(false);
288 }
289
290 let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
291 let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
292
293 self.instantiate_with_data(&merged)
295 .await
296 .map_err(LoadError::Instantiate)?;
297
298 self.dynamic_data.store(Arc::new(DynamicData {
300 version: Some(dynamic_data.id),
301 merged,
302 }));
303
304 Ok(true)
305 }
306
307 #[tracing::instrument(name = "policy.instantiate", skip_all)]
314 pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
315 let data = self.dynamic_data.load();
316 self.instantiate_with_data(&data.merged).await
317 }
318
319 async fn instantiate_with_data(
320 &self,
321 data: &serde_json::Value,
322 ) -> Result<Policy, InstantiateError> {
323 tracing::debug!("Instantiating policy with data={}", data);
324 let mut store = Store::new(&self.engine, ());
325 let runtime = Runtime::new(&mut store, &self.module)
326 .await
327 .map_err(InstantiateError::Runtime)?;
328
329 let policy_entrypoints = runtime.entrypoints();
331
332 for e in self.entrypoints.all() {
333 if !policy_entrypoints.contains(e) {
334 return Err(InstantiateError::MissingEntrypoint {
335 entrypoint: e.to_owned(),
336 });
337 }
338 }
339
340 let instance = runtime
341 .with_data(&mut store, data)
342 .await
343 .map_err(InstantiateError::LoadData)?;
344
345 Ok(Policy {
346 store,
347 instance,
348 entrypoints: self.entrypoints.clone(),
349 })
350 }
351}
352
353pub struct Policy {
354 store: Store<()>,
355 instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
356 entrypoints: Entrypoints,
357}
358
359#[derive(Debug, Error)]
360#[error("failed to evaluate policy")]
361pub enum EvaluationError {
362 Serialization(#[from] serde_json::Error),
363 Evaluation(#[from] anyhow::Error),
364}
365
366impl Policy {
367 #[tracing::instrument(
373 name = "policy.evaluate_email",
374 skip_all,
375 fields(
376 %input.email,
377 ),
378 )]
379 pub async fn evaluate_email(
380 &mut self,
381 input: EmailInput<'_>,
382 ) -> Result<EvaluationResult, EvaluationError> {
383 let [res]: [EvaluationResult; 1] = self
384 .instance
385 .evaluate(&mut self.store, &self.entrypoints.email, &input)
386 .await?;
387
388 Ok(res)
389 }
390
391 #[tracing::instrument(
397 name = "policy.evaluate.register",
398 skip_all,
399 fields(
400 ?input.registration_method,
401 input.username = input.username,
402 input.email = input.email,
403 ),
404 )]
405 pub async fn evaluate_register(
406 &mut self,
407 input: RegisterInput<'_>,
408 ) -> Result<EvaluationResult, EvaluationError> {
409 let [res]: [EvaluationResult; 1] = self
410 .instance
411 .evaluate(&mut self.store, &self.entrypoints.register, &input)
412 .await?;
413
414 Ok(res)
415 }
416
417 #[tracing::instrument(skip(self))]
423 pub async fn evaluate_client_registration(
424 &mut self,
425 input: ClientRegistrationInput<'_>,
426 ) -> Result<EvaluationResult, EvaluationError> {
427 let [res]: [EvaluationResult; 1] = self
428 .instance
429 .evaluate(
430 &mut self.store,
431 &self.entrypoints.client_registration,
432 &input,
433 )
434 .await?;
435
436 Ok(res)
437 }
438
439 #[tracing::instrument(
445 name = "policy.evaluate.authorization_grant",
446 skip_all,
447 fields(
448 %input.scope,
449 %input.client.id,
450 ),
451 )]
452 pub async fn evaluate_authorization_grant(
453 &mut self,
454 input: AuthorizationGrantInput<'_>,
455 ) -> Result<EvaluationResult, EvaluationError> {
456 let [res]: [EvaluationResult; 1] = self
457 .instance
458 .evaluate(
459 &mut self.store,
460 &self.entrypoints.authorization_grant,
461 &input,
462 )
463 .await?;
464
465 Ok(res)
466 }
467
468 #[tracing::instrument(
474 name = "policy.evaluate.compat_login",
475 skip_all,
476 fields(
477 %input.user.id,
478 ),
479 )]
480 pub async fn evaluate_compat_login(
481 &mut self,
482 input: CompatLoginInput<'_>,
483 ) -> Result<EvaluationResult, EvaluationError> {
484 let [res]: [EvaluationResult; 1] = self
485 .instance
486 .evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
487 .await?;
488
489 Ok(res)
490 }
491}
492
493#[cfg(test)]
494mod tests {
495
496 use std::time::SystemTime;
497
498 use super::*;
499
500 fn make_entrypoints() -> Entrypoints {
501 Entrypoints {
502 register: "register/violation".to_owned(),
503 client_registration: "client_registration/violation".to_owned(),
504 authorization_grant: "authorization_grant/violation".to_owned(),
505 compat_login: "compat_login/violation".to_owned(),
506 email: "email/violation".to_owned(),
507 }
508 }
509
510 #[tokio::test]
511 async fn test_register() {
512 let data = Data::new(BaseData {
513 server_name: "example.com".to_owned(),
514 session_limit: None,
515 })
516 .with_rest(serde_json::json!({
517 "allowed_domains": ["element.io", "*.element.io"],
518 "banned_domains": ["staging.element.io"],
519 }));
520
521 #[allow(clippy::disallowed_types)]
522 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
523 .join("..")
524 .join("..")
525 .join("policies")
526 .join("policy.wasm");
527
528 let file = tokio::fs::File::open(path).await.unwrap();
529
530 let factory = PolicyFactory::load(file, data, make_entrypoints())
531 .await
532 .unwrap();
533
534 let mut policy = factory.instantiate().await.unwrap();
535
536 let res = policy
537 .evaluate_register(RegisterInput {
538 registration_method: RegistrationMethod::Password,
539 username: "hello",
540 email: Some("hello@example.com"),
541 requester: Requester {
542 ip_address: None,
543 user_agent: None,
544 },
545 })
546 .await
547 .unwrap();
548 assert!(!res.valid());
549
550 let res = policy
551 .evaluate_register(RegisterInput {
552 registration_method: RegistrationMethod::Password,
553 username: "hello",
554 email: Some("hello@foo.element.io"),
555 requester: Requester {
556 ip_address: None,
557 user_agent: None,
558 },
559 })
560 .await
561 .unwrap();
562 assert!(res.valid());
563
564 let res = policy
565 .evaluate_register(RegisterInput {
566 registration_method: RegistrationMethod::Password,
567 username: "hello",
568 email: Some("hello@staging.element.io"),
569 requester: Requester {
570 ip_address: None,
571 user_agent: None,
572 },
573 })
574 .await
575 .unwrap();
576 assert!(!res.valid());
577 }
578
579 #[tokio::test]
580 async fn test_dynamic_data() {
581 let data = Data::new(BaseData {
582 server_name: "example.com".to_owned(),
583 session_limit: None,
584 });
585
586 #[allow(clippy::disallowed_types)]
587 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
588 .join("..")
589 .join("..")
590 .join("policies")
591 .join("policy.wasm");
592
593 let file = tokio::fs::File::open(path).await.unwrap();
594
595 let factory = PolicyFactory::load(file, data, make_entrypoints())
596 .await
597 .unwrap();
598
599 let mut policy = factory.instantiate().await.unwrap();
600
601 let res = policy
602 .evaluate_register(RegisterInput {
603 registration_method: RegistrationMethod::Password,
604 username: "hello",
605 email: Some("hello@example.com"),
606 requester: Requester {
607 ip_address: None,
608 user_agent: None,
609 },
610 })
611 .await
612 .unwrap();
613 assert!(res.valid());
614
615 factory
617 .set_dynamic_data(mas_data_model::PolicyData {
618 id: Ulid::nil(),
619 created_at: SystemTime::now().into(),
620 data: serde_json::json!({
621 "emails": {
622 "banned_addresses": {
623 "substrings": ["hello"]
624 }
625 }
626 }),
627 })
628 .await
629 .unwrap();
630 let mut policy = factory.instantiate().await.unwrap();
631 let res = policy
632 .evaluate_register(RegisterInput {
633 registration_method: RegistrationMethod::Password,
634 username: "hello",
635 email: Some("hello@example.com"),
636 requester: Requester {
637 ip_address: None,
638 user_agent: None,
639 },
640 })
641 .await
642 .unwrap();
643 assert!(!res.valid());
644 }
645
646 #[tokio::test]
647 async fn test_big_dynamic_data() {
648 let data = Data::new(BaseData {
649 server_name: "example.com".to_owned(),
650 session_limit: None,
651 });
652
653 #[allow(clippy::disallowed_types)]
654 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
655 .join("..")
656 .join("..")
657 .join("policies")
658 .join("policy.wasm");
659
660 let file = tokio::fs::File::open(path).await.unwrap();
661
662 let factory = PolicyFactory::load(file, data, make_entrypoints())
663 .await
664 .unwrap();
665
666 let data: Vec<String> = (0..(1024 * 1024 / 8))
669 .map(|i| format!("{:05}", i % 100_000))
670 .collect();
671 let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
672 factory
673 .set_dynamic_data(mas_data_model::PolicyData {
674 id: Ulid::nil(),
675 created_at: SystemTime::now().into(),
676 data: json,
677 })
678 .await
679 .unwrap();
680
681 let mut policy = factory.instantiate().await.unwrap();
684 let res = policy
685 .evaluate_register(RegisterInput {
686 registration_method: RegistrationMethod::Password,
687 username: "hello",
688 email: Some("12345@example.com"),
689 requester: Requester {
690 ip_address: None,
691 user_agent: None,
692 },
693 })
694 .await
695 .unwrap();
696 assert!(!res.valid());
697 }
698
699 #[test]
700 fn test_merge() {
701 use serde_json::json as j;
702
703 let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
705 assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
706
707 let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
709 assert_eq!(res, j!({"hello": "john"}));
710
711 let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
712 assert_eq!(res, j!({"hello": false}));
713
714 let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
715 assert_eq!(res, j!({"hello": 42}));
716
717 merge_data(j!({"hello": "world"}), j!({"hello": 123}))
719 .expect_err("Can't merge different types");
720
721 let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
723 assert_eq!(res, j!({"hello": ["world", "john"]}));
724
725 let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
727 assert_eq!(res, j!({"hello": null}));
728
729 let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
731 assert_eq!(res, j!({"hello": "world"}));
732
733 let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
735 assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
736 }
737}