mas_storage_pg/upstream_oauth2/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
11    UpstreamOAuthProvider,
12};
13use mas_storage::{Clock, upstream_oauth2::UpstreamOAuthSessionRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
20
21/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL
22/// connection
23pub struct PgUpstreamOAuthSessionRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgUpstreamOAuthSessionRepository<'c> {
28    /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active
29    /// PostgreSQL connection
30    pub fn new(conn: &'c mut PgConnection) -> Self {
31        Self { conn }
32    }
33}
34
35struct SessionLookup {
36    upstream_oauth_authorization_session_id: Uuid,
37    upstream_oauth_provider_id: Uuid,
38    upstream_oauth_link_id: Option<Uuid>,
39    state: String,
40    code_challenge_verifier: Option<String>,
41    nonce: String,
42    id_token: Option<String>,
43    userinfo: Option<serde_json::Value>,
44    created_at: DateTime<Utc>,
45    completed_at: Option<DateTime<Utc>>,
46    consumed_at: Option<DateTime<Utc>>,
47    extra_callback_parameters: Option<serde_json::Value>,
48}
49
50impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
51    type Error = DatabaseInconsistencyError;
52
53    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
54        let id = value.upstream_oauth_authorization_session_id.into();
55        let state = match (
56            value.upstream_oauth_link_id,
57            value.id_token,
58            value.extra_callback_parameters,
59            value.userinfo,
60            value.completed_at,
61            value.consumed_at,
62        ) {
63            (None, None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
64            (
65                Some(link_id),
66                id_token,
67                extra_callback_parameters,
68                userinfo,
69                Some(completed_at),
70                None,
71            ) => UpstreamOAuthAuthorizationSessionState::Completed {
72                completed_at,
73                link_id: link_id.into(),
74                id_token,
75                extra_callback_parameters,
76                userinfo,
77            },
78            (
79                Some(link_id),
80                id_token,
81                extra_callback_parameters,
82                userinfo,
83                Some(completed_at),
84                Some(consumed_at),
85            ) => UpstreamOAuthAuthorizationSessionState::Consumed {
86                completed_at,
87                link_id: link_id.into(),
88                id_token,
89                extra_callback_parameters,
90                userinfo,
91                consumed_at,
92            },
93            _ => {
94                return Err(DatabaseInconsistencyError::on(
95                    "upstream_oauth_authorization_sessions",
96                )
97                .row(id));
98            }
99        };
100
101        Ok(Self {
102            id,
103            provider_id: value.upstream_oauth_provider_id.into(),
104            state_str: value.state,
105            nonce: value.nonce,
106            code_challenge_verifier: value.code_challenge_verifier,
107            created_at: value.created_at,
108            state,
109        })
110    }
111}
112
113#[async_trait]
114impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
115    type Error = DatabaseError;
116
117    #[tracing::instrument(
118        name = "db.upstream_oauth_authorization_session.lookup",
119        skip_all,
120        fields(
121            db.query.text,
122            upstream_oauth_provider.id = %id,
123        ),
124        err,
125    )]
126    async fn lookup(
127        &mut self,
128        id: Ulid,
129    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
130        let res = sqlx::query_as!(
131            SessionLookup,
132            r#"
133                SELECT
134                    upstream_oauth_authorization_session_id,
135                    upstream_oauth_provider_id,
136                    upstream_oauth_link_id,
137                    state,
138                    code_challenge_verifier,
139                    nonce,
140                    id_token,
141                    extra_callback_parameters,
142                    userinfo,
143                    created_at,
144                    completed_at,
145                    consumed_at
146                FROM upstream_oauth_authorization_sessions
147                WHERE upstream_oauth_authorization_session_id = $1
148            "#,
149            Uuid::from(id),
150        )
151        .traced()
152        .fetch_optional(&mut *self.conn)
153        .await?;
154
155        let Some(res) = res else { return Ok(None) };
156
157        Ok(Some(res.try_into()?))
158    }
159
160    #[tracing::instrument(
161        name = "db.upstream_oauth_authorization_session.add",
162        skip_all,
163        fields(
164            db.query.text,
165            %upstream_oauth_provider.id,
166            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
167            %upstream_oauth_provider.client_id,
168            upstream_oauth_authorization_session.id,
169        ),
170        err,
171    )]
172    async fn add(
173        &mut self,
174        rng: &mut (dyn RngCore + Send),
175        clock: &dyn Clock,
176        upstream_oauth_provider: &UpstreamOAuthProvider,
177        state_str: String,
178        code_challenge_verifier: Option<String>,
179        nonce: String,
180    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
181        let created_at = clock.now();
182        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
183        tracing::Span::current().record(
184            "upstream_oauth_authorization_session.id",
185            tracing::field::display(id),
186        );
187
188        sqlx::query!(
189            r#"
190                INSERT INTO upstream_oauth_authorization_sessions (
191                    upstream_oauth_authorization_session_id,
192                    upstream_oauth_provider_id,
193                    state,
194                    code_challenge_verifier,
195                    nonce,
196                    created_at,
197                    completed_at,
198                    consumed_at,
199                    id_token,
200                    userinfo
201                ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
202            "#,
203            Uuid::from(id),
204            Uuid::from(upstream_oauth_provider.id),
205            &state_str,
206            code_challenge_verifier.as_deref(),
207            nonce,
208            created_at,
209        )
210        .traced()
211        .execute(&mut *self.conn)
212        .await?;
213
214        Ok(UpstreamOAuthAuthorizationSession {
215            id,
216            state: UpstreamOAuthAuthorizationSessionState::default(),
217            provider_id: upstream_oauth_provider.id,
218            state_str,
219            code_challenge_verifier,
220            nonce,
221            created_at,
222        })
223    }
224
225    #[tracing::instrument(
226        name = "db.upstream_oauth_authorization_session.complete_with_link",
227        skip_all,
228        fields(
229            db.query.text,
230            %upstream_oauth_authorization_session.id,
231            %upstream_oauth_link.id,
232        ),
233        err,
234    )]
235    async fn complete_with_link(
236        &mut self,
237        clock: &dyn Clock,
238        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
239        upstream_oauth_link: &UpstreamOAuthLink,
240        id_token: Option<String>,
241        extra_callback_parameters: Option<serde_json::Value>,
242        userinfo: Option<serde_json::Value>,
243    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
244        let completed_at = clock.now();
245
246        sqlx::query!(
247            r#"
248                UPDATE upstream_oauth_authorization_sessions
249                SET upstream_oauth_link_id = $1,
250                    completed_at = $2,
251                    id_token = $3,
252                    extra_callback_parameters = $4,
253                    userinfo = $5
254                WHERE upstream_oauth_authorization_session_id = $6
255            "#,
256            Uuid::from(upstream_oauth_link.id),
257            completed_at,
258            id_token,
259            extra_callback_parameters,
260            userinfo,
261            Uuid::from(upstream_oauth_authorization_session.id),
262        )
263        .traced()
264        .execute(&mut *self.conn)
265        .await?;
266
267        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
268            .complete(
269                completed_at,
270                upstream_oauth_link,
271                id_token,
272                extra_callback_parameters,
273                userinfo,
274            )
275            .map_err(DatabaseError::to_invalid_operation)?;
276
277        Ok(upstream_oauth_authorization_session)
278    }
279
280    /// Mark a session as consumed
281    #[tracing::instrument(
282        name = "db.upstream_oauth_authorization_session.consume",
283        skip_all,
284        fields(
285            db.query.text,
286            %upstream_oauth_authorization_session.id,
287        ),
288        err,
289    )]
290    async fn consume(
291        &mut self,
292        clock: &dyn Clock,
293        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
294    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
295        let consumed_at = clock.now();
296        sqlx::query!(
297            r#"
298                UPDATE upstream_oauth_authorization_sessions
299                SET consumed_at = $1
300                WHERE upstream_oauth_authorization_session_id = $2
301            "#,
302            consumed_at,
303            Uuid::from(upstream_oauth_authorization_session.id),
304        )
305        .traced()
306        .execute(&mut *self.conn)
307        .await?;
308
309        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
310            .consume(consumed_at)
311            .map_err(DatabaseError::to_invalid_operation)?;
312
313        Ok(upstream_oauth_authorization_session)
314    }
315}