mas_policy/
model.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! Input and output types for policy evaluation.
8//!
9//! This is useful to generate JSON schemas for each input type, which can then
10//! be type-checked by Open Policy Agent.
11
12use std::net::IpAddr;
13
14use mas_data_model::{Client, User};
15use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18
19/// A well-known policy code.
20#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)]
21#[serde(rename_all = "kebab-case")]
22pub enum Code {
23    /// The username is too short.
24    UsernameTooShort,
25
26    /// The username is too long.
27    UsernameTooLong,
28
29    /// The username contains invalid characters.
30    UsernameInvalidChars,
31
32    /// The username contains only numeric characters.
33    UsernameAllNumeric,
34
35    /// The username is banned.
36    UsernameBanned,
37
38    /// The username is not allowed.
39    UsernameNotAllowed,
40
41    /// The email domain is not allowed.
42    EmailDomainNotAllowed,
43
44    /// The email domain is banned.
45    EmailDomainBanned,
46
47    /// The email address is not allowed.
48    EmailNotAllowed,
49
50    /// The email address is banned.
51    EmailBanned,
52
53    /// The user has reached their session limit.
54    TooManySessions,
55}
56
57impl Code {
58    /// Returns the code as a string
59    #[must_use]
60    pub fn as_str(&self) -> &'static str {
61        match self {
62            Self::UsernameTooShort => "username-too-short",
63            Self::UsernameTooLong => "username-too-long",
64            Self::UsernameInvalidChars => "username-invalid-chars",
65            Self::UsernameAllNumeric => "username-all-numeric",
66            Self::UsernameBanned => "username-banned",
67            Self::UsernameNotAllowed => "username-not-allowed",
68            Self::EmailDomainNotAllowed => "email-domain-not-allowed",
69            Self::EmailDomainBanned => "email-domain-banned",
70            Self::EmailNotAllowed => "email-not-allowed",
71            Self::EmailBanned => "email-banned",
72            Self::TooManySessions => "too-many-sessions",
73        }
74    }
75}
76
77/// A single violation of a policy.
78#[derive(Deserialize, Debug, JsonSchema)]
79pub struct Violation {
80    pub msg: String,
81    pub redirect_uri: Option<String>,
82    pub field: Option<String>,
83    pub code: Option<Code>,
84}
85
86/// The result of a policy evaluation.
87#[derive(Deserialize, Debug)]
88pub struct EvaluationResult {
89    #[serde(rename = "result")]
90    pub violations: Vec<Violation>,
91}
92
93impl std::fmt::Display for EvaluationResult {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        let mut first = true;
96        for violation in &self.violations {
97            if first {
98                first = false;
99            } else {
100                write!(f, ", ")?;
101            }
102            write!(f, "{}", violation.msg)?;
103        }
104        Ok(())
105    }
106}
107
108impl EvaluationResult {
109    /// Returns true if the policy evaluation was successful.
110    #[must_use]
111    pub fn valid(&self) -> bool {
112        self.violations.is_empty()
113    }
114}
115
116/// Identity of the requester
117#[derive(Serialize, Debug, Default, JsonSchema)]
118#[serde(rename_all = "snake_case")]
119pub struct Requester {
120    /// IP address of the entity making the request
121    pub ip_address: Option<IpAddr>,
122
123    /// User agent of the entity making the request
124    pub user_agent: Option<String>,
125}
126
127#[derive(Serialize, Debug, JsonSchema)]
128pub enum RegistrationMethod {
129    #[serde(rename = "password")]
130    Password,
131
132    #[serde(rename = "upstream-oauth2")]
133    UpstreamOAuth2,
134}
135
136/// Input for the user registration policy.
137#[derive(Serialize, Debug, JsonSchema)]
138#[serde(tag = "registration_method")]
139pub struct RegisterInput<'a> {
140    pub registration_method: RegistrationMethod,
141
142    pub username: &'a str,
143
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub email: Option<&'a str>,
146
147    pub requester: Requester,
148}
149
150/// Input for the client registration policy.
151#[derive(Serialize, Debug, JsonSchema)]
152#[serde(rename_all = "snake_case")]
153pub struct ClientRegistrationInput<'a> {
154    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
155    pub client_metadata: &'a VerifiedClientMetadata,
156    pub requester: Requester,
157}
158
159#[derive(Serialize, Debug, JsonSchema)]
160#[serde(rename_all = "snake_case")]
161pub enum GrantType {
162    AuthorizationCode,
163    ClientCredentials,
164    #[serde(rename = "urn:ietf:params:oauth:grant-type:device_code")]
165    DeviceCode,
166}
167
168/// Input for the authorization grant policy.
169#[derive(Serialize, Debug, JsonSchema)]
170#[serde(rename_all = "snake_case")]
171pub struct AuthorizationGrantInput<'a> {
172    #[schemars(with = "Option<std::collections::HashMap<String, serde_json::Value>>")]
173    pub user: Option<&'a User>,
174
175    /// How many sessions the user has.
176    /// Not populated if it's not a user logging in.
177    pub session_counts: Option<SessionCounts>,
178
179    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
180    pub client: &'a Client,
181
182    #[schemars(with = "String")]
183    pub scope: &'a Scope,
184
185    pub grant_type: GrantType,
186
187    pub requester: Requester,
188}
189
190/// Information about how many sessions the user has
191#[derive(Serialize, Debug, JsonSchema)]
192pub struct SessionCounts {
193    pub total: u64,
194
195    pub oauth2: u64,
196    pub compat: u64,
197    pub personal: u64,
198}
199
200/// Input for the email add policy.
201#[derive(Serialize, Debug, JsonSchema)]
202#[serde(rename_all = "snake_case")]
203pub struct EmailInput<'a> {
204    pub email: &'a str,
205
206    pub requester: Requester,
207}