mas_handlers/views/
shared.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use anyhow::Context;
8use mas_router::{PostAuthAction, Route, UrlBuilder};
9use mas_storage::{
10    RepositoryAccess,
11    compat::CompatSsoLoginRepository,
12    oauth2::OAuth2AuthorizationGrantRepository,
13    upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository},
14};
15use mas_templates::{PostAuthContext, PostAuthContextInner};
16use serde::{Deserialize, Serialize};
17
18#[derive(Serialize, Deserialize, Default, Debug, Clone)]
19pub(crate) struct OptionalPostAuthAction {
20    #[serde(flatten)]
21    pub post_auth_action: Option<PostAuthAction>,
22}
23
24impl From<Option<PostAuthAction>> for OptionalPostAuthAction {
25    fn from(post_auth_action: Option<PostAuthAction>) -> Self {
26        Self { post_auth_action }
27    }
28}
29
30impl OptionalPostAuthAction {
31    pub fn go_next_or_default<T: Route>(
32        &self,
33        url_builder: &UrlBuilder,
34        default: &T,
35    ) -> axum::response::Redirect {
36        self.post_auth_action.as_ref().map_or_else(
37            || url_builder.redirect(default),
38            |action| action.go_next(url_builder),
39        )
40    }
41
42    pub fn go_next(&self, url_builder: &UrlBuilder) -> axum::response::Redirect {
43        self.go_next_or_default(url_builder, &mas_router::Index)
44    }
45
46    pub async fn load_context<'a>(
47        &'a self,
48        repo: &'a mut impl RepositoryAccess,
49    ) -> anyhow::Result<Option<PostAuthContext>> {
50        let Some(action) = self.post_auth_action.clone() else {
51            return Ok(None);
52        };
53        let ctx = match action {
54            PostAuthAction::ContinueAuthorizationGrant { id } => {
55                let grant = repo
56                    .oauth2_authorization_grant()
57                    .lookup(id)
58                    .await?
59                    .context("Failed to load authorization grant")?;
60                let grant = Box::new(grant);
61                PostAuthContextInner::ContinueAuthorizationGrant { grant }
62            }
63
64            PostAuthAction::ContinueDeviceCodeGrant { id } => {
65                let grant = repo
66                    .oauth2_device_code_grant()
67                    .lookup(id)
68                    .await?
69                    .context("Failed to load device code grant")?;
70                let grant = Box::new(grant);
71                PostAuthContextInner::ContinueDeviceCodeGrant { grant }
72            }
73
74            PostAuthAction::ContinueCompatSsoLogin { id } => {
75                let login = repo
76                    .compat_sso_login()
77                    .lookup(id)
78                    .await?
79                    .context("Failed to load compat SSO login")?;
80                let login = Box::new(login);
81                PostAuthContextInner::ContinueCompatSsoLogin { login }
82            }
83
84            PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
85
86            PostAuthAction::LinkUpstream { id } => {
87                let link = repo
88                    .upstream_oauth_link()
89                    .lookup(id)
90                    .await?
91                    .context("Failed to load upstream OAuth 2.0 link")?;
92
93                let provider = repo
94                    .upstream_oauth_provider()
95                    .lookup(link.provider_id)
96                    .await?
97                    .context("Failed to load upstream OAuth 2.0 provider")?;
98
99                let provider = Box::new(provider);
100                let link = Box::new(link);
101                PostAuthContextInner::LinkUpstream { provider, link }
102            }
103
104            PostAuthAction::ManageAccount { .. } => PostAuthContextInner::ManageAccount,
105        };
106
107        Ok(Some(PostAuthContext {
108            params: action.clone(),
109            ctx,
110        }))
111    }
112}