mas_handlers/activity_tracker/
worker.rs1use std::{collections::HashMap, net::IpAddr};
8
9use chrono::{DateTime, Utc};
10use mas_storage::{RepositoryAccess, user::BrowserSessionRepository};
11use opentelemetry::{
12 Key, KeyValue,
13 metrics::{Counter, Histogram},
14};
15use sqlx::PgPool;
16use tokio_util::sync::CancellationToken;
17use ulid::Ulid;
18
19use crate::{
20 METER,
21 activity_tracker::{Message, SessionKind},
22};
23
24static MAX_PENDING_RECORDS: usize = 10_000;
30
31const TYPE: Key = Key::from_static_str("type");
32const SESSION_KIND: Key = Key::from_static_str("session_kind");
33const RESULT: Key = Key::from_static_str("result");
34
35#[derive(Clone, Copy, Debug)]
36struct ActivityRecord {
37 #[allow(dead_code)]
39 start_time: DateTime<Utc>,
40 end_time: DateTime<Utc>,
41 ip: Option<IpAddr>,
42}
43
44pub struct Worker {
46 pool: PgPool,
47 pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
48 message_counter: Counter<u64>,
49 flush_time_histogram: Histogram<u64>,
50}
51
52impl Worker {
53 pub(crate) fn new(pool: PgPool) -> Self {
54 let message_counter = METER
55 .u64_counter("mas.activity_tracker.messages")
56 .with_description("The number of messages received by the activity tracker")
57 .with_unit("{messages}")
58 .build();
59
60 for kind in &[
62 SessionKind::OAuth2,
63 SessionKind::Compat,
64 SessionKind::Browser,
65 ] {
66 message_counter.add(
67 0,
68 &[
69 KeyValue::new(TYPE, "record"),
70 KeyValue::new(SESSION_KIND, kind.as_str()),
71 ],
72 );
73 }
74 message_counter.add(0, &[KeyValue::new(TYPE, "flush")]);
75 message_counter.add(0, &[KeyValue::new(TYPE, "shutdown")]);
76
77 let flush_time_histogram = METER
78 .u64_histogram("mas.activity_tracker.flush_time")
79 .with_description("The time it took to flush the activity tracker")
80 .with_unit("ms")
81 .build();
82
83 Self {
84 pool,
85 pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
86 message_counter,
87 flush_time_histogram,
88 }
89 }
90
91 pub(super) async fn run(
92 mut self,
93 mut receiver: tokio::sync::mpsc::Receiver<Message>,
94 cancellation_token: CancellationToken,
95 ) {
96 let _guard = cancellation_token.clone().drop_guard();
99
100 loop {
101 let message = tokio::select! {
102 () = cancellation_token.cancelled(), if !receiver.is_closed() => {
105 receiver.close();
108 tracing::debug!("Shutting down activity tracker");
109 continue;
110 },
111
112 message = receiver.recv() => {
113 let Some(message) = message else { break };
115 message
116 }
117 };
118
119 match message {
120 Message::Record {
121 kind,
122 id,
123 date_time,
124 ip,
125 } => {
126 if self.pending_records.len() >= MAX_PENDING_RECORDS {
127 tracing::warn!("Too many pending activity records, flushing");
128 self.flush().await;
129 }
130
131 if self.pending_records.len() >= MAX_PENDING_RECORDS {
132 tracing::error!(
133 kind = kind.as_str(),
134 %id,
135 %date_time,
136 "Still too many pending activity records, dropping"
137 );
138 continue;
139 }
140
141 self.message_counter.add(
142 1,
143 &[
144 KeyValue::new(TYPE, "record"),
145 KeyValue::new(SESSION_KIND, kind.as_str()),
146 ],
147 );
148
149 let record =
150 self.pending_records
151 .entry((kind, id))
152 .or_insert_with(|| ActivityRecord {
153 start_time: date_time,
154 end_time: date_time,
155 ip,
156 });
157
158 record.end_time = date_time.max(record.end_time);
159 }
160
161 Message::Flush(tx) => {
162 self.message_counter.add(1, &[KeyValue::new(TYPE, "flush")]);
163
164 self.flush().await;
165 let _ = tx.send(());
166 }
167 }
168 }
169
170 self.flush().await;
172 }
173
174 async fn flush(&mut self) {
176 if self.pending_records.is_empty() {
178 return;
179 }
180
181 let start = std::time::Instant::now();
182 let res = self.try_flush().await;
183
184 let duration = start.elapsed();
186 let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
187
188 match res {
189 Ok(()) => {
190 self.flush_time_histogram
191 .record(duration_ms, &[KeyValue::new(RESULT, "success")]);
192 }
193 Err(e) => {
194 self.flush_time_histogram
195 .record(duration_ms, &[KeyValue::new(RESULT, "failure")]);
196 tracing::error!("Failed to flush activity tracker: {}", e);
197 }
198 }
199 }
200
201 #[tracing::instrument(name = "activity_tracker.flush", skip(self))]
203 async fn try_flush(&mut self) -> Result<(), anyhow::Error> {
204 let pending_records = &self.pending_records;
205
206 let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
207 .await?
208 .boxed();
209
210 let mut browser_sessions = Vec::new();
211 let mut oauth2_sessions = Vec::new();
212 let mut compat_sessions = Vec::new();
213
214 for ((kind, id), record) in pending_records {
215 match kind {
216 SessionKind::Browser => {
217 browser_sessions.push((*id, record.end_time, record.ip));
218 }
219 SessionKind::OAuth2 => {
220 oauth2_sessions.push((*id, record.end_time, record.ip));
221 }
222 SessionKind::Compat => {
223 compat_sessions.push((*id, record.end_time, record.ip));
224 }
225 }
226 }
227
228 tracing::info!(
229 "Flushing {} activity records to the database",
230 pending_records.len()
231 );
232
233 repo.browser_session()
234 .record_batch_activity(browser_sessions)
235 .await?;
236 repo.oauth2_session()
237 .record_batch_activity(oauth2_sessions)
238 .await?;
239 repo.compat_session()
240 .record_batch_activity(compat_sessions)
241 .await?;
242
243 repo.save().await?;
244 self.pending_records.clear();
245
246 Ok(())
247 }
248}