mas_handlers/views/
shared.rs1use std::str::FromStr as _;
8
9use anyhow::Context;
10use mas_router::{PostAuthAction, Route, UrlBuilder};
11use mas_storage::{
12 RepositoryAccess,
13 compat::CompatSsoLoginRepository,
14 oauth2::OAuth2AuthorizationGrantRepository,
15 upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
16};
17use mas_templates::{PostAuthContext, PostAuthContextInner};
18use ruma_common::UserId;
19use serde::{Deserialize, Serialize};
20use tracing::warn;
21
22#[derive(Serialize, Deserialize, Default, Debug, Clone)]
23pub(crate) struct OptionalPostAuthAction {
24 #[serde(flatten)]
25 pub post_auth_action: Option<PostAuthAction>,
26}
27
28impl From<Option<PostAuthAction>> for OptionalPostAuthAction {
29 fn from(post_auth_action: Option<PostAuthAction>) -> Self {
30 Self { post_auth_action }
31 }
32}
33
34impl OptionalPostAuthAction {
35 pub fn go_next_or_default<T: Route>(
36 &self,
37 url_builder: &UrlBuilder,
38 default: &T,
39 ) -> axum::response::Redirect {
40 self.post_auth_action.as_ref().map_or_else(
41 || url_builder.redirect(default),
42 |action| action.go_next(url_builder),
43 )
44 }
45
46 pub fn go_next(&self, url_builder: &UrlBuilder) -> axum::response::Redirect {
47 self.go_next_or_default(url_builder, &mas_router::Index)
48 }
49
50 pub async fn load_context<'a>(
51 &'a self,
52 repo: &'a mut impl RepositoryAccess,
53 ) -> anyhow::Result<Option<PostAuthContext>> {
54 let Some(action) = self.post_auth_action.clone() else {
55 return Ok(None);
56 };
57 let ctx = match action {
58 PostAuthAction::ContinueAuthorizationGrant { id } => {
59 let Some(grant) = repo.oauth2_authorization_grant().lookup(id).await? else {
60 warn!(%id, "Failed to load authorization grant, it was likely deleted or is an invalid ID");
61 return Ok(None);
62 };
63 let grant = Box::new(grant);
64 PostAuthContextInner::ContinueAuthorizationGrant { grant }
65 }
66
67 PostAuthAction::ContinueDeviceCodeGrant { id } => {
68 let Some(grant) = repo.oauth2_device_code_grant().lookup(id).await? else {
69 warn!(%id, "Failed to load device code grant, it was likely deleted or is an invalid ID");
70 return Ok(None);
71 };
72 let grant = Box::new(grant);
73 PostAuthContextInner::ContinueDeviceCodeGrant { grant }
74 }
75
76 PostAuthAction::ContinueCompatSsoLogin { id } => {
77 let Some(login) = repo.compat_sso_login().lookup(id).await? else {
78 warn!(%id, "Failed to load compat SSO login, it was likely deleted or is an invalid ID");
79 return Ok(None);
80 };
81 let login = Box::new(login);
82 PostAuthContextInner::ContinueCompatSsoLogin { login }
83 }
84
85 PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
86
87 PostAuthAction::LinkUpstream { id } => {
88 let Some(link) = repo.upstream_oauth_link().lookup(id).await? else {
89 warn!(%id, "Failed to load upstream OAuth 2.0 link, it was likely deleted or is an invalid ID");
90 return Ok(None);
91 };
92
93 let provider = repo
94 .upstream_oauth_provider()
95 .lookup(link.provider_id)
96 .await?
97 .context("Failed to load upstream OAuth 2.0 provider")?;
98
99 let provider = Box::new(provider);
100 let link = Box::new(link);
101 PostAuthContextInner::LinkUpstream { provider, link }
102 }
103
104 PostAuthAction::ManageAccount { .. } => PostAuthContextInner::ManageAccount,
105 };
106
107 Ok(Some(PostAuthContext {
108 params: action.clone(),
109 ctx,
110 }))
111 }
112}
113
114pub enum LoginHint<'a> {
115 Mxid(&'a UserId),
116 Email(lettre::Address),
117 None,
118}
119
120#[derive(Debug, Deserialize)]
121pub(crate) struct QueryLoginHint {
122 login_hint: Option<String>,
123}
124
125impl QueryLoginHint {
126 pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint<'_> {
134 let Some(login_hint) = &self.login_hint else {
135 return LoginHint::None;
136 };
137
138 if let Some(value) = login_hint.strip_prefix("mxid:")
139 && let Ok(mxid) = <&UserId>::try_from(value)
140 && mxid.server_name() == homeserver
141 {
142 LoginHint::Mxid(mxid)
143 } else if let Ok(email) = lettre::Address::from_str(login_hint) {
144 LoginHint::Email(email)
145 } else {
146 LoginHint::None
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn no_login_hint() {
157 let query_login_hint = QueryLoginHint { login_hint: None };
158
159 let hint = query_login_hint.parse_login_hint("example.com");
160
161 assert!(matches!(hint, LoginHint::None));
162 }
163
164 #[test]
165 fn valid_login_hint() {
166 let query_login_hint = QueryLoginHint {
167 login_hint: Some(String::from("mxid:@example-user:example.com")),
168 };
169
170 let hint = query_login_hint.parse_login_hint("example.com");
171
172 assert!(matches!(hint, LoginHint::Mxid(mxid) if mxid.localpart() == "example-user"));
173 }
174
175 #[test]
176 fn valid_login_hint_with_email() {
177 let query_login_hint = QueryLoginHint {
178 login_hint: Some(String::from("example@user")),
179 };
180
181 let hint = query_login_hint.parse_login_hint("example.com");
182
183 assert!(matches!(hint, LoginHint::Email(email) if email.to_string() == "example@user"));
184 }
185
186 #[test]
187 fn invalid_login_hint() {
188 let query_login_hint = QueryLoginHint {
189 login_hint: Some(String::from("example-user")),
190 };
191
192 let hint = query_login_hint.parse_login_hint("example.com");
193
194 assert!(matches!(hint, LoginHint::None));
195 }
196
197 #[test]
198 fn valid_login_hint_for_wrong_homeserver() {
199 let query_login_hint = QueryLoginHint {
200 login_hint: Some(String::from("mxid:@example-user:matrix.org")),
201 };
202
203 let hint = query_login_hint.parse_login_hint("example.com");
204
205 assert!(matches!(hint, LoginHint::None));
206 }
207
208 #[test]
209 fn unknown_login_hint_type() {
210 let query_login_hint = QueryLoginHint {
211 login_hint: Some(String::from("something:anything")),
212 };
213
214 let hint = query_login_hint.parse_login_hint("example.com");
215
216 assert!(matches!(hint, LoginHint::None));
217 }
218}