mas_handlers/compat/
refresh.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use chrono::Duration;
9use hyper::StatusCode;
10use mas_axum_utils::sentry::SentryEventID;
11use mas_data_model::{SiteConfig, TokenFormatError, TokenType};
12use mas_storage::{
13    BoxClock, BoxRepository, BoxRng, Clock,
14    compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
15};
16use serde::{Deserialize, Serialize};
17use serde_with::{DurationMilliSeconds, serde_as};
18use thiserror::Error;
19
20use super::MatrixError;
21use crate::{BoundActivityTracker, impl_from_error_for_route};
22
23#[derive(Debug, Deserialize)]
24pub struct RequestBody {
25    refresh_token: String,
26}
27
28#[derive(Debug, Error)]
29pub enum RouteError {
30    #[error(transparent)]
31    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
32
33    #[error("invalid token")]
34    InvalidToken,
35
36    #[error("refresh token already consumed")]
37    RefreshTokenConsumed,
38
39    #[error("invalid session")]
40    InvalidSession,
41
42    #[error("unknown session")]
43    UnknownSession,
44}
45
46impl IntoResponse for RouteError {
47    fn into_response(self) -> axum::response::Response {
48        let event_id = sentry::capture_error(&self);
49        let response = match self {
50            Self::Internal(_) | Self::UnknownSession => MatrixError {
51                errcode: "M_UNKNOWN",
52                error: "Internal error",
53                status: StatusCode::INTERNAL_SERVER_ERROR,
54            },
55            Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError {
56                errcode: "M_UNKNOWN_TOKEN",
57                error: "Invalid refresh token",
58                status: StatusCode::UNAUTHORIZED,
59            },
60        };
61
62        (SentryEventID::from(event_id), response).into_response()
63    }
64}
65
66impl_from_error_for_route!(mas_storage::RepositoryError);
67
68impl From<TokenFormatError> for RouteError {
69    fn from(_e: TokenFormatError) -> Self {
70        Self::InvalidToken
71    }
72}
73
74#[serde_as]
75#[derive(Debug, Serialize)]
76pub struct ResponseBody {
77    access_token: String,
78    refresh_token: String,
79    #[serde_as(as = "DurationMilliSeconds<i64>")]
80    expires_in_ms: Duration,
81}
82
83#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all, err)]
84pub(crate) async fn post(
85    mut rng: BoxRng,
86    clock: BoxClock,
87    mut repo: BoxRepository,
88    activity_tracker: BoundActivityTracker,
89    State(site_config): State<SiteConfig>,
90    Json(input): Json<RequestBody>,
91) -> Result<impl IntoResponse, RouteError> {
92    let token_type = TokenType::check(&input.refresh_token)?;
93
94    if token_type != TokenType::CompatRefreshToken {
95        return Err(RouteError::InvalidToken);
96    }
97
98    let refresh_token = repo
99        .compat_refresh_token()
100        .find_by_token(&input.refresh_token)
101        .await?
102        .ok_or(RouteError::InvalidToken)?;
103
104    if !refresh_token.is_valid() {
105        return Err(RouteError::RefreshTokenConsumed);
106    }
107
108    let session = repo
109        .compat_session()
110        .lookup(refresh_token.session_id)
111        .await?
112        .ok_or(RouteError::UnknownSession)?;
113
114    if !session.is_valid() {
115        return Err(RouteError::InvalidSession);
116    }
117
118    activity_tracker
119        .record_compat_session(&clock, &session)
120        .await;
121
122    let access_token = repo
123        .compat_access_token()
124        .lookup(refresh_token.access_token_id)
125        .await?
126        .filter(|t| t.is_valid(clock.now()));
127
128    let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
129    let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
130
131    let expires_in = site_config.compat_token_ttl;
132    let new_access_token = repo
133        .compat_access_token()
134        .add(
135            &mut rng,
136            &clock,
137            &session,
138            new_access_token_str,
139            Some(expires_in),
140        )
141        .await?;
142    let new_refresh_token = repo
143        .compat_refresh_token()
144        .add(
145            &mut rng,
146            &clock,
147            &session,
148            &new_access_token,
149            new_refresh_token_str,
150        )
151        .await?;
152
153    repo.compat_refresh_token()
154        .consume(&clock, refresh_token)
155        .await?;
156
157    if let Some(access_token) = access_token {
158        repo.compat_access_token()
159            .expire(&clock, access_token)
160            .await?;
161    }
162
163    repo.save().await?;
164
165    Ok(Json(ResponseBody {
166        access_token: new_access_token.token,
167        refresh_token: new_refresh_token.token,
168        expires_in_ms: expires_in,
169    }))
170}