mas_handlers/admin/v1/oauth2_sessions/
list.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::str::FromStr;
8
9use aide::{OperationIo, transform::TransformOperation};
10use axum::{
11    Json,
12    extract::{Query, rejection::QueryRejection},
13    response::IntoResponse,
14};
15use axum_macros::FromRequestParts;
16use hyper::StatusCode;
17use mas_storage::{Page, oauth2::OAuth2SessionFilter};
18use oauth2_types::scope::{Scope, ScopeToken};
19use schemars::JsonSchema;
20use serde::Deserialize;
21use ulid::Ulid;
22
23use crate::{
24    admin::{
25        call_context::CallContext,
26        model::{OAuth2Session, Resource},
27        params::Pagination,
28        response::{ErrorResponse, PaginatedResponse},
29    },
30    impl_from_error_for_route,
31};
32
33#[derive(Deserialize, JsonSchema, Clone, Copy)]
34#[serde(rename_all = "snake_case")]
35enum OAuth2SessionStatus {
36    Active,
37    Finished,
38}
39
40impl std::fmt::Display for OAuth2SessionStatus {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            Self::Active => write!(f, "active"),
44            Self::Finished => write!(f, "finished"),
45        }
46    }
47}
48
49#[derive(Deserialize, JsonSchema, Clone, Copy)]
50#[serde(rename_all = "snake_case")]
51enum OAuth2ClientKind {
52    Dynamic,
53    Static,
54}
55
56impl std::fmt::Display for OAuth2ClientKind {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            Self::Dynamic => write!(f, "dynamic"),
60            Self::Static => write!(f, "static"),
61        }
62    }
63}
64
65#[derive(FromRequestParts, Deserialize, JsonSchema, OperationIo)]
66#[serde(rename = "OAuth2SessionFilter")]
67#[aide(input_with = "Query<FilterParams>")]
68#[from_request(via(Query), rejection(RouteError))]
69pub struct FilterParams {
70    /// Retrieve the items for the given user
71    #[serde(rename = "filter[user]")]
72    #[schemars(with = "Option<crate::admin::schema::Ulid>")]
73    user: Option<Ulid>,
74
75    /// Retrieve the items for the given client
76    #[serde(rename = "filter[client]")]
77    #[schemars(with = "Option<crate::admin::schema::Ulid>")]
78    client: Option<Ulid>,
79
80    /// Retrieve the items only for a specific client kind
81    #[serde(rename = "filter[client-kind]")]
82    client_kind: Option<OAuth2ClientKind>,
83
84    /// Retrieve the items started from the given browser session
85    #[serde(rename = "filter[user-session]")]
86    #[schemars(with = "Option<crate::admin::schema::Ulid>")]
87    user_session: Option<Ulid>,
88
89    /// Retrieve the items with the given scope
90    #[serde(default, rename = "filter[scope]")]
91    scope: Vec<String>,
92
93    /// Retrieve the items with the given status
94    ///
95    /// Defaults to retrieve all sessions, including finished ones.
96    ///
97    /// * `active`: Only retrieve active sessions
98    ///
99    /// * `finished`: Only retrieve finished sessions
100    #[serde(rename = "filter[status]")]
101    status: Option<OAuth2SessionStatus>,
102}
103
104impl std::fmt::Display for FilterParams {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        let mut sep = '?';
107
108        if let Some(user) = self.user {
109            write!(f, "{sep}filter[user]={user}")?;
110            sep = '&';
111        }
112
113        if let Some(client) = self.client {
114            write!(f, "{sep}filter[client]={client}")?;
115            sep = '&';
116        }
117
118        if let Some(client_kind) = self.client_kind {
119            write!(f, "{sep}filter[client-kind]={client_kind}")?;
120            sep = '&';
121        }
122
123        if let Some(user_session) = self.user_session {
124            write!(f, "{sep}filter[user-session]={user_session}")?;
125            sep = '&';
126        }
127
128        for scope in &self.scope {
129            write!(f, "{sep}filter[scope]={scope}")?;
130            sep = '&';
131        }
132
133        if let Some(status) = self.status {
134            write!(f, "{sep}filter[status]={status}")?;
135            sep = '&';
136        }
137
138        let _ = sep;
139        Ok(())
140    }
141}
142
143#[derive(Debug, thiserror::Error, OperationIo)]
144#[aide(output_with = "Json<ErrorResponse>")]
145pub enum RouteError {
146    #[error(transparent)]
147    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
148
149    #[error("User ID {0} not found")]
150    UserNotFound(Ulid),
151
152    #[error("Client ID {0} not found")]
153    ClientNotFound(Ulid),
154
155    #[error("User session ID {0} not found")]
156    UserSessionNotFound(Ulid),
157
158    #[error("Invalid filter parameters")]
159    InvalidFilter(#[from] QueryRejection),
160
161    #[error("Invalid scope {0:?} in filter parameters")]
162    InvalidScope(String),
163}
164
165impl_from_error_for_route!(mas_storage::RepositoryError);
166
167impl IntoResponse for RouteError {
168    fn into_response(self) -> axum::response::Response {
169        let error = ErrorResponse::from_error(&self);
170        let status = match self {
171            Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
172            Self::UserNotFound(_) | Self::ClientNotFound(_) | Self::UserSessionNotFound(_) => {
173                StatusCode::NOT_FOUND
174            }
175            Self::InvalidScope(_) | Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
176        };
177        (status, Json(error)).into_response()
178    }
179}
180
181pub fn doc(operation: TransformOperation) -> TransformOperation {
182    operation
183        .id("listOAuth2Sessions")
184        .summary("List OAuth 2.0 sessions")
185        .description("Retrieve a list of OAuth 2.0 sessions.
186Note that by default, all sessions, including finished ones are returned, with the oldest first.
187Use the `filter[status]` parameter to filter the sessions by their status and `page[last]` parameter to retrieve the last N sessions.")
188        .tag("oauth2-session")
189        .response_with::<200, Json<PaginatedResponse<OAuth2Session>>, _>(|t| {
190            let sessions = OAuth2Session::samples();
191            let pagination = mas_storage::Pagination::first(sessions.len());
192            let page = Page {
193                edges: sessions.into(),
194                has_next_page: true,
195                has_previous_page: false,
196            };
197
198            t.description("Paginated response of OAuth 2.0 sessions")
199                .example(PaginatedResponse::new(
200                    page,
201                    pagination,
202                    42,
203                    OAuth2Session::PATH,
204                ))
205        })
206        .response_with::<404, RouteError, _>(|t| {
207            let response = ErrorResponse::from_error(&RouteError::UserNotFound(Ulid::nil()));
208            t.description("User was not found").example(response)
209        })
210        .response_with::<400, RouteError, _>(|t| {
211            let response = ErrorResponse::from_error(&RouteError::InvalidScope("not a valid scope".to_owned()));
212            t.description("Invalid scope").example(response)
213        })
214}
215
216#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all, err)]
217pub async fn handler(
218    CallContext { mut repo, .. }: CallContext,
219    Pagination(pagination): Pagination,
220    params: FilterParams,
221) -> Result<Json<PaginatedResponse<OAuth2Session>>, RouteError> {
222    let base = format!("{path}{params}", path = OAuth2Session::PATH);
223    let filter = OAuth2SessionFilter::default();
224
225    // Load the user from the filter
226    let user = if let Some(user_id) = params.user {
227        let user = repo
228            .user()
229            .lookup(user_id)
230            .await?
231            .ok_or(RouteError::UserNotFound(user_id))?;
232
233        Some(user)
234    } else {
235        None
236    };
237
238    let filter = match &user {
239        Some(user) => filter.for_user(user),
240        None => filter,
241    };
242
243    let client = if let Some(client_id) = params.client {
244        let client = repo
245            .oauth2_client()
246            .lookup(client_id)
247            .await?
248            .ok_or(RouteError::ClientNotFound(client_id))?;
249
250        Some(client)
251    } else {
252        None
253    };
254
255    let filter = match &client {
256        Some(client) => filter.for_client(client),
257        None => filter,
258    };
259
260    let filter = match params.client_kind {
261        Some(OAuth2ClientKind::Dynamic) => filter.only_dynamic_clients(),
262        Some(OAuth2ClientKind::Static) => filter.only_static_clients(),
263        None => filter,
264    };
265
266    let user_session = if let Some(user_session_id) = params.user_session {
267        let user_session = repo
268            .browser_session()
269            .lookup(user_session_id)
270            .await?
271            .ok_or(RouteError::UserSessionNotFound(user_session_id))?;
272
273        Some(user_session)
274    } else {
275        None
276    };
277
278    let filter = match &user_session {
279        Some(user_session) => filter.for_browser_session(user_session),
280        None => filter,
281    };
282
283    let scope: Scope = params
284        .scope
285        .into_iter()
286        .map(|s| ScopeToken::from_str(&s).map_err(|_| RouteError::InvalidScope(s)))
287        .collect::<Result<_, _>>()?;
288
289    let filter = if scope.is_empty() {
290        filter
291    } else {
292        filter.with_scope(&scope)
293    };
294
295    let filter = match params.status {
296        Some(OAuth2SessionStatus::Active) => filter.active_only(),
297        Some(OAuth2SessionStatus::Finished) => filter.finished_only(),
298        None => filter,
299    };
300
301    let page = repo.oauth2_session().list(filter, pagination).await?;
302    let count = repo.oauth2_session().count(filter).await?;
303
304    Ok(Json(PaginatedResponse::new(
305        page.map(OAuth2Session::from),
306        pagination,
307        count,
308        &base,
309    )))
310}
311
312#[cfg(test)]
313mod tests {
314    use hyper::{Request, StatusCode};
315    use sqlx::PgPool;
316
317    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
318
319    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
320    async fn test_oauth2_simple_session_list(pool: PgPool) {
321        setup();
322        let mut state = TestState::from_pool(pool).await.unwrap();
323        let token = state.token_with_scope("urn:mas:admin").await;
324
325        // We already have a session because of the token above
326        let request = Request::get("/api/admin/v1/oauth2-sessions")
327            .bearer(&token)
328            .empty();
329        let response = state.request(request).await;
330        response.assert_status(StatusCode::OK);
331        let body: serde_json::Value = response.json();
332        insta::assert_json_snapshot!(body, @r###"
333        {
334          "meta": {
335            "count": 1
336          },
337          "data": [
338            {
339              "type": "oauth2-session",
340              "id": "01FSHN9AG0MKGTBNZ16RDR3PVY",
341              "attributes": {
342                "created_at": "2022-01-16T14:40:00Z",
343                "finished_at": null,
344                "user_id": null,
345                "user_session_id": null,
346                "client_id": "01FSHN9AG0FAQ50MT1E9FFRPZR",
347                "scope": "urn:mas:admin",
348                "user_agent": null,
349                "last_active_at": null,
350                "last_active_ip": null
351              },
352              "links": {
353                "self": "/api/admin/v1/oauth2-sessions/01FSHN9AG0MKGTBNZ16RDR3PVY"
354              }
355            }
356          ],
357          "links": {
358            "self": "/api/admin/v1/oauth2-sessions?page[first]=10",
359            "first": "/api/admin/v1/oauth2-sessions?page[first]=10",
360            "last": "/api/admin/v1/oauth2-sessions?page[last]=10"
361          }
362        }
363        "###);
364    }
365}