mas_handlers/views/
shared.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::str::FromStr as _;
8
9use anyhow::Context;
10use mas_router::{PostAuthAction, Route, UrlBuilder};
11use mas_storage::{
12    RepositoryAccess,
13    compat::CompatSsoLoginRepository,
14    oauth2::OAuth2AuthorizationGrantRepository,
15    upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
16};
17use mas_templates::{PostAuthContext, PostAuthContextInner};
18use ruma_common::UserId;
19use serde::{Deserialize, Serialize};
20use tracing::warn;
21
22#[derive(Serialize, Deserialize, Default, Debug, Clone)]
23pub(crate) struct OptionalPostAuthAction {
24    #[serde(flatten)]
25    pub post_auth_action: Option<PostAuthAction>,
26}
27
28impl From<Option<PostAuthAction>> for OptionalPostAuthAction {
29    fn from(post_auth_action: Option<PostAuthAction>) -> Self {
30        Self { post_auth_action }
31    }
32}
33
34impl OptionalPostAuthAction {
35    pub fn go_next_or_default<T: Route>(
36        &self,
37        url_builder: &UrlBuilder,
38        default: &T,
39    ) -> axum::response::Redirect {
40        self.post_auth_action.as_ref().map_or_else(
41            || url_builder.redirect(default),
42            |action| action.go_next(url_builder),
43        )
44    }
45
46    pub fn go_next(&self, url_builder: &UrlBuilder) -> axum::response::Redirect {
47        self.go_next_or_default(url_builder, &mas_router::Index)
48    }
49
50    pub async fn load_context<'a>(
51        &'a self,
52        repo: &'a mut impl RepositoryAccess,
53    ) -> anyhow::Result<Option<PostAuthContext>> {
54        let Some(action) = self.post_auth_action.clone() else {
55            return Ok(None);
56        };
57        let ctx = match action {
58            PostAuthAction::ContinueAuthorizationGrant { id } => {
59                let Some(grant) = repo.oauth2_authorization_grant().lookup(id).await? else {
60                    warn!(%id, "Failed to load authorization grant, it was likely deleted or is an invalid ID");
61                    return Ok(None);
62                };
63                let grant = Box::new(grant);
64                PostAuthContextInner::ContinueAuthorizationGrant { grant }
65            }
66
67            PostAuthAction::ContinueDeviceCodeGrant { id } => {
68                let Some(grant) = repo.oauth2_device_code_grant().lookup(id).await? else {
69                    warn!(%id, "Failed to load device code grant, it was likely deleted or is an invalid ID");
70                    return Ok(None);
71                };
72                let grant = Box::new(grant);
73                PostAuthContextInner::ContinueDeviceCodeGrant { grant }
74            }
75
76            PostAuthAction::ContinueCompatSsoLogin { id } => {
77                let Some(login) = repo.compat_sso_login().lookup(id).await? else {
78                    warn!(%id, "Failed to load compat SSO login, it was likely deleted or is an invalid ID");
79                    return Ok(None);
80                };
81                let login = Box::new(login);
82                PostAuthContextInner::ContinueCompatSsoLogin { login }
83            }
84
85            PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
86
87            PostAuthAction::LinkUpstream { id } => {
88                let Some(link) = repo.upstream_oauth_link().lookup(id).await? else {
89                    warn!(%id, "Failed to load upstream OAuth 2.0 link, it was likely deleted or is an invalid ID");
90                    return Ok(None);
91                };
92
93                let provider = repo
94                    .upstream_oauth_provider()
95                    .lookup(link.provider_id)
96                    .await?
97                    .context("Failed to load upstream OAuth 2.0 provider")?;
98
99                let provider = Box::new(provider);
100                let link = Box::new(link);
101                PostAuthContextInner::LinkUpstream { provider, link }
102            }
103
104            PostAuthAction::ManageAccount { .. } => PostAuthContextInner::ManageAccount,
105        };
106
107        Ok(Some(PostAuthContext {
108            params: action.clone(),
109            ctx,
110        }))
111    }
112}
113
114pub enum LoginHint<'a> {
115    Mxid(&'a UserId),
116    Email(lettre::Address),
117    None,
118}
119
120#[derive(Debug, Deserialize)]
121pub(crate) struct QueryLoginHint {
122    login_hint: Option<String>,
123}
124
125impl QueryLoginHint {
126    /// Parse a `login_hint`
127    ///
128    /// Returns `LoginHint::MXID` for valid mxid 'mxid:@john.doe:example.com'
129    ///
130    /// Returns `LoginHint::Email` for valid email 'john.doe@example.com'
131    ///
132    /// Otherwise returns `LoginHint::None`
133    pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> {
134        let Some(login_hint) = &self.login_hint else {
135            return LoginHint::None;
136        };
137
138        if let Some(value) = login_hint.strip_prefix("mxid:")
139            && let Ok(mxid) = <&UserId>::try_from(value)
140            && mxid.server_name() == homeserver
141        {
142            LoginHint::Mxid(mxid)
143        } else if let Ok(email) = lettre::Address::from_str(login_hint) {
144            LoginHint::Email(email)
145        } else {
146            LoginHint::None
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn no_login_hint() {
157        let query_login_hint = QueryLoginHint { login_hint: None };
158
159        let hint = query_login_hint.parse_login_hint("example.com");
160
161        assert!(matches!(hint, LoginHint::None));
162    }
163
164    #[test]
165    fn valid_login_hint() {
166        let query_login_hint = QueryLoginHint {
167            login_hint: Some(String::from("mxid:@example-user:example.com")),
168        };
169
170        let hint = query_login_hint.parse_login_hint("example.com");
171
172        assert!(matches!(hint, LoginHint::Mxid(mxid) if mxid.localpart() == "example-user"));
173    }
174
175    #[test]
176    fn valid_login_hint_with_email() {
177        let query_login_hint = QueryLoginHint {
178            login_hint: Some(String::from("example@user")),
179        };
180
181        let hint = query_login_hint.parse_login_hint("example.com");
182
183        assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user"));
184    }
185
186    #[test]
187    fn invalid_login_hint() {
188        let query_login_hint = QueryLoginHint {
189            login_hint: Some(String::from("example-user")),
190        };
191
192        let hint = query_login_hint.parse_login_hint("example.com");
193
194        assert!(matches!(hint, LoginHint::None));
195    }
196
197    #[test]
198    fn valid_login_hint_for_wrong_homeserver() {
199        let query_login_hint = QueryLoginHint {
200            login_hint: Some(String::from("mxid:@example-user:matrix.org")),
201        };
202
203        let hint = query_login_hint.parse_login_hint("example.com");
204
205        assert!(matches!(hint, LoginHint::None));
206    }
207
208    #[test]
209    fn unknown_login_hint_type() {
210        let query_login_hint = QueryLoginHint {
211            login_hint: Some(String::from("something:anything")),
212        };
213
214        let hint = query_login_hint.parse_login_hint("example.com");
215
216        assert!(matches!(hint, LoginHint::None));
217    }
218}