1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
11 UpstreamOAuthProvider,
12};
13use mas_storage::{Clock, upstream_oauth2::UpstreamOAuthSessionRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
20
21pub struct PgUpstreamOAuthSessionRepository<'c> {
24 conn: &'c mut PgConnection,
25}
26
27impl<'c> PgUpstreamOAuthSessionRepository<'c> {
28 pub fn new(conn: &'c mut PgConnection) -> Self {
31 Self { conn }
32 }
33}
34
35struct SessionLookup {
36 upstream_oauth_authorization_session_id: Uuid,
37 upstream_oauth_provider_id: Uuid,
38 upstream_oauth_link_id: Option<Uuid>,
39 state: String,
40 code_challenge_verifier: Option<String>,
41 nonce: String,
42 id_token: Option<String>,
43 userinfo: Option<serde_json::Value>,
44 created_at: DateTime<Utc>,
45 completed_at: Option<DateTime<Utc>>,
46 consumed_at: Option<DateTime<Utc>>,
47 extra_callback_parameters: Option<serde_json::Value>,
48}
49
50impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
51 type Error = DatabaseInconsistencyError;
52
53 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
54 let id = value.upstream_oauth_authorization_session_id.into();
55 let state = match (
56 value.upstream_oauth_link_id,
57 value.id_token,
58 value.extra_callback_parameters,
59 value.userinfo,
60 value.completed_at,
61 value.consumed_at,
62 ) {
63 (None, None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
64 (
65 Some(link_id),
66 id_token,
67 extra_callback_parameters,
68 userinfo,
69 Some(completed_at),
70 None,
71 ) => UpstreamOAuthAuthorizationSessionState::Completed {
72 completed_at,
73 link_id: link_id.into(),
74 id_token,
75 extra_callback_parameters,
76 userinfo,
77 },
78 (
79 Some(link_id),
80 id_token,
81 extra_callback_parameters,
82 userinfo,
83 Some(completed_at),
84 Some(consumed_at),
85 ) => UpstreamOAuthAuthorizationSessionState::Consumed {
86 completed_at,
87 link_id: link_id.into(),
88 id_token,
89 extra_callback_parameters,
90 userinfo,
91 consumed_at,
92 },
93 _ => {
94 return Err(DatabaseInconsistencyError::on(
95 "upstream_oauth_authorization_sessions",
96 )
97 .row(id));
98 }
99 };
100
101 Ok(Self {
102 id,
103 provider_id: value.upstream_oauth_provider_id.into(),
104 state_str: value.state,
105 nonce: value.nonce,
106 code_challenge_verifier: value.code_challenge_verifier,
107 created_at: value.created_at,
108 state,
109 })
110 }
111}
112
113#[async_trait]
114impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
115 type Error = DatabaseError;
116
117 #[tracing::instrument(
118 name = "db.upstream_oauth_authorization_session.lookup",
119 skip_all,
120 fields(
121 db.query.text,
122 upstream_oauth_provider.id = %id,
123 ),
124 err,
125 )]
126 async fn lookup(
127 &mut self,
128 id: Ulid,
129 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
130 let res = sqlx::query_as!(
131 SessionLookup,
132 r#"
133 SELECT
134 upstream_oauth_authorization_session_id,
135 upstream_oauth_provider_id,
136 upstream_oauth_link_id,
137 state,
138 code_challenge_verifier,
139 nonce,
140 id_token,
141 extra_callback_parameters,
142 userinfo,
143 created_at,
144 completed_at,
145 consumed_at
146 FROM upstream_oauth_authorization_sessions
147 WHERE upstream_oauth_authorization_session_id = $1
148 "#,
149 Uuid::from(id),
150 )
151 .traced()
152 .fetch_optional(&mut *self.conn)
153 .await?;
154
155 let Some(res) = res else { return Ok(None) };
156
157 Ok(Some(res.try_into()?))
158 }
159
160 #[tracing::instrument(
161 name = "db.upstream_oauth_authorization_session.add",
162 skip_all,
163 fields(
164 db.query.text,
165 %upstream_oauth_provider.id,
166 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
167 %upstream_oauth_provider.client_id,
168 upstream_oauth_authorization_session.id,
169 ),
170 err,
171 )]
172 async fn add(
173 &mut self,
174 rng: &mut (dyn RngCore + Send),
175 clock: &dyn Clock,
176 upstream_oauth_provider: &UpstreamOAuthProvider,
177 state_str: String,
178 code_challenge_verifier: Option<String>,
179 nonce: String,
180 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
181 let created_at = clock.now();
182 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
183 tracing::Span::current().record(
184 "upstream_oauth_authorization_session.id",
185 tracing::field::display(id),
186 );
187
188 sqlx::query!(
189 r#"
190 INSERT INTO upstream_oauth_authorization_sessions (
191 upstream_oauth_authorization_session_id,
192 upstream_oauth_provider_id,
193 state,
194 code_challenge_verifier,
195 nonce,
196 created_at,
197 completed_at,
198 consumed_at,
199 id_token,
200 userinfo
201 ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
202 "#,
203 Uuid::from(id),
204 Uuid::from(upstream_oauth_provider.id),
205 &state_str,
206 code_challenge_verifier.as_deref(),
207 nonce,
208 created_at,
209 )
210 .traced()
211 .execute(&mut *self.conn)
212 .await?;
213
214 Ok(UpstreamOAuthAuthorizationSession {
215 id,
216 state: UpstreamOAuthAuthorizationSessionState::default(),
217 provider_id: upstream_oauth_provider.id,
218 state_str,
219 code_challenge_verifier,
220 nonce,
221 created_at,
222 })
223 }
224
225 #[tracing::instrument(
226 name = "db.upstream_oauth_authorization_session.complete_with_link",
227 skip_all,
228 fields(
229 db.query.text,
230 %upstream_oauth_authorization_session.id,
231 %upstream_oauth_link.id,
232 ),
233 err,
234 )]
235 async fn complete_with_link(
236 &mut self,
237 clock: &dyn Clock,
238 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
239 upstream_oauth_link: &UpstreamOAuthLink,
240 id_token: Option<String>,
241 extra_callback_parameters: Option<serde_json::Value>,
242 userinfo: Option<serde_json::Value>,
243 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
244 let completed_at = clock.now();
245
246 sqlx::query!(
247 r#"
248 UPDATE upstream_oauth_authorization_sessions
249 SET upstream_oauth_link_id = $1,
250 completed_at = $2,
251 id_token = $3,
252 extra_callback_parameters = $4,
253 userinfo = $5
254 WHERE upstream_oauth_authorization_session_id = $6
255 "#,
256 Uuid::from(upstream_oauth_link.id),
257 completed_at,
258 id_token,
259 extra_callback_parameters,
260 userinfo,
261 Uuid::from(upstream_oauth_authorization_session.id),
262 )
263 .traced()
264 .execute(&mut *self.conn)
265 .await?;
266
267 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
268 .complete(
269 completed_at,
270 upstream_oauth_link,
271 id_token,
272 extra_callback_parameters,
273 userinfo,
274 )
275 .map_err(DatabaseError::to_invalid_operation)?;
276
277 Ok(upstream_oauth_authorization_session)
278 }
279
280 #[tracing::instrument(
282 name = "db.upstream_oauth_authorization_session.consume",
283 skip_all,
284 fields(
285 db.query.text,
286 %upstream_oauth_authorization_session.id,
287 ),
288 err,
289 )]
290 async fn consume(
291 &mut self,
292 clock: &dyn Clock,
293 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
294 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
295 let consumed_at = clock.now();
296 sqlx::query!(
297 r#"
298 UPDATE upstream_oauth_authorization_sessions
299 SET consumed_at = $1
300 WHERE upstream_oauth_authorization_session_id = $2
301 "#,
302 consumed_at,
303 Uuid::from(upstream_oauth_authorization_session.id),
304 )
305 .traced()
306 .execute(&mut *self.conn)
307 .await?;
308
309 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
310 .consume(consumed_at)
311 .map_err(DatabaseError::to_invalid_operation)?;
312
313 Ok(upstream_oauth_authorization_session)
314 }
315}