mas_handlers/activity_tracker/
worker.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{collections::HashMap, net::IpAddr};
8
9use chrono::{DateTime, Utc};
10use mas_storage::{RepositoryAccess, user::BrowserSessionRepository};
11use opentelemetry::{
12    Key, KeyValue,
13    metrics::{Counter, Histogram},
14};
15use sqlx::PgPool;
16use tokio_util::sync::CancellationToken;
17use ulid::Ulid;
18
19use crate::{
20    METER,
21    activity_tracker::{Message, SessionKind},
22};
23
24/// The maximum number of pending activity records before we flush them to the
25/// database automatically.
26///
27/// The [`ActivityRecord`] structure plus the key in the [`HashMap`] takes less
28/// than 100 bytes, so this should allocate around a megabyte of memory.
29static MAX_PENDING_RECORDS: usize = 10_000;
30
31const TYPE: Key = Key::from_static_str("type");
32const SESSION_KIND: Key = Key::from_static_str("session_kind");
33const RESULT: Key = Key::from_static_str("result");
34
35#[derive(Clone, Copy, Debug)]
36struct ActivityRecord {
37    // XXX: We don't actually use the start time for now
38    #[allow(dead_code)]
39    start_time: DateTime<Utc>,
40    end_time: DateTime<Utc>,
41    ip: Option<IpAddr>,
42}
43
44/// Handles writing activity records to the database.
45pub struct Worker {
46    pool: PgPool,
47    pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
48    message_counter: Counter<u64>,
49    flush_time_histogram: Histogram<u64>,
50}
51
52impl Worker {
53    pub(crate) fn new(pool: PgPool) -> Self {
54        let message_counter = METER
55            .u64_counter("mas.activity_tracker.messages")
56            .with_description("The number of messages received by the activity tracker")
57            .with_unit("{messages}")
58            .build();
59
60        // Record stuff on the counter so that the metrics are initialized
61        for kind in &[
62            SessionKind::OAuth2,
63            SessionKind::Compat,
64            SessionKind::Browser,
65        ] {
66            message_counter.add(
67                0,
68                &[
69                    KeyValue::new(TYPE, "record"),
70                    KeyValue::new(SESSION_KIND, kind.as_str()),
71                ],
72            );
73        }
74        message_counter.add(0, &[KeyValue::new(TYPE, "flush")]);
75        message_counter.add(0, &[KeyValue::new(TYPE, "shutdown")]);
76
77        let flush_time_histogram = METER
78            .u64_histogram("mas.activity_tracker.flush_time")
79            .with_description("The time it took to flush the activity tracker")
80            .with_unit("ms")
81            .build();
82
83        Self {
84            pool,
85            pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
86            message_counter,
87            flush_time_histogram,
88        }
89    }
90
91    pub(super) async fn run(
92        mut self,
93        mut receiver: tokio::sync::mpsc::Receiver<Message>,
94        cancellation_token: CancellationToken,
95    ) {
96        // This guard on the shutdown token is to ensure that if this task crashes for
97        // any reason, the server will shut down
98        let _guard = cancellation_token.clone().drop_guard();
99
100        loop {
101            let message = tokio::select! {
102                // Because we want the cancellation token to trigger only once,
103                // we looked whether we closed the channel or not
104                () = cancellation_token.cancelled(), if !receiver.is_closed() => {
105                    // We only close the channel, which will make it flush all
106                    // the pending messages
107                    receiver.close();
108                    tracing::debug!("Shutting down activity tracker");
109                    continue;
110                },
111
112                message = receiver.recv()  => {
113                    // We consumed all the messages, break out of the loop
114                    let Some(message) = message else { break };
115                    message
116                }
117            };
118
119            match message {
120                Message::Record {
121                    kind,
122                    id,
123                    date_time,
124                    ip,
125                } => {
126                    if self.pending_records.len() >= MAX_PENDING_RECORDS {
127                        tracing::warn!("Too many pending activity records, flushing");
128                        self.flush().await;
129                    }
130
131                    if self.pending_records.len() >= MAX_PENDING_RECORDS {
132                        tracing::error!(
133                            kind = kind.as_str(),
134                            %id,
135                            %date_time,
136                            "Still too many pending activity records, dropping"
137                        );
138                        continue;
139                    }
140
141                    self.message_counter.add(
142                        1,
143                        &[
144                            KeyValue::new(TYPE, "record"),
145                            KeyValue::new(SESSION_KIND, kind.as_str()),
146                        ],
147                    );
148
149                    let record =
150                        self.pending_records
151                            .entry((kind, id))
152                            .or_insert_with(|| ActivityRecord {
153                                start_time: date_time,
154                                end_time: date_time,
155                                ip,
156                            });
157
158                    record.end_time = date_time.max(record.end_time);
159                }
160
161                Message::Flush(tx) => {
162                    self.message_counter.add(1, &[KeyValue::new(TYPE, "flush")]);
163
164                    self.flush().await;
165                    let _ = tx.send(());
166                }
167            }
168        }
169
170        // Flush one last time
171        self.flush().await;
172    }
173
174    /// Flush the activity tracker.
175    async fn flush(&mut self) {
176        // Short path: if there are no pending records, we don't need to flush
177        if self.pending_records.is_empty() {
178            return;
179        }
180
181        let start = std::time::Instant::now();
182        let res = self.try_flush().await;
183
184        // Measure the time it took to flush the activity tracker
185        let duration = start.elapsed();
186        let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
187
188        match res {
189            Ok(()) => {
190                self.flush_time_histogram
191                    .record(duration_ms, &[KeyValue::new(RESULT, "success")]);
192            }
193            Err(e) => {
194                self.flush_time_histogram
195                    .record(duration_ms, &[KeyValue::new(RESULT, "failure")]);
196                tracing::error!("Failed to flush activity tracker: {}", e);
197            }
198        }
199    }
200
201    /// Fallible part of [`Self::flush`].
202    #[tracing::instrument(name = "activity_tracker.flush", skip(self))]
203    async fn try_flush(&mut self) -> Result<(), anyhow::Error> {
204        let pending_records = &self.pending_records;
205
206        let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
207            .await?
208            .boxed();
209
210        let mut browser_sessions = Vec::new();
211        let mut oauth2_sessions = Vec::new();
212        let mut compat_sessions = Vec::new();
213
214        for ((kind, id), record) in pending_records {
215            match kind {
216                SessionKind::Browser => {
217                    browser_sessions.push((*id, record.end_time, record.ip));
218                }
219                SessionKind::OAuth2 => {
220                    oauth2_sessions.push((*id, record.end_time, record.ip));
221                }
222                SessionKind::Compat => {
223                    compat_sessions.push((*id, record.end_time, record.ip));
224                }
225            }
226        }
227
228        tracing::info!(
229            "Flushing {} activity records to the database",
230            pending_records.len()
231        );
232
233        repo.browser_session()
234            .record_batch_activity(browser_sessions)
235            .await?;
236        repo.oauth2_session()
237            .record_batch_activity(oauth2_sessions)
238            .await?;
239        repo.compat_session()
240            .record_batch_activity(compat_sessions)
241            .await?;
242
243        repo.save().await?;
244        self.pending_records.clear();
245
246        Ok(())
247    }
248}