mas_handlers/views/register/
cookie.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6// TODO: move that to a standalone cookie manager
7
8use std::collections::BTreeSet;
9
10use chrono::{DateTime, Duration, Utc};
11use mas_axum_utils::cookies::CookieJar;
12use mas_data_model::UserRegistration;
13use mas_storage::Clock;
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use ulid::Ulid;
17
18/// Name of the cookie
19static COOKIE_NAME: &str = "user-registration-sessions";
20
21/// Sessions expire after an hour
22static SESSION_MAX_TIME: Duration = Duration::hours(1);
23
24/// The content of the cookie, which stores a list of user registration IDs
25#[derive(Serialize, Deserialize, Default, Debug)]
26pub struct UserRegistrationSessions(BTreeSet<Ulid>);
27
28#[derive(Debug, Error, PartialEq, Eq)]
29#[error("user registration session not found")]
30pub struct UserRegistrationSessionNotFound;
31
32impl UserRegistrationSessions {
33    /// Load the user registration sessions cookie
34    pub fn load(cookie_jar: &CookieJar) -> Self {
35        match cookie_jar.load(COOKIE_NAME) {
36            Ok(Some(sessions)) => sessions,
37            Ok(None) => Self::default(),
38            Err(e) => {
39                tracing::warn!(
40                    error = &e as &dyn std::error::Error,
41                    "Invalid upstream sessions cookie"
42                );
43                Self::default()
44            }
45        }
46    }
47
48    /// Returns true if the cookie is empty
49    pub fn is_empty(&self) -> bool {
50        self.0.is_empty()
51    }
52
53    /// Save the user registration sessions to the cookie jar
54    pub fn save<C>(self, cookie_jar: CookieJar, clock: &C) -> CookieJar
55    where
56        C: Clock,
57    {
58        let this = self.expire(clock.now());
59
60        if this.is_empty() {
61            cookie_jar.remove(COOKIE_NAME)
62        } else {
63            cookie_jar.save(COOKIE_NAME, &this, false)
64        }
65    }
66
67    fn expire(mut self, now: DateTime<Utc>) -> Self {
68        self.0.retain(|id| {
69            let Ok(ts) = id.timestamp_ms().try_into() else {
70                return false;
71            };
72            let Some(when) = DateTime::from_timestamp_millis(ts) else {
73                return false;
74            };
75            now - when < SESSION_MAX_TIME
76        });
77
78        self
79    }
80
81    /// Add a new session, for a provider and a random state
82    pub fn add(mut self, user_registration: &UserRegistration) -> Self {
83        self.0.insert(user_registration.id);
84        self
85    }
86
87    /// Check if the session is in the list
88    pub fn contains(&self, user_registration: &UserRegistration) -> bool {
89        self.0.contains(&user_registration.id)
90    }
91
92    /// Mark a link as consumed to avoid replay
93    pub fn consume_session(
94        mut self,
95        user_registration: &UserRegistration,
96    ) -> Result<Self, UserRegistrationSessionNotFound> {
97        if !self.0.remove(&user_registration.id) {
98            return Err(UserRegistrationSessionNotFound);
99        }
100
101        Ok(self)
102    }
103}