mas_handlers/upstream_oauth2/
cookie.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7// TODO: move that to a standalone cookie manager
8
9use chrono::{DateTime, Duration, Utc};
10use mas_axum_utils::cookies::CookieJar;
11use mas_router::PostAuthAction;
12use mas_storage::Clock;
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15use ulid::Ulid;
16
17/// Name of the cookie
18static COOKIE_NAME: &str = "upstream-oauth2-sessions";
19
20/// Sessions expire after 10 minutes
21static SESSION_MAX_TIME: Duration = Duration::microseconds(10 * 60 * 1000 * 1000);
22
23#[derive(Serialize, Deserialize, Debug)]
24pub struct Payload {
25    session: Ulid,
26    provider: Ulid,
27    state: String,
28    link: Option<Ulid>,
29    post_auth_action: Option<PostAuthAction>,
30}
31
32impl Payload {
33    fn expired(&self, now: DateTime<Utc>) -> bool {
34        let Ok(ts) = self.session.timestamp_ms().try_into() else {
35            return true;
36        };
37        let Some(when) = DateTime::from_timestamp_millis(ts) else {
38            return true;
39        };
40        now - when > SESSION_MAX_TIME
41    }
42}
43
44#[derive(Serialize, Deserialize, Default, Debug)]
45pub struct UpstreamSessions(Vec<Payload>);
46
47#[derive(Debug, Error, PartialEq, Eq)]
48#[error("upstream session not found")]
49pub struct UpstreamSessionNotFound;
50
51impl UpstreamSessions {
52    /// Load the upstreams sessions cookie
53    pub fn load(cookie_jar: &CookieJar) -> Self {
54        match cookie_jar.load(COOKIE_NAME) {
55            Ok(Some(sessions)) => sessions,
56            Ok(None) => Self::default(),
57            Err(e) => {
58                tracing::warn!("Invalid upstream sessions cookie: {}", e);
59                Self::default()
60            }
61        }
62    }
63
64    /// Returns true if the cookie is empty
65    pub fn is_empty(&self) -> bool {
66        self.0.is_empty()
67    }
68
69    /// Save the upstreams sessions to the cookie jar
70    pub fn save<C>(self, cookie_jar: CookieJar, clock: &C) -> CookieJar
71    where
72        C: Clock,
73    {
74        let this = self.expire(clock.now());
75        cookie_jar.save(COOKIE_NAME, &this, false)
76    }
77
78    fn expire(mut self, now: DateTime<Utc>) -> Self {
79        self.0.retain(|p| !p.expired(now));
80        self
81    }
82
83    /// Add a new session, for a provider and a random state
84    pub fn add(
85        mut self,
86        session: Ulid,
87        provider: Ulid,
88        state: String,
89        post_auth_action: Option<PostAuthAction>,
90    ) -> Self {
91        self.0.push(Payload {
92            session,
93            provider,
94            state,
95            link: None,
96            post_auth_action,
97        });
98        self
99    }
100
101    // Find a session ID from the provider and the state
102    pub fn find_session(
103        &self,
104        provider: Ulid,
105        state: &str,
106    ) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
107        self.0
108            .iter()
109            .find(|p| p.provider == provider && p.state == state && p.link.is_none())
110            .map(|p| (p.session, p.post_auth_action.as_ref()))
111            .ok_or(UpstreamSessionNotFound)
112    }
113
114    /// Save the link generated by a session
115    pub fn add_link_to_session(
116        mut self,
117        session: Ulid,
118        link: Ulid,
119    ) -> Result<Self, UpstreamSessionNotFound> {
120        let payload = self
121            .0
122            .iter_mut()
123            .find(|p| p.session == session && p.link.is_none())
124            .ok_or(UpstreamSessionNotFound)?;
125
126        payload.link = Some(link);
127        Ok(self)
128    }
129
130    /// Find a session from its link
131    pub fn lookup_link(
132        &self,
133        link_id: Ulid,
134    ) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
135        self.0
136            .iter()
137            .filter(|p| p.link == Some(link_id))
138            // Find the session with the highest ID, aka. the most recent one
139            .reduce(|a, b| if a.session > b.session { a } else { b })
140            .map(|p| (p.session, p.post_auth_action.as_ref()))
141            .ok_or(UpstreamSessionNotFound)
142    }
143
144    /// Mark a link as consumed to avoid replay
145    pub fn consume_link(mut self, link_id: Ulid) -> Result<Self, UpstreamSessionNotFound> {
146        let pos = self
147            .0
148            .iter()
149            .position(|p| p.link == Some(link_id))
150            .ok_or(UpstreamSessionNotFound)?;
151
152        self.0.remove(pos);
153
154        Ok(self)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use chrono::TimeZone;
161    use rand::SeedableRng;
162    use rand_chacha::ChaChaRng;
163
164    use super::*;
165
166    #[test]
167    fn test_session_cookie() {
168        let now = chrono::Utc
169            .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
170            .unwrap();
171        let mut rng = ChaChaRng::seed_from_u64(42);
172
173        let sessions = UpstreamSessions::default();
174
175        let provider_a = Ulid::from_datetime_with_source(now.into(), &mut rng);
176        let provider_b = Ulid::from_datetime_with_source(now.into(), &mut rng);
177
178        let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
179        let first_state = "first-state";
180        let sessions = sessions.add(first_session, provider_a, first_state.into(), None);
181
182        let now = now + Duration::microseconds(5 * 60 * 1000 * 1000);
183
184        let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
185        let second_state = "second-state";
186        let sessions = sessions.add(second_session, provider_b, second_state.into(), None);
187
188        let sessions = sessions.expire(now);
189        assert_eq!(
190            sessions.find_session(provider_a, first_state).unwrap().0,
191            first_session,
192        );
193        assert_eq!(
194            sessions.find_session(provider_b, second_state).unwrap().0,
195            second_session
196        );
197        assert!(sessions.find_session(provider_b, first_state).is_err());
198        assert!(sessions.find_session(provider_a, second_state).is_err());
199
200        // Make the first session expire
201        let now = now + Duration::microseconds(6 * 60 * 1000 * 1000);
202        let sessions = sessions.expire(now);
203        assert!(sessions.find_session(provider_a, first_state).is_err());
204        assert_eq!(
205            sessions.find_session(provider_b, second_state).unwrap().0,
206            second_session
207        );
208
209        // Associate a link with the second
210        let second_link = Ulid::from_datetime_with_source(now.into(), &mut rng);
211        let sessions = sessions
212            .add_link_to_session(second_session, second_link)
213            .unwrap();
214
215        // Now the session can't be found with its state
216        assert!(sessions.find_session(provider_b, second_state).is_err());
217
218        // But it can be looked up by its link
219        assert_eq!(sessions.lookup_link(second_link).unwrap().0, second_session);
220        // And it can be consumed
221        let sessions = sessions.consume_link(second_link).unwrap();
222        // But only once
223        assert!(sessions.consume_link(second_link).is_err());
224    }
225}