mas_policy/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14    Runtime,
15    wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22    AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
23    EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
24};
25
26#[derive(Debug, Error)]
27pub enum LoadError {
28    #[error("failed to read module")]
29    Read(#[from] tokio::io::Error),
30
31    #[error("failed to create WASM engine")]
32    Engine(#[source] anyhow::Error),
33
34    #[error("module compilation task crashed")]
35    CompilationTask(#[from] tokio::task::JoinError),
36
37    #[error("failed to compile WASM module")]
38    Compilation(#[source] anyhow::Error),
39
40    #[error("invalid policy data")]
41    InvalidData(#[source] anyhow::Error),
42
43    #[error("failed to instantiate a test instance")]
44    Instantiate(#[source] InstantiateError),
45}
46
47impl LoadError {
48    /// Creates an example of an invalid data error, used for API response
49    /// documentation
50    #[doc(hidden)]
51    #[must_use]
52    pub fn invalid_data_example() -> Self {
53        Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
54    }
55}
56
57#[derive(Debug, Error)]
58pub enum InstantiateError {
59    #[error("failed to create WASM runtime")]
60    Runtime(#[source] anyhow::Error),
61
62    #[error("missing entrypoint {entrypoint}")]
63    MissingEntrypoint { entrypoint: String },
64
65    #[error("failed to load policy data")]
66    LoadData(#[source] anyhow::Error),
67}
68
69/// Holds the entrypoint of each policy
70#[derive(Debug, Clone)]
71pub struct Entrypoints {
72    pub register: String,
73    pub client_registration: String,
74    pub authorization_grant: String,
75    pub email: String,
76}
77
78impl Entrypoints {
79    fn all(&self) -> [&str; 4] {
80        [
81            self.register.as_str(),
82            self.client_registration.as_str(),
83            self.authorization_grant.as_str(),
84            self.email.as_str(),
85        ]
86    }
87}
88
89#[derive(Debug)]
90pub struct Data {
91    base: BaseData,
92
93    // We will merge this in a custom way, so don't emit as part of the base
94    rest: Option<serde_json::Value>,
95}
96
97#[derive(Serialize, Debug)]
98struct BaseData {
99    server_name: String,
100
101    /// Limits on the number of application sessions that each user can have
102    session_limit: Option<SessionLimitConfig>,
103}
104
105impl Data {
106    #[must_use]
107    pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
108        Self {
109            base: BaseData {
110                server_name,
111                session_limit,
112            },
113
114            rest: None,
115        }
116    }
117
118    #[must_use]
119    pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
120        self.rest = Some(rest);
121        self
122    }
123
124    fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
125        let base = serde_json::to_value(&self.base)?;
126
127        if let Some(rest) = &self.rest {
128            merge_data(base, rest.clone())
129        } else {
130            Ok(base)
131        }
132    }
133}
134
135fn value_kind(value: &serde_json::Value) -> &'static str {
136    match value {
137        serde_json::Value::Object(_) => "object",
138        serde_json::Value::Array(_) => "array",
139        serde_json::Value::String(_) => "string",
140        serde_json::Value::Number(_) => "number",
141        serde_json::Value::Bool(_) => "boolean",
142        serde_json::Value::Null => "null",
143    }
144}
145
146fn merge_data(
147    mut left: serde_json::Value,
148    right: serde_json::Value,
149) -> Result<serde_json::Value, anyhow::Error> {
150    merge_data_rec(&mut left, right)?;
151    Ok(left)
152}
153
154fn merge_data_rec(
155    left: &mut serde_json::Value,
156    right: serde_json::Value,
157) -> Result<(), anyhow::Error> {
158    match (left, right) {
159        (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
160            for (key, value) in right {
161                if let Some(left_value) = left.get_mut(&key) {
162                    merge_data_rec(left_value, value)?;
163                } else {
164                    left.insert(key, value);
165                }
166            }
167        }
168        (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
169            left.extend(right);
170        }
171        // Other values override
172        (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
173            *left = right;
174        }
175        (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
176            *left = right;
177        }
178        (serde_json::Value::String(left), serde_json::Value::String(right)) => {
179            *left = right;
180        }
181
182        // Null gets overridden by anything
183        (left, right) if left.is_null() => *left = right,
184
185        // Null on the right makes the left value null
186        (left, right) if right.is_null() => *left = right,
187
188        (left, right) => anyhow::bail!(
189            "Cannot merge a {} into a {}",
190            value_kind(&right),
191            value_kind(left),
192        ),
193    }
194
195    Ok(())
196}
197
198struct DynamicData {
199    version: Option<Ulid>,
200    merged: serde_json::Value,
201}
202
203pub struct PolicyFactory {
204    engine: Engine,
205    module: Module,
206    data: Data,
207    dynamic_data: ArcSwap<DynamicData>,
208    entrypoints: Entrypoints,
209}
210
211impl PolicyFactory {
212    /// Load the policy from the given data source.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if the policy can't be loaded or instantiated.
217    #[tracing::instrument(name = "policy.load", skip(source))]
218    pub async fn load(
219        mut source: impl AsyncRead + std::marker::Unpin,
220        data: Data,
221        entrypoints: Entrypoints,
222    ) -> Result<Self, LoadError> {
223        let mut config = Config::default();
224        config.async_support(true);
225        config.cranelift_opt_level(OptLevel::SpeedAndSize);
226
227        let engine = Engine::new(&config).map_err(LoadError::Engine)?;
228
229        // Read and compile the module
230        let mut buf = Vec::new();
231        source.read_to_end(&mut buf).await?;
232        // Compilation is CPU-bound, so spawn that in a blocking task
233        let (engine, module) = tokio::task::spawn_blocking(move || {
234            let module = Module::new(&engine, buf)?;
235            anyhow::Ok((engine, module))
236        })
237        .await?
238        .map_err(LoadError::Compilation)?;
239
240        let merged = data.to_value().map_err(LoadError::InvalidData)?;
241        let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
242            version: None,
243            merged,
244        }));
245
246        let factory = Self {
247            engine,
248            module,
249            data,
250            dynamic_data,
251            entrypoints,
252        };
253
254        // Try to instantiate
255        factory
256            .instantiate()
257            .await
258            .map_err(LoadError::Instantiate)?;
259
260        Ok(factory)
261    }
262
263    /// Set the dynamic data for the policy.
264    ///
265    /// The `dynamic_data` object is merged with the static data given when the
266    /// policy was loaded.
267    ///
268    /// Returns `true` if the data was updated, `false` if the version
269    /// of the dynamic data was the same as the one we already have.
270    ///
271    /// # Errors
272    ///
273    /// Returns an error if the data can't be merged with the static data, or if
274    /// the policy can't be instantiated with the new data.
275    pub async fn set_dynamic_data(
276        &self,
277        dynamic_data: mas_data_model::PolicyData,
278    ) -> Result<bool, LoadError> {
279        // Check if the version of the dynamic data we have is the same as the one we're
280        // trying to set
281        if self.dynamic_data.load().version == Some(dynamic_data.id) {
282            // Don't do anything if the version is the same
283            return Ok(false);
284        }
285
286        let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
287        let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
288
289        // Try to instantiate with the new data
290        self.instantiate_with_data(&merged)
291            .await
292            .map_err(LoadError::Instantiate)?;
293
294        // If instantiation succeeds, swap the data
295        self.dynamic_data.store(Arc::new(DynamicData {
296            version: Some(dynamic_data.id),
297            merged,
298        }));
299
300        Ok(true)
301    }
302
303    /// Create a new policy instance.
304    ///
305    /// # Errors
306    ///
307    /// Returns an error if the policy can't be instantiated with the current
308    /// dynamic data.
309    #[tracing::instrument(name = "policy.instantiate", skip_all)]
310    pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
311        let data = self.dynamic_data.load();
312        self.instantiate_with_data(&data.merged).await
313    }
314
315    async fn instantiate_with_data(
316        &self,
317        data: &serde_json::Value,
318    ) -> Result<Policy, InstantiateError> {
319        let mut store = Store::new(&self.engine, ());
320        let runtime = Runtime::new(&mut store, &self.module)
321            .await
322            .map_err(InstantiateError::Runtime)?;
323
324        // Check that we have the required entrypoints
325        let policy_entrypoints = runtime.entrypoints();
326
327        for e in self.entrypoints.all() {
328            if !policy_entrypoints.contains(e) {
329                return Err(InstantiateError::MissingEntrypoint {
330                    entrypoint: e.to_owned(),
331                });
332            }
333        }
334
335        let instance = runtime
336            .with_data(&mut store, data)
337            .await
338            .map_err(InstantiateError::LoadData)?;
339
340        Ok(Policy {
341            store,
342            instance,
343            entrypoints: self.entrypoints.clone(),
344        })
345    }
346}
347
348pub struct Policy {
349    store: Store<()>,
350    instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
351    entrypoints: Entrypoints,
352}
353
354#[derive(Debug, Error)]
355#[error("failed to evaluate policy")]
356pub enum EvaluationError {
357    Serialization(#[from] serde_json::Error),
358    Evaluation(#[from] anyhow::Error),
359}
360
361impl Policy {
362    /// Evaluate the 'email' entrypoint.
363    ///
364    /// # Errors
365    ///
366    /// Returns an error if the policy engine fails to evaluate the entrypoint.
367    #[tracing::instrument(
368        name = "policy.evaluate_email",
369        skip_all,
370        fields(
371            %input.email,
372        ),
373    )]
374    pub async fn evaluate_email(
375        &mut self,
376        input: EmailInput<'_>,
377    ) -> Result<EvaluationResult, EvaluationError> {
378        let [res]: [EvaluationResult; 1] = self
379            .instance
380            .evaluate(&mut self.store, &self.entrypoints.email, &input)
381            .await?;
382
383        Ok(res)
384    }
385
386    /// Evaluate the 'register' entrypoint.
387    ///
388    /// # Errors
389    ///
390    /// Returns an error if the policy engine fails to evaluate the entrypoint.
391    #[tracing::instrument(
392        name = "policy.evaluate.register",
393        skip_all,
394        fields(
395            ?input.registration_method,
396            input.username = input.username,
397            input.email = input.email,
398        ),
399    )]
400    pub async fn evaluate_register(
401        &mut self,
402        input: RegisterInput<'_>,
403    ) -> Result<EvaluationResult, EvaluationError> {
404        let [res]: [EvaluationResult; 1] = self
405            .instance
406            .evaluate(&mut self.store, &self.entrypoints.register, &input)
407            .await?;
408
409        Ok(res)
410    }
411
412    /// Evaluate the `client_registration` entrypoint.
413    ///
414    /// # Errors
415    ///
416    /// Returns an error if the policy engine fails to evaluate the entrypoint.
417    #[tracing::instrument(skip(self))]
418    pub async fn evaluate_client_registration(
419        &mut self,
420        input: ClientRegistrationInput<'_>,
421    ) -> Result<EvaluationResult, EvaluationError> {
422        let [res]: [EvaluationResult; 1] = self
423            .instance
424            .evaluate(
425                &mut self.store,
426                &self.entrypoints.client_registration,
427                &input,
428            )
429            .await?;
430
431        Ok(res)
432    }
433
434    /// Evaluate the `authorization_grant` entrypoint.
435    ///
436    /// # Errors
437    ///
438    /// Returns an error if the policy engine fails to evaluate the entrypoint.
439    #[tracing::instrument(
440        name = "policy.evaluate.authorization_grant",
441        skip_all,
442        fields(
443            %input.scope,
444            %input.client.id,
445        ),
446    )]
447    pub async fn evaluate_authorization_grant(
448        &mut self,
449        input: AuthorizationGrantInput<'_>,
450    ) -> Result<EvaluationResult, EvaluationError> {
451        let [res]: [EvaluationResult; 1] = self
452            .instance
453            .evaluate(
454                &mut self.store,
455                &self.entrypoints.authorization_grant,
456                &input,
457            )
458            .await?;
459
460        Ok(res)
461    }
462}
463
464#[cfg(test)]
465mod tests {
466
467    use std::time::SystemTime;
468
469    use super::*;
470
471    #[tokio::test]
472    async fn test_register() {
473        let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
474            "allowed_domains": ["element.io", "*.element.io"],
475            "banned_domains": ["staging.element.io"],
476        }));
477
478        #[allow(clippy::disallowed_types)]
479        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
480            .join("..")
481            .join("..")
482            .join("policies")
483            .join("policy.wasm");
484
485        let file = tokio::fs::File::open(path).await.unwrap();
486
487        let entrypoints = Entrypoints {
488            register: "register/violation".to_owned(),
489            client_registration: "client_registration/violation".to_owned(),
490            authorization_grant: "authorization_grant/violation".to_owned(),
491            email: "email/violation".to_owned(),
492        };
493
494        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
495
496        let mut policy = factory.instantiate().await.unwrap();
497
498        let res = policy
499            .evaluate_register(RegisterInput {
500                registration_method: RegistrationMethod::Password,
501                username: "hello",
502                email: Some("hello@example.com"),
503                requester: Requester {
504                    ip_address: None,
505                    user_agent: None,
506                },
507            })
508            .await
509            .unwrap();
510        assert!(!res.valid());
511
512        let res = policy
513            .evaluate_register(RegisterInput {
514                registration_method: RegistrationMethod::Password,
515                username: "hello",
516                email: Some("hello@foo.element.io"),
517                requester: Requester {
518                    ip_address: None,
519                    user_agent: None,
520                },
521            })
522            .await
523            .unwrap();
524        assert!(res.valid());
525
526        let res = policy
527            .evaluate_register(RegisterInput {
528                registration_method: RegistrationMethod::Password,
529                username: "hello",
530                email: Some("hello@staging.element.io"),
531                requester: Requester {
532                    ip_address: None,
533                    user_agent: None,
534                },
535            })
536            .await
537            .unwrap();
538        assert!(!res.valid());
539    }
540
541    #[tokio::test]
542    async fn test_dynamic_data() {
543        let data = Data::new("example.com".to_owned(), None);
544
545        #[allow(clippy::disallowed_types)]
546        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
547            .join("..")
548            .join("..")
549            .join("policies")
550            .join("policy.wasm");
551
552        let file = tokio::fs::File::open(path).await.unwrap();
553
554        let entrypoints = Entrypoints {
555            register: "register/violation".to_owned(),
556            client_registration: "client_registration/violation".to_owned(),
557            authorization_grant: "authorization_grant/violation".to_owned(),
558            email: "email/violation".to_owned(),
559        };
560
561        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
562
563        let mut policy = factory.instantiate().await.unwrap();
564
565        let res = policy
566            .evaluate_register(RegisterInput {
567                registration_method: RegistrationMethod::Password,
568                username: "hello",
569                email: Some("hello@example.com"),
570                requester: Requester {
571                    ip_address: None,
572                    user_agent: None,
573                },
574            })
575            .await
576            .unwrap();
577        assert!(res.valid());
578
579        // Update the policy data
580        factory
581            .set_dynamic_data(mas_data_model::PolicyData {
582                id: Ulid::nil(),
583                created_at: SystemTime::now().into(),
584                data: serde_json::json!({
585                    "emails": {
586                        "banned_addresses": {
587                            "substrings": ["hello"]
588                        }
589                    }
590                }),
591            })
592            .await
593            .unwrap();
594        let mut policy = factory.instantiate().await.unwrap();
595        let res = policy
596            .evaluate_register(RegisterInput {
597                registration_method: RegistrationMethod::Password,
598                username: "hello",
599                email: Some("hello@example.com"),
600                requester: Requester {
601                    ip_address: None,
602                    user_agent: None,
603                },
604            })
605            .await
606            .unwrap();
607        assert!(!res.valid());
608    }
609
610    #[tokio::test]
611    async fn test_big_dynamic_data() {
612        let data = Data::new("example.com".to_owned(), None);
613
614        #[allow(clippy::disallowed_types)]
615        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
616            .join("..")
617            .join("..")
618            .join("policies")
619            .join("policy.wasm");
620
621        let file = tokio::fs::File::open(path).await.unwrap();
622
623        let entrypoints = Entrypoints {
624            register: "register/violation".to_owned(),
625            client_registration: "client_registration/violation".to_owned(),
626            authorization_grant: "authorization_grant/violation".to_owned(),
627            email: "email/violation".to_owned(),
628        };
629
630        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
631
632        // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
633        // characters including the quotes and a comma.
634        let data: Vec<String> = (0..(1024 * 1024 / 8))
635            .map(|i| format!("{:05}", i % 100_000))
636            .collect();
637        let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
638        factory
639            .set_dynamic_data(mas_data_model::PolicyData {
640                id: Ulid::nil(),
641                created_at: SystemTime::now().into(),
642                data: json,
643            })
644            .await
645            .unwrap();
646
647        // Try instantiating the policy, make sure 5-digit numbers are banned from email
648        // addresses
649        let mut policy = factory.instantiate().await.unwrap();
650        let res = policy
651            .evaluate_register(RegisterInput {
652                registration_method: RegistrationMethod::Password,
653                username: "hello",
654                email: Some("12345@example.com"),
655                requester: Requester {
656                    ip_address: None,
657                    user_agent: None,
658                },
659            })
660            .await
661            .unwrap();
662        assert!(!res.valid());
663    }
664
665    #[test]
666    fn test_merge() {
667        use serde_json::json as j;
668
669        // Merging objects
670        let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
671        assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
672
673        // Override a value of the same type
674        let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
675        assert_eq!(res, j!({"hello": "john"}));
676
677        let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
678        assert_eq!(res, j!({"hello": false}));
679
680        let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
681        assert_eq!(res, j!({"hello": 42}));
682
683        // Override a value of a different type
684        merge_data(j!({"hello": "world"}), j!({"hello": 123}))
685            .expect_err("Can't merge different types");
686
687        // Merge arrays
688        let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
689        assert_eq!(res, j!({"hello": ["world", "john"]}));
690
691        // Null overrides a value
692        let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
693        assert_eq!(res, j!({"hello": null}));
694
695        // Null gets overridden by a value
696        let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
697        assert_eq!(res, j!({"hello": "world"}));
698
699        // Objects get deeply merged
700        let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
701        assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
702    }
703}