1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
11};
12use mas_storage::{Clock, compat::CompatRefreshTokenRepository};
13use rand::RngCore;
14use sqlx::PgConnection;
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, tracing::ExecuteExt};
19
20pub struct PgCompatRefreshTokenRepository<'c> {
23 conn: &'c mut PgConnection,
24}
25
26impl<'c> PgCompatRefreshTokenRepository<'c> {
27 pub fn new(conn: &'c mut PgConnection) -> Self {
30 Self { conn }
31 }
32}
33
34struct CompatRefreshTokenLookup {
35 compat_refresh_token_id: Uuid,
36 refresh_token: String,
37 created_at: DateTime<Utc>,
38 consumed_at: Option<DateTime<Utc>>,
39 compat_access_token_id: Uuid,
40 compat_session_id: Uuid,
41}
42
43impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
44 fn from(value: CompatRefreshTokenLookup) -> Self {
45 let state = match value.consumed_at {
46 Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
47 None => CompatRefreshTokenState::Valid,
48 };
49
50 Self {
51 id: value.compat_refresh_token_id.into(),
52 state,
53 session_id: value.compat_session_id.into(),
54 token: value.refresh_token,
55 created_at: value.created_at,
56 access_token_id: value.compat_access_token_id.into(),
57 }
58 }
59}
60
61#[async_trait]
62impl CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'_> {
63 type Error = DatabaseError;
64
65 #[tracing::instrument(
66 name = "db.compat_refresh_token.lookup",
67 skip_all,
68 fields(
69 db.query.text,
70 compat_refresh_token.id = %id,
71 ),
72 err,
73 )]
74 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
75 let res = sqlx::query_as!(
76 CompatRefreshTokenLookup,
77 r#"
78 SELECT compat_refresh_token_id
79 , refresh_token
80 , created_at
81 , consumed_at
82 , compat_session_id
83 , compat_access_token_id
84
85 FROM compat_refresh_tokens
86
87 WHERE compat_refresh_token_id = $1
88 "#,
89 Uuid::from(id),
90 )
91 .traced()
92 .fetch_optional(&mut *self.conn)
93 .await?;
94
95 let Some(res) = res else { return Ok(None) };
96
97 Ok(Some(res.into()))
98 }
99
100 #[tracing::instrument(
101 name = "db.compat_refresh_token.find_by_token",
102 skip_all,
103 fields(
104 db.query.text,
105 ),
106 err,
107 )]
108 async fn find_by_token(
109 &mut self,
110 refresh_token: &str,
111 ) -> Result<Option<CompatRefreshToken>, Self::Error> {
112 let res = sqlx::query_as!(
113 CompatRefreshTokenLookup,
114 r#"
115 SELECT compat_refresh_token_id
116 , refresh_token
117 , created_at
118 , consumed_at
119 , compat_session_id
120 , compat_access_token_id
121
122 FROM compat_refresh_tokens
123
124 WHERE refresh_token = $1
125 "#,
126 refresh_token,
127 )
128 .traced()
129 .fetch_optional(&mut *self.conn)
130 .await?;
131
132 let Some(res) = res else { return Ok(None) };
133
134 Ok(Some(res.into()))
135 }
136
137 #[tracing::instrument(
138 name = "db.compat_refresh_token.add",
139 skip_all,
140 fields(
141 db.query.text,
142 compat_refresh_token.id,
143 %compat_session.id,
144 user.id = %compat_session.user_id,
145 ),
146 err,
147 )]
148 async fn add(
149 &mut self,
150 rng: &mut (dyn RngCore + Send),
151 clock: &dyn Clock,
152 compat_session: &CompatSession,
153 compat_access_token: &CompatAccessToken,
154 token: String,
155 ) -> Result<CompatRefreshToken, Self::Error> {
156 let created_at = clock.now();
157 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
158 tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
159
160 sqlx::query!(
161 r#"
162 INSERT INTO compat_refresh_tokens
163 (compat_refresh_token_id, compat_session_id,
164 compat_access_token_id, refresh_token, created_at)
165 VALUES ($1, $2, $3, $4, $5)
166 "#,
167 Uuid::from(id),
168 Uuid::from(compat_session.id),
169 Uuid::from(compat_access_token.id),
170 token,
171 created_at,
172 )
173 .traced()
174 .execute(&mut *self.conn)
175 .await?;
176
177 Ok(CompatRefreshToken {
178 id,
179 state: CompatRefreshTokenState::default(),
180 session_id: compat_session.id,
181 access_token_id: compat_access_token.id,
182 token,
183 created_at,
184 })
185 }
186
187 #[tracing::instrument(
188 name = "db.compat_refresh_token.consume",
189 skip_all,
190 fields(
191 db.query.text,
192 %compat_refresh_token.id,
193 compat_session.id = %compat_refresh_token.session_id,
194 ),
195 err,
196 )]
197 async fn consume(
198 &mut self,
199 clock: &dyn Clock,
200 compat_refresh_token: CompatRefreshToken,
201 ) -> Result<CompatRefreshToken, Self::Error> {
202 let consumed_at = clock.now();
203 let res = sqlx::query!(
204 r#"
205 UPDATE compat_refresh_tokens
206 SET consumed_at = $2
207 WHERE compat_session_id = $1
208 AND consumed_at IS NULL
209 "#,
210 Uuid::from(compat_refresh_token.session_id),
211 consumed_at,
212 )
213 .traced()
214 .execute(&mut *self.conn)
215 .await?;
216
217 if res.rows_affected() == 0 {
221 return Err(DatabaseError::RowsAffected {
222 expected: 1,
223 actual: 0,
224 });
225 }
226
227 let compat_refresh_token = compat_refresh_token
228 .consume(consumed_at)
229 .map_err(DatabaseError::to_invalid_operation)?;
230
231 Ok(compat_refresh_token)
232 }
233}