mas_handlers/compat/
refresh.rs1use axum::{Json, extract::State, response::IntoResponse};
8use chrono::Duration;
9use hyper::StatusCode;
10use mas_axum_utils::sentry::SentryEventID;
11use mas_data_model::{SiteConfig, TokenFormatError, TokenType};
12use mas_storage::{
13 BoxClock, BoxRepository, BoxRng, Clock,
14 compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
15};
16use serde::{Deserialize, Serialize};
17use serde_with::{DurationMilliSeconds, serde_as};
18use thiserror::Error;
19
20use super::MatrixError;
21use crate::{BoundActivityTracker, impl_from_error_for_route};
22
23#[derive(Debug, Deserialize)]
24pub struct RequestBody {
25 refresh_token: String,
26}
27
28#[derive(Debug, Error)]
29pub enum RouteError {
30 #[error(transparent)]
31 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
32
33 #[error("invalid token")]
34 InvalidToken,
35
36 #[error("refresh token already consumed")]
37 RefreshTokenConsumed,
38
39 #[error("invalid session")]
40 InvalidSession,
41
42 #[error("unknown session")]
43 UnknownSession,
44}
45
46impl IntoResponse for RouteError {
47 fn into_response(self) -> axum::response::Response {
48 let event_id = sentry::capture_error(&self);
49 let response = match self {
50 Self::Internal(_) | Self::UnknownSession => MatrixError {
51 errcode: "M_UNKNOWN",
52 error: "Internal error",
53 status: StatusCode::INTERNAL_SERVER_ERROR,
54 },
55 Self::InvalidToken | Self::InvalidSession | Self::RefreshTokenConsumed => MatrixError {
56 errcode: "M_UNKNOWN_TOKEN",
57 error: "Invalid refresh token",
58 status: StatusCode::UNAUTHORIZED,
59 },
60 };
61
62 (SentryEventID::from(event_id), response).into_response()
63 }
64}
65
66impl_from_error_for_route!(mas_storage::RepositoryError);
67
68impl From<TokenFormatError> for RouteError {
69 fn from(_e: TokenFormatError) -> Self {
70 Self::InvalidToken
71 }
72}
73
74#[serde_as]
75#[derive(Debug, Serialize)]
76pub struct ResponseBody {
77 access_token: String,
78 refresh_token: String,
79 #[serde_as(as = "DurationMilliSeconds<i64>")]
80 expires_in_ms: Duration,
81}
82
83#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all, err)]
84pub(crate) async fn post(
85 mut rng: BoxRng,
86 clock: BoxClock,
87 mut repo: BoxRepository,
88 activity_tracker: BoundActivityTracker,
89 State(site_config): State<SiteConfig>,
90 Json(input): Json<RequestBody>,
91) -> Result<impl IntoResponse, RouteError> {
92 let token_type = TokenType::check(&input.refresh_token)?;
93
94 if token_type != TokenType::CompatRefreshToken {
95 return Err(RouteError::InvalidToken);
96 }
97
98 let refresh_token = repo
99 .compat_refresh_token()
100 .find_by_token(&input.refresh_token)
101 .await?
102 .ok_or(RouteError::InvalidToken)?;
103
104 if !refresh_token.is_valid() {
105 return Err(RouteError::RefreshTokenConsumed);
106 }
107
108 let session = repo
109 .compat_session()
110 .lookup(refresh_token.session_id)
111 .await?
112 .ok_or(RouteError::UnknownSession)?;
113
114 if !session.is_valid() {
115 return Err(RouteError::InvalidSession);
116 }
117
118 activity_tracker
119 .record_compat_session(&clock, &session)
120 .await;
121
122 let access_token = repo
123 .compat_access_token()
124 .lookup(refresh_token.access_token_id)
125 .await?
126 .filter(|t| t.is_valid(clock.now()));
127
128 let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
129 let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
130
131 let expires_in = site_config.compat_token_ttl;
132 let new_access_token = repo
133 .compat_access_token()
134 .add(
135 &mut rng,
136 &clock,
137 &session,
138 new_access_token_str,
139 Some(expires_in),
140 )
141 .await?;
142 let new_refresh_token = repo
143 .compat_refresh_token()
144 .add(
145 &mut rng,
146 &clock,
147 &session,
148 &new_access_token,
149 new_refresh_token_str,
150 )
151 .await?;
152
153 repo.compat_refresh_token()
154 .consume(&clock, refresh_token)
155 .await?;
156
157 if let Some(access_token) = access_token {
158 repo.compat_access_token()
159 .expire(&clock, access_token)
160 .await?;
161 }
162
163 repo.save().await?;
164
165 Ok(Json(ResponseBody {
166 access_token: new_access_token.token,
167 refresh_token: new_refresh_token.token,
168 expires_in_ms: expires_in,
169 }))
170}