mas_handlers/graphql/query/
upstream_oauth.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_graphql::{
8    Context, ID, Object,
9    connection::{Connection, Edge, OpaqueCursor, query},
10};
11use mas_storage::{Pagination, RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderFilter};
12
13use crate::graphql::{
14    model::{
15        Cursor, NodeCursor, NodeType, PreloadedTotalCount, UpstreamOAuth2Link,
16        UpstreamOAuth2Provider,
17    },
18    state::ContextExt,
19};
20
21#[derive(Default)]
22pub struct UpstreamOAuthQuery;
23
24#[Object]
25impl UpstreamOAuthQuery {
26    /// Fetch an upstream OAuth 2.0 link by its ID.
27    pub async fn upstream_oauth2_link(
28        &self,
29        ctx: &Context<'_>,
30        id: ID,
31    ) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
32        let state = ctx.state();
33        let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
34        let requester = ctx.requester();
35
36        let mut repo = state.repository().await?;
37        let link = repo.upstream_oauth_link().lookup(id).await?;
38        repo.cancel().await?;
39
40        let Some(link) = link else {
41            return Ok(None);
42        };
43
44        if !requester.is_owner_or_admin(&link) {
45            return Ok(None);
46        }
47
48        Ok(Some(UpstreamOAuth2Link::new(link)))
49    }
50
51    /// Fetch an upstream OAuth 2.0 provider by its ID.
52    pub async fn upstream_oauth2_provider(
53        &self,
54        ctx: &Context<'_>,
55        id: ID,
56    ) -> Result<Option<UpstreamOAuth2Provider>, async_graphql::Error> {
57        let state = ctx.state();
58        let id = NodeType::UpstreamOAuth2Provider.extract_ulid(&id)?;
59
60        let mut repo = state.repository().await?;
61        let provider = repo.upstream_oauth_provider().lookup(id).await?;
62        repo.cancel().await?;
63
64        let Some(provider) = provider else {
65            return Ok(None);
66        };
67
68        // We only allow enabled providers to be fetched
69        if !provider.enabled() {
70            return Ok(None);
71        }
72
73        Ok(Some(UpstreamOAuth2Provider::new(provider)))
74    }
75
76    /// Get a list of upstream OAuth 2.0 providers.
77    async fn upstream_oauth2_providers(
78        &self,
79        ctx: &Context<'_>,
80
81        #[graphql(desc = "Returns the elements in the list that come after the cursor.")]
82        after: Option<String>,
83        #[graphql(desc = "Returns the elements in the list that come before the cursor.")]
84        before: Option<String>,
85        #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
86        #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
87    ) -> Result<Connection<Cursor, UpstreamOAuth2Provider, PreloadedTotalCount>, async_graphql::Error>
88    {
89        let state = ctx.state();
90        let mut repo = state.repository().await?;
91
92        query(
93            after,
94            before,
95            first,
96            last,
97            async |after, before, first, last| {
98                let after_id = after
99                    .map(|x: OpaqueCursor<NodeCursor>| {
100                        x.extract_for_type(NodeType::UpstreamOAuth2Provider)
101                    })
102                    .transpose()?;
103                let before_id = before
104                    .map(|x: OpaqueCursor<NodeCursor>| {
105                        x.extract_for_type(NodeType::UpstreamOAuth2Provider)
106                    })
107                    .transpose()?;
108                let pagination = Pagination::try_new(before_id, after_id, first, last)?;
109
110                // We only want enabled providers
111                // XXX: we may want to let admins see disabled providers
112                let filter = UpstreamOAuthProviderFilter::new().enabled_only();
113
114                let page = repo
115                    .upstream_oauth_provider()
116                    .list(filter, pagination)
117                    .await?;
118
119                // Preload the total count if requested
120                let count = if ctx.look_ahead().field("totalCount").exists() {
121                    Some(repo.upstream_oauth_provider().count(filter).await?)
122                } else {
123                    None
124                };
125
126                repo.cancel().await?;
127
128                let mut connection = Connection::with_additional_fields(
129                    page.has_previous_page,
130                    page.has_next_page,
131                    PreloadedTotalCount(count),
132                );
133                connection.edges.extend(page.edges.into_iter().map(|p| {
134                    Edge::new(
135                        OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)),
136                        UpstreamOAuth2Provider::new(p),
137                    )
138                }));
139
140                Ok::<_, async_graphql::Error>(connection)
141            },
142        )
143        .await
144    }
145}