mas_handlers/admin/v1/oauth2_sessions/
list.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::str::FromStr;
8
9use aide::{OperationIo, transform::TransformOperation};
10use axum::{Json, response::IntoResponse};
11use axum_extra::extract::{Query, QueryRejection};
12use axum_macros::FromRequestParts;
13use hyper::StatusCode;
14use mas_axum_utils::record_error;
15use mas_storage::{Page, oauth2::OAuth2SessionFilter};
16use oauth2_types::scope::{Scope, ScopeToken};
17use schemars::JsonSchema;
18use serde::Deserialize;
19use ulid::Ulid;
20
21use crate::{
22    admin::{
23        call_context::CallContext,
24        model::{OAuth2Session, Resource},
25        params::{IncludeCount, Pagination},
26        response::{ErrorResponse, PaginatedResponse},
27    },
28    impl_from_error_for_route,
29};
30
31#[derive(Deserialize, JsonSchema, Clone, Copy)]
32#[serde(rename_all = "snake_case")]
33enum OAuth2SessionStatus {
34    Active,
35    Finished,
36}
37
38impl std::fmt::Display for OAuth2SessionStatus {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::Active => write!(f, "active"),
42            Self::Finished => write!(f, "finished"),
43        }
44    }
45}
46
47#[derive(Deserialize, JsonSchema, Clone, Copy)]
48#[serde(rename_all = "snake_case")]
49enum OAuth2ClientKind {
50    Dynamic,
51    Static,
52}
53
54impl std::fmt::Display for OAuth2ClientKind {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            Self::Dynamic => write!(f, "dynamic"),
58            Self::Static => write!(f, "static"),
59        }
60    }
61}
62
63#[derive(FromRequestParts, Deserialize, JsonSchema, OperationIo)]
64#[serde(rename = "OAuth2SessionFilter")]
65#[aide(input_with = "Query<FilterParams>")]
66#[from_request(via(Query), rejection(RouteError))]
67pub struct FilterParams {
68    /// Retrieve the items for the given user
69    #[serde(rename = "filter[user]")]
70    #[schemars(with = "Option<crate::admin::schema::Ulid>")]
71    user: Option<Ulid>,
72
73    /// Retrieve the items for the given client
74    #[serde(rename = "filter[client]")]
75    #[schemars(with = "Option<crate::admin::schema::Ulid>")]
76    client: Option<Ulid>,
77
78    /// Retrieve the items only for a specific client kind
79    #[serde(rename = "filter[client-kind]")]
80    client_kind: Option<OAuth2ClientKind>,
81
82    /// Retrieve the items started from the given browser session
83    #[serde(rename = "filter[user-session]")]
84    #[schemars(with = "Option<crate::admin::schema::Ulid>")]
85    user_session: Option<Ulid>,
86
87    /// Retrieve the items with the given scope
88    #[serde(default, rename = "filter[scope]")]
89    scope: Vec<String>,
90
91    /// Retrieve the items with the given status
92    ///
93    /// Defaults to retrieve all sessions, including finished ones.
94    ///
95    /// * `active`: Only retrieve active sessions
96    ///
97    /// * `finished`: Only retrieve finished sessions
98    #[serde(rename = "filter[status]")]
99    status: Option<OAuth2SessionStatus>,
100}
101
102impl std::fmt::Display for FilterParams {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        let mut sep = '?';
105
106        if let Some(user) = self.user {
107            write!(f, "{sep}filter[user]={user}")?;
108            sep = '&';
109        }
110
111        if let Some(client) = self.client {
112            write!(f, "{sep}filter[client]={client}")?;
113            sep = '&';
114        }
115
116        if let Some(client_kind) = self.client_kind {
117            write!(f, "{sep}filter[client-kind]={client_kind}")?;
118            sep = '&';
119        }
120
121        if let Some(user_session) = self.user_session {
122            write!(f, "{sep}filter[user-session]={user_session}")?;
123            sep = '&';
124        }
125
126        for scope in &self.scope {
127            write!(f, "{sep}filter[scope]={scope}")?;
128            sep = '&';
129        }
130
131        if let Some(status) = self.status {
132            write!(f, "{sep}filter[status]={status}")?;
133            sep = '&';
134        }
135
136        let _ = sep;
137        Ok(())
138    }
139}
140
141#[derive(Debug, thiserror::Error, OperationIo)]
142#[aide(output_with = "Json<ErrorResponse>")]
143pub enum RouteError {
144    #[error(transparent)]
145    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
146
147    #[error("User ID {0} not found")]
148    UserNotFound(Ulid),
149
150    #[error("Client ID {0} not found")]
151    ClientNotFound(Ulid),
152
153    #[error("User session ID {0} not found")]
154    UserSessionNotFound(Ulid),
155
156    #[error("Invalid filter parameters")]
157    InvalidFilter(#[from] QueryRejection),
158
159    #[error("Invalid scope {0:?} in filter parameters")]
160    InvalidScope(String),
161}
162
163impl_from_error_for_route!(mas_storage::RepositoryError);
164
165impl IntoResponse for RouteError {
166    fn into_response(self) -> axum::response::Response {
167        let error = ErrorResponse::from_error(&self);
168        let sentry_event_id = record_error!(self, RouteError::Internal(_));
169        let status = match self {
170            Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
171            Self::UserNotFound(_) | Self::ClientNotFound(_) | Self::UserSessionNotFound(_) => {
172                StatusCode::NOT_FOUND
173            }
174            Self::InvalidScope(_) | Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
175        };
176        (status, sentry_event_id, Json(error)).into_response()
177    }
178}
179
180pub fn doc(operation: TransformOperation) -> TransformOperation {
181    operation
182        .id("listOAuth2Sessions")
183        .summary("List OAuth 2.0 sessions")
184        .description("Retrieve a list of OAuth 2.0 sessions.
185Note that by default, all sessions, including finished ones are returned, with the oldest first.
186Use the `filter[status]` parameter to filter the sessions by their status and `page[last]` parameter to retrieve the last N sessions.")
187        .tag("oauth2-session")
188        .response_with::<200, Json<PaginatedResponse<OAuth2Session>>, _>(|t| {
189            let sessions = OAuth2Session::samples();
190            let pagination = mas_storage::Pagination::first(sessions.len());
191            let page = Page {
192                edges: sessions
193                    .into_iter()
194                    .map(|node| mas_storage::pagination::Edge {
195                        cursor: node.id(),
196                        node,
197                    })
198                    .collect(),
199                has_next_page: true,
200                has_previous_page: false,
201            };
202
203            t.description("Paginated response of OAuth 2.0 sessions")
204                .example(PaginatedResponse::for_page(
205                    page,
206                    pagination,
207                    Some(42),
208                    OAuth2Session::PATH,
209                ))
210        })
211        .response_with::<404, RouteError, _>(|t| {
212            let response = ErrorResponse::from_error(&RouteError::UserNotFound(Ulid::nil()));
213            t.description("User was not found").example(response)
214        })
215        .response_with::<400, RouteError, _>(|t| {
216            let response = ErrorResponse::from_error(&RouteError::InvalidScope("not a valid scope".to_owned()));
217            t.description("Invalid scope").example(response)
218        })
219}
220
221#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all)]
222pub async fn handler(
223    CallContext { mut repo, .. }: CallContext,
224    Pagination(pagination, include_count): Pagination,
225    params: FilterParams,
226) -> Result<Json<PaginatedResponse<OAuth2Session>>, RouteError> {
227    let base = format!("{path}{params}", path = OAuth2Session::PATH);
228    let base = include_count.add_to_base(&base);
229    let filter = OAuth2SessionFilter::default();
230
231    // Load the user from the filter
232    let user = if let Some(user_id) = params.user {
233        let user = repo
234            .user()
235            .lookup(user_id)
236            .await?
237            .ok_or(RouteError::UserNotFound(user_id))?;
238
239        Some(user)
240    } else {
241        None
242    };
243
244    let filter = match &user {
245        Some(user) => filter.for_user(user),
246        None => filter,
247    };
248
249    let client = if let Some(client_id) = params.client {
250        let client = repo
251            .oauth2_client()
252            .lookup(client_id)
253            .await?
254            .ok_or(RouteError::ClientNotFound(client_id))?;
255
256        Some(client)
257    } else {
258        None
259    };
260
261    let filter = match &client {
262        Some(client) => filter.for_client(client),
263        None => filter,
264    };
265
266    let filter = match params.client_kind {
267        Some(OAuth2ClientKind::Dynamic) => filter.only_dynamic_clients(),
268        Some(OAuth2ClientKind::Static) => filter.only_static_clients(),
269        None => filter,
270    };
271
272    let user_session = if let Some(user_session_id) = params.user_session {
273        let user_session = repo
274            .browser_session()
275            .lookup(user_session_id)
276            .await?
277            .ok_or(RouteError::UserSessionNotFound(user_session_id))?;
278
279        Some(user_session)
280    } else {
281        None
282    };
283
284    let filter = match &user_session {
285        Some(user_session) => filter.for_browser_session(user_session),
286        None => filter,
287    };
288
289    let scope: Scope = params
290        .scope
291        .into_iter()
292        .map(|s| ScopeToken::from_str(&s).map_err(|_| RouteError::InvalidScope(s)))
293        .collect::<Result<_, _>>()?;
294
295    let filter = if scope.is_empty() {
296        filter
297    } else {
298        filter.with_scope(&scope)
299    };
300
301    let filter = match params.status {
302        Some(OAuth2SessionStatus::Active) => filter.active_only(),
303        Some(OAuth2SessionStatus::Finished) => filter.finished_only(),
304        None => filter,
305    };
306
307    let response = match include_count {
308        IncludeCount::True => {
309            let page = repo
310                .oauth2_session()
311                .list(filter, pagination)
312                .await?
313                .map(OAuth2Session::from);
314            let count = repo.oauth2_session().count(filter).await?;
315            PaginatedResponse::for_page(page, pagination, Some(count), &base)
316        }
317        IncludeCount::False => {
318            let page = repo
319                .oauth2_session()
320                .list(filter, pagination)
321                .await?
322                .map(OAuth2Session::from);
323            PaginatedResponse::for_page(page, pagination, None, &base)
324        }
325        IncludeCount::Only => {
326            let count = repo.oauth2_session().count(filter).await?;
327            PaginatedResponse::for_count_only(count, &base)
328        }
329    };
330
331    Ok(Json(response))
332}
333
334#[cfg(test)]
335mod tests {
336    use hyper::{Request, StatusCode};
337    use sqlx::PgPool;
338
339    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
340
341    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
342    async fn test_oauth2_simple_session_list(pool: PgPool) {
343        setup();
344        let mut state = TestState::from_pool(pool).await.unwrap();
345        let token = state.token_with_scope("urn:mas:admin").await;
346
347        // We already have a session because of the token above
348        let request = Request::get("/api/admin/v1/oauth2-sessions")
349            .bearer(&token)
350            .empty();
351        let response = state.request(request).await;
352        response.assert_status(StatusCode::OK);
353        let body: serde_json::Value = response.json();
354        insta::assert_json_snapshot!(body, @r#"
355        {
356          "meta": {
357            "count": 1
358          },
359          "data": [
360            {
361              "type": "oauth2-session",
362              "id": "01FSHN9AG0MKGTBNZ16RDR3PVY",
363              "attributes": {
364                "created_at": "2022-01-16T14:40:00Z",
365                "finished_at": null,
366                "user_id": null,
367                "user_session_id": null,
368                "client_id": "01FSHN9AG0FAQ50MT1E9FFRPZR",
369                "scope": "urn:mas:admin",
370                "user_agent": null,
371                "last_active_at": null,
372                "last_active_ip": null,
373                "human_name": null
374              },
375              "links": {
376                "self": "/api/admin/v1/oauth2-sessions/01FSHN9AG0MKGTBNZ16RDR3PVY"
377              },
378              "meta": {
379                "page": {
380                  "cursor": "01FSHN9AG0MKGTBNZ16RDR3PVY"
381                }
382              }
383            }
384          ],
385          "links": {
386            "self": "/api/admin/v1/oauth2-sessions?page[first]=10",
387            "first": "/api/admin/v1/oauth2-sessions?page[first]=10",
388            "last": "/api/admin/v1/oauth2-sessions?page[last]=10"
389          }
390        }
391        "#);
392
393        // Test count=false
394        let request = Request::get("/api/admin/v1/oauth2-sessions?count=false")
395            .bearer(&token)
396            .empty();
397        let response = state.request(request).await;
398        response.assert_status(StatusCode::OK);
399        let body: serde_json::Value = response.json();
400        insta::assert_json_snapshot!(body, @r#"
401        {
402          "data": [
403            {
404              "type": "oauth2-session",
405              "id": "01FSHN9AG0MKGTBNZ16RDR3PVY",
406              "attributes": {
407                "created_at": "2022-01-16T14:40:00Z",
408                "finished_at": null,
409                "user_id": null,
410                "user_session_id": null,
411                "client_id": "01FSHN9AG0FAQ50MT1E9FFRPZR",
412                "scope": "urn:mas:admin",
413                "user_agent": null,
414                "last_active_at": null,
415                "last_active_ip": null,
416                "human_name": null
417              },
418              "links": {
419                "self": "/api/admin/v1/oauth2-sessions/01FSHN9AG0MKGTBNZ16RDR3PVY"
420              },
421              "meta": {
422                "page": {
423                  "cursor": "01FSHN9AG0MKGTBNZ16RDR3PVY"
424                }
425              }
426            }
427          ],
428          "links": {
429            "self": "/api/admin/v1/oauth2-sessions?count=false&page[first]=10",
430            "first": "/api/admin/v1/oauth2-sessions?count=false&page[first]=10",
431            "last": "/api/admin/v1/oauth2-sessions?count=false&page[last]=10"
432          }
433        }
434        "#);
435
436        // Test count=only
437        let request = Request::get("/api/admin/v1/oauth2-sessions?count=only")
438            .bearer(&token)
439            .empty();
440        let response = state.request(request).await;
441        response.assert_status(StatusCode::OK);
442        let body: serde_json::Value = response.json();
443        insta::assert_json_snapshot!(body, @r#"
444        {
445          "meta": {
446            "count": 1
447          },
448          "links": {
449            "self": "/api/admin/v1/oauth2-sessions?count=only"
450          }
451        }
452        "#);
453    }
454}