Skip to main content

mas_policy/
model.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! Input and output types for policy evaluation.
8//!
9//! This is useful to generate JSON schemas for each input type, which can then
10//! be type-checked by Open Policy Agent.
11
12use std::net::IpAddr;
13
14use mas_data_model::{Client, User};
15use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18
19/// Violation variants identified by a well-known policy code (under the `code`
20/// key).
21#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
22#[serde(tag = "code", rename_all = "kebab-case")]
23pub enum ViolationVariant {
24    /// The username is too short.
25    UsernameTooShort,
26
27    /// The username is too long.
28    UsernameTooLong,
29
30    /// The username contains invalid characters.
31    UsernameInvalidChars,
32
33    /// The username contains only numeric characters.
34    UsernameAllNumeric,
35
36    /// The username is banned.
37    UsernameBanned,
38
39    /// The username is not allowed.
40    UsernameNotAllowed,
41
42    /// The email domain is not allowed.
43    EmailDomainNotAllowed,
44
45    /// The email domain is banned.
46    EmailDomainBanned,
47
48    /// The email address is not allowed.
49    EmailNotAllowed,
50
51    /// The email address is banned.
52    EmailBanned,
53
54    /// The user has reached their session limit.
55    TooManySessions {
56        /// How many devices need to be removed to make room for the new session
57        need_to_remove: u32,
58    },
59}
60
61impl ViolationVariant {
62    /// Returns the code as a string
63    #[must_use]
64    pub fn as_str(&self) -> &'static str {
65        match self {
66            Self::UsernameTooShort => "username-too-short",
67            Self::UsernameTooLong => "username-too-long",
68            Self::UsernameInvalidChars => "username-invalid-chars",
69            Self::UsernameAllNumeric => "username-all-numeric",
70            Self::UsernameBanned => "username-banned",
71            Self::UsernameNotAllowed => "username-not-allowed",
72            Self::EmailDomainNotAllowed => "email-domain-not-allowed",
73            Self::EmailDomainBanned => "email-domain-banned",
74            Self::EmailNotAllowed => "email-not-allowed",
75            Self::EmailBanned => "email-banned",
76            Self::TooManySessions { .. } => "too-many-sessions",
77        }
78    }
79}
80
81/// A single violation of a policy.
82#[derive(Serialize, Deserialize, Debug, JsonSchema)]
83pub struct Violation {
84    pub msg: String,
85    pub redirect_uri: Option<String>,
86    pub field: Option<String>,
87
88    // We flatten as policies expect `code` as another top-level field.
89    //
90    // This also means all of the extra fields from the variant will be splatted at this
91    // level which is fine (arbitrary).
92    #[serde(flatten)]
93    pub variant: Option<ViolationVariant>,
94}
95
96/// The result of a policy evaluation.
97#[derive(Deserialize, Debug)]
98pub struct EvaluationResult {
99    #[serde(rename = "result")]
100    pub violations: Vec<Violation>,
101}
102
103impl std::fmt::Display for EvaluationResult {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        let mut first = true;
106        for violation in &self.violations {
107            if first {
108                first = false;
109            } else {
110                write!(f, ", ")?;
111            }
112            write!(f, "{}", violation.msg)?;
113        }
114        Ok(())
115    }
116}
117
118impl EvaluationResult {
119    /// Returns true if the policy evaluation was successful.
120    #[must_use]
121    pub fn valid(&self) -> bool {
122        self.violations.is_empty()
123    }
124}
125
126/// Identity of the requester
127#[derive(Serialize, Debug, Default, JsonSchema)]
128#[serde(rename_all = "snake_case")]
129pub struct Requester {
130    /// IP address of the entity making the request
131    pub ip_address: Option<IpAddr>,
132
133    /// User agent of the entity making the request
134    pub user_agent: Option<String>,
135}
136
137#[derive(Serialize, Debug, JsonSchema)]
138pub enum RegistrationMethod {
139    #[serde(rename = "password")]
140    Password,
141
142    #[serde(rename = "upstream-oauth2")]
143    UpstreamOAuth2,
144}
145
146/// Input for the user registration policy.
147#[derive(Serialize, Debug, JsonSchema)]
148#[serde(tag = "registration_method")]
149pub struct RegisterInput<'a> {
150    pub registration_method: RegistrationMethod,
151
152    pub username: &'a str,
153
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub email: Option<&'a str>,
156
157    pub requester: Requester,
158}
159
160/// Input for the client registration policy.
161#[derive(Serialize, Debug, JsonSchema)]
162#[serde(rename_all = "snake_case")]
163pub struct ClientRegistrationInput<'a> {
164    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
165    pub client_metadata: &'a VerifiedClientMetadata,
166    pub requester: Requester,
167}
168
169#[derive(Serialize, Debug, JsonSchema)]
170#[serde(rename_all = "snake_case")]
171pub enum GrantType {
172    AuthorizationCode,
173    ClientCredentials,
174    #[serde(rename = "urn:ietf:params:oauth:grant-type:device_code")]
175    DeviceCode,
176}
177
178/// Input for the authorization grant policy.
179#[derive(Serialize, Debug, JsonSchema)]
180#[serde(rename_all = "snake_case")]
181pub struct AuthorizationGrantInput<'a> {
182    #[schemars(with = "Option<std::collections::HashMap<String, serde_json::Value>>")]
183    pub user: Option<&'a User>,
184
185    /// How many sessions the user has.
186    /// Not populated if it's not a user logging in.
187    pub session_counts: Option<SessionCounts>,
188
189    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
190    pub client: &'a Client,
191
192    #[schemars(with = "String")]
193    pub scope: &'a Scope,
194
195    pub grant_type: GrantType,
196
197    pub requester: Requester,
198}
199
200/// Input for the compatibility login policy.
201#[derive(Serialize, Debug, JsonSchema)]
202#[serde(rename_all = "snake_case")]
203pub struct CompatLoginInput<'a> {
204    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
205    pub user: &'a User,
206
207    /// How many sessions the user has.
208    pub session_counts: SessionCounts,
209
210    /// Whether a session will be replaced by this login
211    pub session_replaced: bool,
212
213    /// What type of login is being performed.
214    /// This also determines whether the login is interactive.
215    pub login: CompatLogin,
216
217    pub requester: Requester,
218}
219
220#[derive(Serialize, Debug, JsonSchema)]
221#[serde(tag = "type")]
222pub enum CompatLogin {
223    /// Used as the interactive part of SSO login.
224    #[serde(rename = "m.login.sso")]
225    Sso { redirect_uri: String },
226
227    /// Used as the final (non-interactive) stage of SSO login.
228    #[serde(rename = "m.login.token")]
229    Token,
230
231    /// Non-interactive password-over-the-API login.
232    #[serde(rename = "m.login.password")]
233    Password,
234}
235
236/// Information about how many sessions the user has
237#[derive(Serialize, Debug, JsonSchema)]
238pub struct SessionCounts {
239    pub total: u64,
240
241    pub oauth2: u64,
242    pub compat: u64,
243    pub personal: u64,
244}
245
246/// Input for the email add policy.
247#[derive(Serialize, Debug, JsonSchema)]
248#[serde(rename_all = "snake_case")]
249pub struct EmailInput<'a> {
250    pub email: &'a str,
251
252    pub requester: Requester,
253}