mas_handlers/upstream_oauth2/
cookie.rs1use chrono::{DateTime, Duration, Utc};
10use mas_axum_utils::cookies::CookieJar;
11use mas_router::PostAuthAction;
12use mas_storage::Clock;
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15use ulid::Ulid;
16
17static COOKIE_NAME: &str = "upstream-oauth2-sessions";
19
20static SESSION_MAX_TIME: Duration = Duration::microseconds(10 * 60 * 1000 * 1000);
22
23#[derive(Serialize, Deserialize, Debug)]
24pub struct Payload {
25 session: Ulid,
26 provider: Ulid,
27 state: String,
28 link: Option<Ulid>,
29 post_auth_action: Option<PostAuthAction>,
30}
31
32impl Payload {
33 fn expired(&self, now: DateTime<Utc>) -> bool {
34 let Ok(ts) = self.session.timestamp_ms().try_into() else {
35 return true;
36 };
37 let Some(when) = DateTime::from_timestamp_millis(ts) else {
38 return true;
39 };
40 now - when > SESSION_MAX_TIME
41 }
42}
43
44#[derive(Serialize, Deserialize, Default, Debug)]
45pub struct UpstreamSessions(Vec<Payload>);
46
47#[derive(Debug, Error, PartialEq, Eq)]
48#[error("upstream session not found")]
49pub struct UpstreamSessionNotFound;
50
51impl UpstreamSessions {
52 pub fn load(cookie_jar: &CookieJar) -> Self {
54 match cookie_jar.load(COOKIE_NAME) {
55 Ok(Some(sessions)) => sessions,
56 Ok(None) => Self::default(),
57 Err(e) => {
58 tracing::warn!("Invalid upstream sessions cookie: {}", e);
59 Self::default()
60 }
61 }
62 }
63
64 pub fn is_empty(&self) -> bool {
66 self.0.is_empty()
67 }
68
69 pub fn save<C>(self, cookie_jar: CookieJar, clock: &C) -> CookieJar
71 where
72 C: Clock,
73 {
74 let this = self.expire(clock.now());
75 cookie_jar.save(COOKIE_NAME, &this, false)
76 }
77
78 fn expire(mut self, now: DateTime<Utc>) -> Self {
79 self.0.retain(|p| !p.expired(now));
80 self
81 }
82
83 pub fn add(
85 mut self,
86 session: Ulid,
87 provider: Ulid,
88 state: String,
89 post_auth_action: Option<PostAuthAction>,
90 ) -> Self {
91 self.0.push(Payload {
92 session,
93 provider,
94 state,
95 link: None,
96 post_auth_action,
97 });
98 self
99 }
100
101 pub fn find_session(
103 &self,
104 provider: Ulid,
105 state: &str,
106 ) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
107 self.0
108 .iter()
109 .find(|p| p.provider == provider && p.state == state && p.link.is_none())
110 .map(|p| (p.session, p.post_auth_action.as_ref()))
111 .ok_or(UpstreamSessionNotFound)
112 }
113
114 pub fn add_link_to_session(
116 mut self,
117 session: Ulid,
118 link: Ulid,
119 ) -> Result<Self, UpstreamSessionNotFound> {
120 let payload = self
121 .0
122 .iter_mut()
123 .find(|p| p.session == session && p.link.is_none())
124 .ok_or(UpstreamSessionNotFound)?;
125
126 payload.link = Some(link);
127 Ok(self)
128 }
129
130 pub fn lookup_link(
132 &self,
133 link_id: Ulid,
134 ) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
135 self.0
136 .iter()
137 .filter(|p| p.link == Some(link_id))
138 .reduce(|a, b| if a.session > b.session { a } else { b })
140 .map(|p| (p.session, p.post_auth_action.as_ref()))
141 .ok_or(UpstreamSessionNotFound)
142 }
143
144 pub fn consume_link(mut self, link_id: Ulid) -> Result<Self, UpstreamSessionNotFound> {
146 let pos = self
147 .0
148 .iter()
149 .position(|p| p.link == Some(link_id))
150 .ok_or(UpstreamSessionNotFound)?;
151
152 self.0.remove(pos);
153
154 Ok(self)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use chrono::TimeZone;
161 use rand::SeedableRng;
162 use rand_chacha::ChaChaRng;
163
164 use super::*;
165
166 #[test]
167 fn test_session_cookie() {
168 let now = chrono::Utc
169 .with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
170 .unwrap();
171 let mut rng = ChaChaRng::seed_from_u64(42);
172
173 let sessions = UpstreamSessions::default();
174
175 let provider_a = Ulid::from_datetime_with_source(now.into(), &mut rng);
176 let provider_b = Ulid::from_datetime_with_source(now.into(), &mut rng);
177
178 let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
179 let first_state = "first-state";
180 let sessions = sessions.add(first_session, provider_a, first_state.into(), None);
181
182 let now = now + Duration::microseconds(5 * 60 * 1000 * 1000);
183
184 let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
185 let second_state = "second-state";
186 let sessions = sessions.add(second_session, provider_b, second_state.into(), None);
187
188 let sessions = sessions.expire(now);
189 assert_eq!(
190 sessions.find_session(provider_a, first_state).unwrap().0,
191 first_session,
192 );
193 assert_eq!(
194 sessions.find_session(provider_b, second_state).unwrap().0,
195 second_session
196 );
197 assert!(sessions.find_session(provider_b, first_state).is_err());
198 assert!(sessions.find_session(provider_a, second_state).is_err());
199
200 let now = now + Duration::microseconds(6 * 60 * 1000 * 1000);
202 let sessions = sessions.expire(now);
203 assert!(sessions.find_session(provider_a, first_state).is_err());
204 assert_eq!(
205 sessions.find_session(provider_b, second_state).unwrap().0,
206 second_session
207 );
208
209 let second_link = Ulid::from_datetime_with_source(now.into(), &mut rng);
211 let sessions = sessions
212 .add_link_to_session(second_session, second_link)
213 .unwrap();
214
215 assert!(sessions.find_session(provider_b, second_state).is_err());
217
218 assert_eq!(sessions.lookup_link(second_link).unwrap().0, second_session);
220 let sessions = sessions.consume_link(second_link).unwrap();
222 assert!(sessions.consume_link(second_link).is_err());
224 }
225}