1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User, UserAgent};
12use mas_storage::{
13 Clock, Page, Pagination,
14 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError, DatabaseInconsistencyError,
26 filter::{Filter, StatementExt},
27 iden::{OAuth2Clients, OAuth2Sessions},
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32pub struct PgOAuth2SessionRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgOAuth2SessionRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct OAuthSessionLookup {
48 oauth2_session_id: Uuid,
49 user_id: Option<Uuid>,
50 user_session_id: Option<Uuid>,
51 oauth2_client_id: Uuid,
52 scope_list: Vec<String>,
53 created_at: DateTime<Utc>,
54 finished_at: Option<DateTime<Utc>>,
55 user_agent: Option<String>,
56 last_active_at: Option<DateTime<Utc>>,
57 last_active_ip: Option<IpAddr>,
58}
59
60impl TryFrom<OAuthSessionLookup> for Session {
61 type Error = DatabaseInconsistencyError;
62
63 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
64 let id = Ulid::from(value.oauth2_session_id);
65 let scope: Result<Scope, _> = value
66 .scope_list
67 .iter()
68 .map(|s| s.parse::<ScopeToken>())
69 .collect();
70 let scope = scope.map_err(|e| {
71 DatabaseInconsistencyError::on("oauth2_sessions")
72 .column("scope")
73 .row(id)
74 .source(e)
75 })?;
76
77 let state = match value.finished_at {
78 None => SessionState::Valid,
79 Some(finished_at) => SessionState::Finished { finished_at },
80 };
81
82 Ok(Session {
83 id,
84 state,
85 created_at: value.created_at,
86 client_id: value.oauth2_client_id.into(),
87 user_id: value.user_id.map(Ulid::from),
88 user_session_id: value.user_session_id.map(Ulid::from),
89 scope,
90 user_agent: value.user_agent.map(UserAgent::parse),
91 last_active_at: value.last_active_at,
92 last_active_ip: value.last_active_ip,
93 })
94 }
95}
96
97impl Filter for OAuth2SessionFilter<'_> {
98 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
99 sea_query::Condition::all()
100 .add_option(self.user().map(|user| {
101 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
102 }))
103 .add_option(self.client().map(|client| {
104 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
105 .eq(Uuid::from(client.id))
106 }))
107 .add_option(self.client_kind().map(|client_kind| {
108 let static_clients = Query::select()
112 .expr(Expr::col((
113 OAuth2Clients::Table,
114 OAuth2Clients::OAuth2ClientId,
115 )))
116 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
117 .from(OAuth2Clients::Table)
118 .take();
119 if client_kind.is_static() {
120 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
121 .eq(Expr::any(static_clients))
122 } else {
123 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
124 .ne(Expr::all(static_clients))
125 }
126 }))
127 .add_option(self.device().map(|device| {
128 if let Ok(scope_token) = device.to_scope_token() {
129 Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
130 OAuth2Sessions::Table,
131 OAuth2Sessions::ScopeList,
132 ))))
133 } else {
134 Expr::val(false).into()
136 }
137 }))
138 .add_option(self.browser_session().map(|browser_session| {
139 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
140 .eq(Uuid::from(browser_session.id))
141 }))
142 .add_option(self.state().map(|state| {
143 if state.is_active() {
144 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
145 } else {
146 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
147 }
148 }))
149 .add_option(self.scope().map(|scope| {
150 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
151 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
152 }))
153 .add_option(self.any_user().map(|any_user| {
154 if any_user {
155 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
156 } else {
157 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
158 }
159 }))
160 .add_option(self.last_active_after().map(|last_active_after| {
161 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
162 .gt(last_active_after)
163 }))
164 .add_option(self.last_active_before().map(|last_active_before| {
165 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
166 .lt(last_active_before)
167 }))
168 }
169}
170
171#[async_trait]
172impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
173 type Error = DatabaseError;
174
175 #[tracing::instrument(
176 name = "db.oauth2_session.lookup",
177 skip_all,
178 fields(
179 db.query.text,
180 session.id = %id,
181 ),
182 err,
183 )]
184 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
185 let res = sqlx::query_as!(
186 OAuthSessionLookup,
187 r#"
188 SELECT oauth2_session_id
189 , user_id
190 , user_session_id
191 , oauth2_client_id
192 , scope_list
193 , created_at
194 , finished_at
195 , user_agent
196 , last_active_at
197 , last_active_ip as "last_active_ip: IpAddr"
198 FROM oauth2_sessions
199
200 WHERE oauth2_session_id = $1
201 "#,
202 Uuid::from(id),
203 )
204 .traced()
205 .fetch_optional(&mut *self.conn)
206 .await?;
207
208 let Some(session) = res else { return Ok(None) };
209
210 Ok(Some(session.try_into()?))
211 }
212
213 #[tracing::instrument(
214 name = "db.oauth2_session.add",
215 skip_all,
216 fields(
217 db.query.text,
218 %client.id,
219 session.id,
220 session.scope = %scope,
221 ),
222 err,
223 )]
224 async fn add(
225 &mut self,
226 rng: &mut (dyn RngCore + Send),
227 clock: &dyn Clock,
228 client: &Client,
229 user: Option<&User>,
230 user_session: Option<&BrowserSession>,
231 scope: Scope,
232 ) -> Result<Session, Self::Error> {
233 let created_at = clock.now();
234 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
235 tracing::Span::current().record("session.id", tracing::field::display(id));
236
237 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
238
239 sqlx::query!(
240 r#"
241 INSERT INTO oauth2_sessions
242 ( oauth2_session_id
243 , user_id
244 , user_session_id
245 , oauth2_client_id
246 , scope_list
247 , created_at
248 )
249 VALUES ($1, $2, $3, $4, $5, $6)
250 "#,
251 Uuid::from(id),
252 user.map(|u| Uuid::from(u.id)),
253 user_session.map(|s| Uuid::from(s.id)),
254 Uuid::from(client.id),
255 &scope_list,
256 created_at,
257 )
258 .traced()
259 .execute(&mut *self.conn)
260 .await?;
261
262 Ok(Session {
263 id,
264 state: SessionState::Valid,
265 created_at,
266 user_id: user.map(|u| u.id),
267 user_session_id: user_session.map(|s| s.id),
268 client_id: client.id,
269 scope,
270 user_agent: None,
271 last_active_at: None,
272 last_active_ip: None,
273 })
274 }
275
276 #[tracing::instrument(
277 name = "db.oauth2_session.finish_bulk",
278 skip_all,
279 fields(
280 db.query.text,
281 ),
282 err,
283 )]
284 async fn finish_bulk(
285 &mut self,
286 clock: &dyn Clock,
287 filter: OAuth2SessionFilter<'_>,
288 ) -> Result<usize, Self::Error> {
289 let finished_at = clock.now();
290 let (sql, arguments) = Query::update()
291 .table(OAuth2Sessions::Table)
292 .value(OAuth2Sessions::FinishedAt, finished_at)
293 .apply_filter(filter)
294 .build_sqlx(PostgresQueryBuilder);
295
296 let res = sqlx::query_with(&sql, arguments)
297 .traced()
298 .execute(&mut *self.conn)
299 .await?;
300
301 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
302 }
303
304 #[tracing::instrument(
305 name = "db.oauth2_session.finish",
306 skip_all,
307 fields(
308 db.query.text,
309 %session.id,
310 %session.scope,
311 client.id = %session.client_id,
312 ),
313 err,
314 )]
315 async fn finish(
316 &mut self,
317 clock: &dyn Clock,
318 session: Session,
319 ) -> Result<Session, Self::Error> {
320 let finished_at = clock.now();
321 let res = sqlx::query!(
322 r#"
323 UPDATE oauth2_sessions
324 SET finished_at = $2
325 WHERE oauth2_session_id = $1
326 "#,
327 Uuid::from(session.id),
328 finished_at,
329 )
330 .traced()
331 .execute(&mut *self.conn)
332 .await?;
333
334 DatabaseError::ensure_affected_rows(&res, 1)?;
335
336 session
337 .finish(finished_at)
338 .map_err(DatabaseError::to_invalid_operation)
339 }
340
341 #[tracing::instrument(
342 name = "db.oauth2_session.list",
343 skip_all,
344 fields(
345 db.query.text,
346 ),
347 err,
348 )]
349 async fn list(
350 &mut self,
351 filter: OAuth2SessionFilter<'_>,
352 pagination: Pagination,
353 ) -> Result<Page<Session>, Self::Error> {
354 let (sql, arguments) = Query::select()
355 .expr_as(
356 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
357 OAuthSessionLookupIden::Oauth2SessionId,
358 )
359 .expr_as(
360 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
361 OAuthSessionLookupIden::UserId,
362 )
363 .expr_as(
364 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
365 OAuthSessionLookupIden::UserSessionId,
366 )
367 .expr_as(
368 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
369 OAuthSessionLookupIden::Oauth2ClientId,
370 )
371 .expr_as(
372 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
373 OAuthSessionLookupIden::ScopeList,
374 )
375 .expr_as(
376 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
377 OAuthSessionLookupIden::CreatedAt,
378 )
379 .expr_as(
380 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
381 OAuthSessionLookupIden::FinishedAt,
382 )
383 .expr_as(
384 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
385 OAuthSessionLookupIden::UserAgent,
386 )
387 .expr_as(
388 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
389 OAuthSessionLookupIden::LastActiveAt,
390 )
391 .expr_as(
392 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
393 OAuthSessionLookupIden::LastActiveIp,
394 )
395 .from(OAuth2Sessions::Table)
396 .apply_filter(filter)
397 .generate_pagination(
398 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
399 pagination,
400 )
401 .build_sqlx(PostgresQueryBuilder);
402
403 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
404 .traced()
405 .fetch_all(&mut *self.conn)
406 .await?;
407
408 let page = pagination.process(edges).try_map(Session::try_from)?;
409
410 Ok(page)
411 }
412
413 #[tracing::instrument(
414 name = "db.oauth2_session.count",
415 skip_all,
416 fields(
417 db.query.text,
418 ),
419 err,
420 )]
421 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
422 let (sql, arguments) = Query::select()
423 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
424 .from(OAuth2Sessions::Table)
425 .apply_filter(filter)
426 .build_sqlx(PostgresQueryBuilder);
427
428 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
429 .traced()
430 .fetch_one(&mut *self.conn)
431 .await?;
432
433 count
434 .try_into()
435 .map_err(DatabaseError::to_invalid_operation)
436 }
437
438 #[tracing::instrument(
439 name = "db.oauth2_session.record_batch_activity",
440 skip_all,
441 fields(
442 db.query.text,
443 ),
444 err,
445 )]
446 async fn record_batch_activity(
447 &mut self,
448 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
449 ) -> Result<(), Self::Error> {
450 let mut ids = Vec::with_capacity(activity.len());
451 let mut last_activities = Vec::with_capacity(activity.len());
452 let mut ips = Vec::with_capacity(activity.len());
453
454 for (id, last_activity, ip) in activity {
455 ids.push(Uuid::from(id));
456 last_activities.push(last_activity);
457 ips.push(ip);
458 }
459
460 let res = sqlx::query!(
461 r#"
462 UPDATE oauth2_sessions
463 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
464 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
465 FROM (
466 SELECT *
467 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
468 AS t(oauth2_session_id, last_active_at, last_active_ip)
469 ) AS t
470 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
471 "#,
472 &ids,
473 &last_activities,
474 &ips as &[Option<IpAddr>],
475 )
476 .traced()
477 .execute(&mut *self.conn)
478 .await?;
479
480 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
481
482 Ok(())
483 }
484
485 #[tracing::instrument(
486 name = "db.oauth2_session.record_user_agent",
487 skip_all,
488 fields(
489 db.query.text,
490 %session.id,
491 %session.scope,
492 client.id = %session.client_id,
493 session.user_agent = %user_agent.raw,
494 ),
495 err,
496 )]
497 async fn record_user_agent(
498 &mut self,
499 mut session: Session,
500 user_agent: UserAgent,
501 ) -> Result<Session, Self::Error> {
502 let res = sqlx::query!(
503 r#"
504 UPDATE oauth2_sessions
505 SET user_agent = $2
506 WHERE oauth2_session_id = $1
507 "#,
508 Uuid::from(session.id),
509 &*user_agent,
510 )
511 .traced()
512 .execute(&mut *self.conn)
513 .await?;
514
515 session.user_agent = Some(user_agent);
516
517 DatabaseError::ensure_affected_rows(&res, 1)?;
518
519 Ok(session)
520 }
521}