1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Duration, Utc};
11use mas_data_model::{UserAgent, UserEmail, UserRecoverySession, UserRecoveryTicket};
12use mas_storage::{Clock, user::UserRecoveryRepository};
13use rand::RngCore;
14use sqlx::PgConnection;
15use ulid::Ulid;
16use uuid::Uuid;
17
18use crate::{DatabaseError, ExecuteExt};
19
20pub struct PgUserRecoveryRepository<'c> {
22 conn: &'c mut PgConnection,
23}
24
25impl<'c> PgUserRecoveryRepository<'c> {
26 pub fn new(conn: &'c mut PgConnection) -> Self {
29 Self { conn }
30 }
31}
32
33struct UserRecoverySessionRow {
34 user_recovery_session_id: Uuid,
35 email: String,
36 user_agent: String,
37 ip_address: Option<IpAddr>,
38 locale: String,
39 created_at: DateTime<Utc>,
40 consumed_at: Option<DateTime<Utc>>,
41}
42
43impl From<UserRecoverySessionRow> for UserRecoverySession {
44 fn from(row: UserRecoverySessionRow) -> Self {
45 UserRecoverySession {
46 id: row.user_recovery_session_id.into(),
47 email: row.email,
48 user_agent: UserAgent::parse(row.user_agent),
49 ip_address: row.ip_address,
50 locale: row.locale,
51 created_at: row.created_at,
52 consumed_at: row.consumed_at,
53 }
54 }
55}
56
57struct UserRecoveryTicketRow {
58 user_recovery_ticket_id: Uuid,
59 user_recovery_session_id: Uuid,
60 user_email_id: Uuid,
61 ticket: String,
62 created_at: DateTime<Utc>,
63 expires_at: DateTime<Utc>,
64}
65
66impl From<UserRecoveryTicketRow> for UserRecoveryTicket {
67 fn from(row: UserRecoveryTicketRow) -> Self {
68 Self {
69 id: row.user_recovery_ticket_id.into(),
70 user_recovery_session_id: row.user_recovery_session_id.into(),
71 user_email_id: row.user_email_id.into(),
72 ticket: row.ticket,
73 created_at: row.created_at,
74 expires_at: row.expires_at,
75 }
76 }
77}
78
79#[async_trait]
80impl UserRecoveryRepository for PgUserRecoveryRepository<'_> {
81 type Error = DatabaseError;
82
83 #[tracing::instrument(
84 name = "db.user_recovery.lookup_session",
85 skip_all,
86 fields(
87 db.query.text,
88 user_recovery_session.id = %id,
89 ),
90 err,
91 )]
92 async fn lookup_session(
93 &mut self,
94 id: Ulid,
95 ) -> Result<Option<UserRecoverySession>, Self::Error> {
96 let row = sqlx::query_as!(
97 UserRecoverySessionRow,
98 r#"
99 SELECT
100 user_recovery_session_id
101 , email
102 , user_agent
103 , ip_address as "ip_address: IpAddr"
104 , locale
105 , created_at
106 , consumed_at
107 FROM user_recovery_sessions
108 WHERE user_recovery_session_id = $1
109 "#,
110 Uuid::from(id),
111 )
112 .traced()
113 .fetch_optional(&mut *self.conn)
114 .await?;
115
116 let Some(row) = row else {
117 return Ok(None);
118 };
119
120 Ok(Some(row.into()))
121 }
122
123 #[tracing::instrument(
124 name = "db.user_recovery.add_session",
125 skip_all,
126 fields(
127 db.query.text,
128 user_recovery_session.id,
129 user_recovery_session.email = email,
130 user_recovery_session.user_agent = &*user_agent,
131 user_recovery_session.ip_address = ip_address.map(|ip| ip.to_string()),
132 )
133 )]
134 async fn add_session(
135 &mut self,
136 rng: &mut (dyn RngCore + Send),
137 clock: &dyn Clock,
138 email: String,
139 user_agent: UserAgent,
140 ip_address: Option<IpAddr>,
141 locale: String,
142 ) -> Result<UserRecoverySession, Self::Error> {
143 let created_at = clock.now();
144 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
145 tracing::Span::current().record("user_recovery_session.id", tracing::field::display(id));
146 sqlx::query!(
147 r#"
148 INSERT INTO user_recovery_sessions (
149 user_recovery_session_id
150 , email
151 , user_agent
152 , ip_address
153 , locale
154 , created_at
155 )
156 VALUES ($1, $2, $3, $4, $5, $6)
157 "#,
158 Uuid::from(id),
159 &email,
160 &*user_agent,
161 ip_address as Option<IpAddr>,
162 &locale,
163 created_at,
164 )
165 .traced()
166 .execute(&mut *self.conn)
167 .await?;
168
169 let user_recovery_session = UserRecoverySession {
170 id,
171 email,
172 user_agent,
173 ip_address,
174 locale,
175 created_at,
176 consumed_at: None,
177 };
178
179 Ok(user_recovery_session)
180 }
181
182 #[tracing::instrument(
183 name = "db.user_recovery.find_ticket",
184 skip_all,
185 fields(
186 db.query.text,
187 user_recovery_ticket.id = ticket,
188 ),
189 err,
190 )]
191 async fn find_ticket(
192 &mut self,
193 ticket: &str,
194 ) -> Result<Option<UserRecoveryTicket>, Self::Error> {
195 let row = sqlx::query_as!(
196 UserRecoveryTicketRow,
197 r#"
198 SELECT
199 user_recovery_ticket_id
200 , user_recovery_session_id
201 , user_email_id
202 , ticket
203 , created_at
204 , expires_at
205 FROM user_recovery_tickets
206 WHERE ticket = $1
207 "#,
208 ticket,
209 )
210 .traced()
211 .fetch_optional(&mut *self.conn)
212 .await?;
213
214 let Some(row) = row else {
215 return Ok(None);
216 };
217
218 Ok(Some(row.into()))
219 }
220
221 #[tracing::instrument(
222 name = "db.user_recovery.add_ticket",
223 skip_all,
224 fields(
225 db.query.text,
226 user_recovery_ticket.id,
227 user_recovery_ticket.id = ticket,
228 %user_recovery_session.id,
229 %user_email.id,
230 )
231 )]
232 async fn add_ticket(
233 &mut self,
234 rng: &mut (dyn RngCore + Send),
235 clock: &dyn Clock,
236 user_recovery_session: &UserRecoverySession,
237 user_email: &UserEmail,
238 ticket: String,
239 ) -> Result<UserRecoveryTicket, Self::Error> {
240 let created_at = clock.now();
241 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
242 tracing::Span::current().record("user_recovery_ticket.id", tracing::field::display(id));
243
244 let expires_at = created_at + Duration::minutes(10);
246
247 sqlx::query!(
248 r#"
249 INSERT INTO user_recovery_tickets (
250 user_recovery_ticket_id
251 , user_recovery_session_id
252 , user_email_id
253 , ticket
254 , created_at
255 , expires_at
256 )
257 VALUES ($1, $2, $3, $4, $5, $6)
258 "#,
259 Uuid::from(id),
260 Uuid::from(user_recovery_session.id),
261 Uuid::from(user_email.id),
262 &ticket,
263 created_at,
264 expires_at,
265 )
266 .traced()
267 .execute(&mut *self.conn)
268 .await?;
269
270 let ticket = UserRecoveryTicket {
271 id,
272 user_recovery_session_id: user_recovery_session.id,
273 user_email_id: user_email.id,
274 ticket,
275 created_at,
276 expires_at,
277 };
278
279 Ok(ticket)
280 }
281
282 #[tracing::instrument(
283 name = "db.user_recovery.consume_ticket",
284 skip_all,
285 fields(
286 db.query.text,
287 %user_recovery_ticket.id,
288 user_email.id = %user_recovery_ticket.user_email_id,
289 %user_recovery_session.id,
290 %user_recovery_session.email,
291 ),
292 err,
293 )]
294 async fn consume_ticket(
295 &mut self,
296 clock: &dyn Clock,
297 user_recovery_ticket: UserRecoveryTicket,
298 mut user_recovery_session: UserRecoverySession,
299 ) -> Result<UserRecoverySession, Self::Error> {
300 let _ = user_recovery_ticket;
302
303 if user_recovery_session.consumed_at.is_some() {
305 return Err(DatabaseError::invalid_operation());
306 }
307
308 let consumed_at = clock.now();
309
310 let res = sqlx::query!(
311 r#"
312 UPDATE user_recovery_sessions
313 SET consumed_at = $1
314 WHERE user_recovery_session_id = $2
315 "#,
316 consumed_at,
317 Uuid::from(user_recovery_session.id),
318 )
319 .traced()
320 .execute(&mut *self.conn)
321 .await?;
322
323 user_recovery_session.consumed_at = Some(consumed_at);
324
325 DatabaseError::ensure_affected_rows(&res, 1)?;
326
327 Ok(user_recovery_session)
328 }
329}