mas_handlers/views/register/
cookie.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6// TODO: move that to a standalone cookie manager
7
8use std::collections::BTreeSet;
9
10use chrono::{DateTime, Duration, Utc};
11use mas_axum_utils::cookies::CookieJar;
12use mas_data_model::{Clock, UserRegistration};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15use ulid::Ulid;
16
17/// Name of the cookie
18static COOKIE_NAME: &str = "user-registration-sessions";
19
20/// Sessions expire after an hour
21static SESSION_MAX_TIME: Duration = Duration::hours(1);
22
23/// The content of the cookie, which stores a list of user registration IDs
24#[derive(Serialize, Deserialize, Default, Debug)]
25pub struct UserRegistrationSessions(BTreeSet<Ulid>);
26
27#[derive(Debug, Error, PartialEq, Eq)]
28#[error("user registration session not found")]
29pub struct UserRegistrationSessionNotFound;
30
31impl UserRegistrationSessions {
32    /// Load the user registration sessions cookie
33    pub fn load(cookie_jar: &CookieJar) -> Self {
34        match cookie_jar.load(COOKIE_NAME) {
35            Ok(Some(sessions)) => sessions,
36            Ok(None) => Self::default(),
37            Err(e) => {
38                tracing::warn!(
39                    error = &e as &dyn std::error::Error,
40                    "Invalid upstream sessions cookie"
41                );
42                Self::default()
43            }
44        }
45    }
46
47    /// Returns true if the cookie is empty
48    pub fn is_empty(&self) -> bool {
49        self.0.is_empty()
50    }
51
52    /// Save the user registration sessions to the cookie jar
53    pub fn save<C>(self, cookie_jar: CookieJar, clock: &C) -> CookieJar
54    where
55        C: Clock,
56    {
57        let this = self.expire(clock.now());
58
59        if this.is_empty() {
60            cookie_jar.remove(COOKIE_NAME)
61        } else {
62            cookie_jar.save(COOKIE_NAME, &this, false)
63        }
64    }
65
66    fn expire(mut self, now: DateTime<Utc>) -> Self {
67        self.0.retain(|id| {
68            let Ok(ts) = id.timestamp_ms().try_into() else {
69                return false;
70            };
71            let Some(when) = DateTime::from_timestamp_millis(ts) else {
72                return false;
73            };
74            now - when < SESSION_MAX_TIME
75        });
76
77        self
78    }
79
80    /// Add a new session, for a provider and a random state
81    pub fn add(mut self, user_registration: &UserRegistration) -> Self {
82        self.0.insert(user_registration.id);
83        self
84    }
85
86    /// Check if the session is in the list
87    pub fn contains(&self, user_registration: &UserRegistration) -> bool {
88        self.0.contains(&user_registration.id)
89    }
90
91    /// Mark a link as consumed to avoid replay
92    pub fn consume_session(
93        mut self,
94        user_registration: &UserRegistration,
95    ) -> Result<Self, UserRegistrationSessionNotFound> {
96        if !self.0.remove(&user_registration.id) {
97            return Err(UserRegistrationSessionNotFound);
98        }
99
100        Ok(self)
101    }
102}