mas_handlers/admin/v1/upstream_oauth_links/
add.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use aide::{NoApi, OperationIo, transform::TransformOperation};
7use axum::{Json, response::IntoResponse};
8use hyper::StatusCode;
9use mas_storage::BoxRng;
10use schemars::JsonSchema;
11use serde::Deserialize;
12use ulid::Ulid;
13
14use crate::{
15    admin::{
16        call_context::CallContext,
17        model::{Resource, UpstreamOAuthLink},
18        response::{ErrorResponse, SingleResponse},
19    },
20    impl_from_error_for_route,
21};
22
23#[derive(Debug, thiserror::Error, OperationIo)]
24#[aide(output_with = "Json<ErrorResponse>")]
25pub enum RouteError {
26    #[error(transparent)]
27    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
28
29    #[error("Upstream Oauth 2.0 Provider ID {0} with subject {1} is already linked to a user")]
30    LinkAlreadyExists(Ulid, String),
31
32    #[error("User ID {0} not found")]
33    UserNotFound(Ulid),
34
35    #[error("Upstream OAuth 2.0 Provider ID {0} not found")]
36    ProviderNotFound(Ulid),
37}
38
39impl_from_error_for_route!(mas_storage::RepositoryError);
40
41impl IntoResponse for RouteError {
42    fn into_response(self) -> axum::response::Response {
43        let error = ErrorResponse::from_error(&self);
44        let status = match self {
45            Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
46            Self::LinkAlreadyExists(_, _) => StatusCode::CONFLICT,
47            Self::UserNotFound(_) | Self::ProviderNotFound(_) => StatusCode::NOT_FOUND,
48        };
49        (status, Json(error)).into_response()
50    }
51}
52
53/// # JSON payload for the `POST /api/admin/v1/upstream-oauth-links`
54#[derive(Deserialize, JsonSchema)]
55#[serde(rename = "AddUpstreamOauthLinkRequest")]
56pub struct Request {
57    /// The ID of the user to which the link should be added.
58    #[schemars(with = "crate::admin::schema::Ulid")]
59    user_id: Ulid,
60
61    /// The ID of the upstream provider to which the link is for.
62    #[schemars(with = "crate::admin::schema::Ulid")]
63    provider_id: Ulid,
64
65    /// The subject (sub) claim of the user on the provider.
66    subject: String,
67
68    /// A human readable account name.
69    human_account_name: Option<String>,
70}
71
72pub fn doc(operation: TransformOperation) -> TransformOperation {
73    operation
74        .id("addUpstreamOAuthLink")
75        .summary("Add an upstream OAuth 2.0 link")
76        .tag("upstream-oauth-link")
77        .response_with::<200, Json<SingleResponse<UpstreamOAuthLink>>, _>(|t| {
78            let [sample, ..] = UpstreamOAuthLink::samples();
79            let response = SingleResponse::new_canonical(sample);
80            t.description("An existing Upstream OAuth 2.0 link was associated to a user")
81                .example(response)
82        })
83        .response_with::<201, Json<SingleResponse<UpstreamOAuthLink>>, _>(|t| {
84            let [sample, ..] = UpstreamOAuthLink::samples();
85            let response = SingleResponse::new_canonical(sample);
86            t.description("A new Upstream OAuth 2.0 link was created")
87                .example(response)
88        })
89        .response_with::<409, RouteError, _>(|t| {
90            let [provider_sample, ..] = UpstreamOAuthLink::samples();
91            let response = ErrorResponse::from_error(&RouteError::LinkAlreadyExists(
92                provider_sample.id(),
93                String::from("subject1"),
94            ));
95            t.description("The subject from the provider is already linked to another user")
96                .example(response)
97        })
98        .response_with::<404, RouteError, _>(|t| {
99            let response = ErrorResponse::from_error(&RouteError::UserNotFound(Ulid::nil()));
100            t.description("User or provider was not found")
101                .example(response)
102        })
103}
104
105#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.post", skip_all, err)]
106pub async fn handler(
107    CallContext {
108        mut repo, clock, ..
109    }: CallContext,
110    NoApi(mut rng): NoApi<BoxRng>,
111    Json(params): Json<Request>,
112) -> Result<(StatusCode, Json<SingleResponse<UpstreamOAuthLink>>), RouteError> {
113    // Find the user
114    let user = repo
115        .user()
116        .lookup(params.user_id)
117        .await?
118        .ok_or(RouteError::UserNotFound(params.user_id))?;
119
120    // Find the provider
121    let provider = repo
122        .upstream_oauth_provider()
123        .lookup(params.provider_id)
124        .await?
125        .ok_or(RouteError::ProviderNotFound(params.provider_id))?;
126
127    let maybe_link = repo
128        .upstream_oauth_link()
129        .find_by_subject(&provider, &params.subject)
130        .await?;
131    if let Some(mut link) = maybe_link {
132        if link.user_id.is_some() {
133            return Err(RouteError::LinkAlreadyExists(
134                link.provider_id,
135                link.subject,
136            ));
137        }
138
139        repo.upstream_oauth_link()
140            .associate_to_user(&link, &user)
141            .await?;
142        link.user_id = Some(user.id);
143
144        repo.save().await?;
145
146        return Ok((
147            StatusCode::OK,
148            Json(SingleResponse::new_canonical(link.into())),
149        ));
150    }
151
152    let mut link = repo
153        .upstream_oauth_link()
154        .add(
155            &mut rng,
156            &clock,
157            &provider,
158            params.subject,
159            params.human_account_name,
160        )
161        .await?;
162
163    repo.upstream_oauth_link()
164        .associate_to_user(&link, &user)
165        .await?;
166    link.user_id = Some(user.id);
167
168    repo.save().await?;
169
170    Ok((
171        StatusCode::CREATED,
172        Json(SingleResponse::new_canonical(link.into())),
173    ))
174}
175
176#[cfg(test)]
177mod tests {
178    use hyper::{Request, StatusCode};
179    use insta::assert_json_snapshot;
180    use sqlx::PgPool;
181    use ulid::Ulid;
182
183    use super::super::test_utils;
184    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
185
186    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
187    async fn test_create(pool: PgPool) {
188        setup();
189        let mut state = TestState::from_pool(pool).await.unwrap();
190        let token = state.token_with_scope("urn:mas:admin").await;
191        let mut rng = state.rng();
192        let mut repo = state.repository().await.unwrap();
193
194        let alice = repo
195            .user()
196            .add(&mut rng, &state.clock, "alice".to_owned())
197            .await
198            .unwrap();
199
200        let provider = repo
201            .upstream_oauth_provider()
202            .add(
203                &mut rng,
204                &state.clock,
205                test_utils::oidc_provider_params("provider1"),
206            )
207            .await
208            .unwrap();
209
210        repo.save().await.unwrap();
211
212        let request = Request::post("/api/admin/v1/upstream-oauth-links")
213            .bearer(&token)
214            .json(serde_json::json!({
215                "user_id": alice.id,
216                "provider_id": provider.id,
217                "subject": "subject1"
218            }));
219        let response = state.request(request).await;
220        response.assert_status(StatusCode::CREATED);
221        let body: serde_json::Value = response.json();
222        assert_json_snapshot!(body, @r###"
223        {
224          "data": {
225            "type": "upstream-oauth-link",
226            "id": "01FSHN9AG07HNEZXNQM2KNBNF6",
227            "attributes": {
228              "created_at": "2022-01-16T14:40:00Z",
229              "provider_id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
230              "subject": "subject1",
231              "user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
232              "human_account_name": null
233            },
234            "links": {
235              "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG07HNEZXNQM2KNBNF6"
236            }
237          },
238          "links": {
239            "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG07HNEZXNQM2KNBNF6"
240          }
241        }
242        "###);
243    }
244
245    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
246    async fn test_association(pool: PgPool) {
247        setup();
248        let mut state = TestState::from_pool(pool).await.unwrap();
249        let token = state.token_with_scope("urn:mas:admin").await;
250        let mut rng = state.rng();
251        let mut repo = state.repository().await.unwrap();
252
253        let alice = repo
254            .user()
255            .add(&mut rng, &state.clock, "alice".to_owned())
256            .await
257            .unwrap();
258
259        let provider = repo
260            .upstream_oauth_provider()
261            .add(
262                &mut rng,
263                &state.clock,
264                test_utils::oidc_provider_params("provider1"),
265            )
266            .await
267            .unwrap();
268
269        // Existing unfinished link
270        repo.upstream_oauth_link()
271            .add(
272                &mut rng,
273                &state.clock,
274                &provider,
275                String::from("subject1"),
276                None,
277            )
278            .await
279            .unwrap();
280
281        repo.save().await.unwrap();
282
283        let request = Request::post("/api/admin/v1/upstream-oauth-links")
284            .bearer(&token)
285            .json(serde_json::json!({
286                "user_id": alice.id,
287                "provider_id": provider.id,
288                "subject": "subject1"
289            }));
290        let response = state.request(request).await;
291        response.assert_status(StatusCode::OK);
292        let body: serde_json::Value = response.json();
293        assert_json_snapshot!(body, @r###"
294        {
295          "data": {
296            "type": "upstream-oauth-link",
297            "id": "01FSHN9AG09NMZYX8MFYH578R9",
298            "attributes": {
299              "created_at": "2022-01-16T14:40:00Z",
300              "provider_id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
301              "subject": "subject1",
302              "user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
303              "human_account_name": null
304            },
305            "links": {
306              "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG09NMZYX8MFYH578R9"
307            }
308          },
309          "links": {
310            "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG09NMZYX8MFYH578R9"
311          }
312        }
313        "###);
314    }
315
316    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
317    async fn test_link_already_exists(pool: PgPool) {
318        setup();
319        let mut state = TestState::from_pool(pool).await.unwrap();
320        let token = state.token_with_scope("urn:mas:admin").await;
321        let mut rng = state.rng();
322        let mut repo = state.repository().await.unwrap();
323
324        let alice = repo
325            .user()
326            .add(&mut rng, &state.clock, "alice".to_owned())
327            .await
328            .unwrap();
329
330        let bob = repo
331            .user()
332            .add(&mut rng, &state.clock, "bob".to_owned())
333            .await
334            .unwrap();
335
336        let provider = repo
337            .upstream_oauth_provider()
338            .add(
339                &mut rng,
340                &state.clock,
341                test_utils::oidc_provider_params("provider1"),
342            )
343            .await
344            .unwrap();
345
346        let link = repo
347            .upstream_oauth_link()
348            .add(
349                &mut rng,
350                &state.clock,
351                &provider,
352                String::from("subject1"),
353                None,
354            )
355            .await
356            .unwrap();
357
358        repo.upstream_oauth_link()
359            .associate_to_user(&link, &alice)
360            .await
361            .unwrap();
362
363        repo.save().await.unwrap();
364
365        let request = Request::post("/api/admin/v1/upstream-oauth-links")
366            .bearer(&token)
367            .json(serde_json::json!({
368                "user_id": bob.id,
369                "provider_id": provider.id,
370                "subject": "subject1"
371            }));
372        let response = state.request(request).await;
373        response.assert_status(StatusCode::CONFLICT);
374        let body: serde_json::Value = response.json();
375        assert_json_snapshot!(body, @r###"
376        {
377          "errors": [
378            {
379              "title": "Upstream Oauth 2.0 Provider ID 01FSHN9AG09NMZYX8MFYH578R9 with subject subject1 is already linked to a user"
380            }
381          ]
382        }
383        "###);
384    }
385
386    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
387    async fn test_user_not_found(pool: PgPool) {
388        setup();
389        let mut state = TestState::from_pool(pool).await.unwrap();
390        let token = state.token_with_scope("urn:mas:admin").await;
391        let mut rng = state.rng();
392        let mut repo = state.repository().await.unwrap();
393
394        let provider = repo
395            .upstream_oauth_provider()
396            .add(
397                &mut rng,
398                &state.clock,
399                test_utils::oidc_provider_params("provider1"),
400            )
401            .await
402            .unwrap();
403
404        repo.save().await.unwrap();
405
406        let request = Request::post("/api/admin/v1/upstream-oauth-links")
407            .bearer(&token)
408            .json(serde_json::json!({
409                "user_id": Ulid::nil(),
410                "provider_id": provider.id,
411                "subject": "subject1"
412            }));
413        let response = state.request(request).await;
414        response.assert_status(StatusCode::NOT_FOUND);
415        let body: serde_json::Value = response.json();
416        assert_json_snapshot!(body, @r###"
417        {
418          "errors": [
419            {
420              "title": "User ID 00000000000000000000000000 not found"
421            }
422          ]
423        }
424        "###);
425    }
426
427    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
428    async fn test_provider_not_found(pool: PgPool) {
429        setup();
430        let mut state = TestState::from_pool(pool).await.unwrap();
431        let token = state.token_with_scope("urn:mas:admin").await;
432        let mut rng = state.rng();
433        let mut repo = state.repository().await.unwrap();
434
435        let alice = repo
436            .user()
437            .add(&mut rng, &state.clock, "alice".to_owned())
438            .await
439            .unwrap();
440
441        repo.save().await.unwrap();
442
443        let request = Request::post("/api/admin/v1/upstream-oauth-links")
444            .bearer(&token)
445            .json(serde_json::json!({
446                "user_id": alice.id,
447                "provider_id": Ulid::nil(),
448                "subject": "subject1"
449            }));
450        let response = state.request(request).await;
451        response.assert_status(StatusCode::NOT_FOUND);
452        let body: serde_json::Value = response.json();
453        assert_json_snapshot!(body, @r###"
454        {
455          "errors": [
456            {
457              "title": "Upstream OAuth 2.0 Provider ID 00000000000000000000000000 not found"
458            }
459          ]
460        }
461        "###);
462    }
463}