mas_handlers/graphql/mutations/
oauth2_session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use anyhow::Context as _;
8use async_graphql::{Context, Description, Enum, ID, InputObject, Object};
9use chrono::Duration;
10use mas_data_model::{Device, TokenType};
11use mas_storage::{
12    RepositoryAccess,
13    oauth2::{
14        OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository,
15        OAuth2SessionRepository,
16    },
17    queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
18    user::UserRepository,
19};
20use oauth2_types::scope::Scope;
21
22use crate::graphql::{
23    model::{NodeType, OAuth2Session},
24    state::ContextExt,
25};
26
27#[derive(Default)]
28pub struct OAuth2SessionMutations {
29    _private: (),
30}
31
32/// The input of the `createOauth2Session` mutation.
33#[derive(InputObject)]
34pub struct CreateOAuth2SessionInput {
35    /// The scope of the session
36    scope: String,
37
38    /// The ID of the user for which to create the session
39    user_id: ID,
40
41    /// Whether the session should issue a never-expiring access token
42    permanent: Option<bool>,
43}
44
45/// The payload of the `createOauth2Session` mutation.
46#[derive(Description)]
47pub struct CreateOAuth2SessionPayload {
48    access_token: String,
49    refresh_token: Option<String>,
50    session: mas_data_model::Session,
51}
52
53#[Object(use_type_description)]
54impl CreateOAuth2SessionPayload {
55    /// Access token for this session
56    pub async fn access_token(&self) -> &str {
57        &self.access_token
58    }
59
60    /// Refresh token for this session, if it is not a permanent session
61    pub async fn refresh_token(&self) -> Option<&str> {
62        self.refresh_token.as_deref()
63    }
64
65    /// The OAuth 2.0 session which was just created
66    pub async fn oauth2_session(&self) -> OAuth2Session {
67        OAuth2Session(self.session.clone())
68    }
69}
70
71/// The input of the `endOauth2Session` mutation.
72#[derive(InputObject)]
73pub struct EndOAuth2SessionInput {
74    /// The ID of the session to end.
75    oauth2_session_id: ID,
76}
77
78/// The payload of the `endOauth2Session` mutation.
79pub enum EndOAuth2SessionPayload {
80    NotFound,
81    Ended(mas_data_model::Session),
82}
83
84/// The status of the `endOauth2Session` mutation.
85#[derive(Enum, Copy, Clone, PartialEq, Eq, Debug)]
86enum EndOAuth2SessionStatus {
87    /// The session was ended.
88    Ended,
89
90    /// The session was not found.
91    NotFound,
92}
93
94#[Object]
95impl EndOAuth2SessionPayload {
96    /// The status of the mutation.
97    async fn status(&self) -> EndOAuth2SessionStatus {
98        match self {
99            Self::Ended(_) => EndOAuth2SessionStatus::Ended,
100            Self::NotFound => EndOAuth2SessionStatus::NotFound,
101        }
102    }
103
104    /// Returns the ended session.
105    async fn oauth2_session(&self) -> Option<OAuth2Session> {
106        match self {
107            Self::Ended(session) => Some(OAuth2Session(session.clone())),
108            Self::NotFound => None,
109        }
110    }
111}
112
113#[Object]
114impl OAuth2SessionMutations {
115    /// Create a new arbitrary OAuth 2.0 Session.
116    ///
117    /// Only available for administrators.
118    async fn create_oauth2_session(
119        &self,
120        ctx: &Context<'_>,
121        input: CreateOAuth2SessionInput,
122    ) -> Result<CreateOAuth2SessionPayload, async_graphql::Error> {
123        let state = ctx.state();
124        let homeserver = state.homeserver_connection();
125        let user_id = NodeType::User.extract_ulid(&input.user_id)?;
126        let scope: Scope = input.scope.parse().context("Invalid scope")?;
127        let permanent = input.permanent.unwrap_or(false);
128        let requester = ctx.requester();
129
130        if !requester.is_admin() {
131            return Err(async_graphql::Error::new("Unauthorized"));
132        }
133
134        let session = requester
135            .oauth2_session()
136            .context("Requester should be a OAuth 2.0 client")?;
137
138        let mut repo = state.repository().await?;
139        let clock = state.clock();
140        let mut rng = state.rng();
141
142        let client = repo
143            .oauth2_client()
144            .lookup(session.client_id)
145            .await?
146            .context("Client not found")?;
147
148        let user = repo
149            .user()
150            .lookup(user_id)
151            .await?
152            .context("User not found")?;
153
154        // Generate a new access token
155        let access_token = TokenType::AccessToken.generate(&mut rng);
156
157        // Create the OAuth 2.0 Session
158        let session = repo
159            .oauth2_session()
160            .add(&mut rng, &clock, &client, Some(&user), None, scope)
161            .await?;
162
163        // Lock the user sync to make sure we don't get into a race condition
164        repo.user().acquire_lock_for_sync(&user).await?;
165
166        // Look for devices to provision
167        let mxid = homeserver.mxid(&user.username);
168        for scope in &*session.scope {
169            if let Some(device) = Device::from_scope_token(scope) {
170                homeserver
171                    .create_device(&mxid, device.as_str())
172                    .await
173                    .context("Failed to provision device")?;
174            }
175        }
176
177        let ttl = if permanent {
178            None
179        } else {
180            Some(Duration::microseconds(5 * 60 * 1000 * 1000))
181        };
182        let access_token = repo
183            .oauth2_access_token()
184            .add(&mut rng, &clock, &session, access_token, ttl)
185            .await?;
186
187        let refresh_token = if permanent {
188            None
189        } else {
190            let refresh_token = TokenType::RefreshToken.generate(&mut rng);
191
192            let refresh_token = repo
193                .oauth2_refresh_token()
194                .add(&mut rng, &clock, &session, &access_token, refresh_token)
195                .await?;
196
197            Some(refresh_token)
198        };
199
200        repo.save().await?;
201
202        Ok(CreateOAuth2SessionPayload {
203            session,
204            access_token: access_token.access_token,
205            refresh_token: refresh_token.map(|t| t.refresh_token),
206        })
207    }
208
209    async fn end_oauth2_session(
210        &self,
211        ctx: &Context<'_>,
212        input: EndOAuth2SessionInput,
213    ) -> Result<EndOAuth2SessionPayload, async_graphql::Error> {
214        let state = ctx.state();
215        let oauth2_session_id = NodeType::OAuth2Session.extract_ulid(&input.oauth2_session_id)?;
216        let requester = ctx.requester();
217
218        let mut repo = state.repository().await?;
219        let clock = state.clock();
220        let mut rng = state.rng();
221
222        let session = repo.oauth2_session().lookup(oauth2_session_id).await?;
223        let Some(session) = session else {
224            return Ok(EndOAuth2SessionPayload::NotFound);
225        };
226
227        if !requester.is_owner_or_admin(&session) {
228            return Ok(EndOAuth2SessionPayload::NotFound);
229        }
230
231        if let Some(user_id) = session.user_id {
232            let user = repo
233                .user()
234                .lookup(user_id)
235                .await?
236                .context("Could not load user")?;
237
238            // Schedule a job to sync the devices of the user with the homeserver
239            repo.queue_job()
240                .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
241                .await?;
242        }
243
244        let session = repo.oauth2_session().finish(&clock, session).await?;
245
246        repo.save().await?;
247
248        Ok(EndOAuth2SessionPayload::Ended(session))
249    }
250}