1use std::collections::HashSet;
8
9use anyhow::Context;
10use async_trait::async_trait;
11use mas_data_model::Device;
12use mas_matrix::ProvisionRequest;
13use mas_storage::{
14 Pagination, RepositoryAccess,
15 compat::CompatSessionFilter,
16 oauth2::OAuth2SessionFilter,
17 queue::{
18 DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _,
19 SyncDevicesJob,
20 },
21 user::{UserEmailRepository, UserRepository},
22};
23use tracing::info;
24
25use crate::{
26 State,
27 new_queue::{JobContext, JobError, RunnableJob},
28};
29
30#[async_trait]
34impl RunnableJob for ProvisionUserJob {
35 #[tracing::instrument(
36 name = "job.provision_user"
37 fields(user.id = %self.user_id()),
38 skip_all,
39 )]
40 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
41 let matrix = state.matrix_connection();
42 let mut repo = state.repository().await.map_err(JobError::retry)?;
43 let mut rng = state.rng();
44 let clock = state.clock();
45
46 let user = repo
47 .user()
48 .lookup(self.user_id())
49 .await
50 .map_err(JobError::retry)?
51 .context("User not found")
52 .map_err(JobError::fail)?;
53
54 let emails = repo
55 .user_email()
56 .all(&user)
57 .await
58 .map_err(JobError::retry)?
59 .into_iter()
60 .map(|email| email.email)
61 .collect();
62 let mut request =
63 ProvisionRequest::new(user.username.clone(), user.sub.clone()).set_emails(emails);
64
65 if let Some(display_name) = self.display_name_to_set() {
66 request = request.set_displayname(display_name.to_owned());
67 }
68
69 let created = matrix
70 .provision_user(&request)
71 .await
72 .map_err(JobError::retry)?;
73
74 let mxid = matrix.mxid(&user.username);
75 if created {
76 info!(%user.id, %mxid, "User created");
77 } else {
78 info!(%user.id, %mxid, "User updated");
79 }
80
81 let sync_device_job = SyncDevicesJob::new(&user);
83 repo.queue_job()
84 .schedule_job(&mut rng, clock, sync_device_job)
85 .await
86 .map_err(JobError::retry)?;
87
88 repo.save().await.map_err(JobError::retry)?;
89
90 Ok(())
91 }
92}
93
94#[async_trait]
98impl RunnableJob for ProvisionDeviceJob {
99 #[tracing::instrument(
100 name = "job.provision_device"
101 fields(
102 user.id = %self.user_id(),
103 device.id = %self.device_id(),
104 ),
105 skip_all,
106 )]
107 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
108 let mut repo = state.repository().await.map_err(JobError::retry)?;
109 let mut rng = state.rng();
110 let clock = state.clock();
111
112 let user = repo
113 .user()
114 .lookup(self.user_id())
115 .await
116 .map_err(JobError::retry)?
117 .context("User not found")
118 .map_err(JobError::fail)?;
119
120 repo.queue_job()
122 .schedule_job(&mut rng, clock, SyncDevicesJob::new(&user))
123 .await
124 .map_err(JobError::retry)?;
125
126 Ok(())
127 }
128}
129
130#[async_trait]
134impl RunnableJob for DeleteDeviceJob {
135 #[tracing::instrument(
136 name = "job.delete_device"
137 fields(
138 user.id = %self.user_id(),
139 device.id = %self.device_id(),
140 ),
141 skip_all,
142 )]
143 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
144 let mut rng = state.rng();
145 let clock = state.clock();
146 let mut repo = state.repository().await.map_err(JobError::retry)?;
147
148 let user = repo
149 .user()
150 .lookup(self.user_id())
151 .await
152 .map_err(JobError::retry)?
153 .context("User not found")
154 .map_err(JobError::fail)?;
155
156 repo.queue_job()
158 .schedule_job(&mut rng, clock, SyncDevicesJob::new(&user))
159 .await
160 .map_err(JobError::retry)?;
161
162 Ok(())
163 }
164}
165
166#[async_trait]
168impl RunnableJob for SyncDevicesJob {
169 #[tracing::instrument(
170 name = "job.sync_devices",
171 fields(user.id = %self.user_id()),
172 skip_all,
173 )]
174 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
175 let matrix = state.matrix_connection();
176 let mut repo = state.repository().await.map_err(JobError::retry)?;
177
178 let user = repo
179 .user()
180 .lookup(self.user_id())
181 .await
182 .map_err(JobError::retry)?
183 .context("User not found")
184 .map_err(JobError::fail)?;
185
186 repo.user()
188 .acquire_lock_for_sync(&user)
189 .await
190 .map_err(JobError::retry)?;
191
192 let mut devices = HashSet::new();
193
194 let mut cursor = Pagination::first(5000);
196 loop {
197 let page = repo
198 .compat_session()
199 .list(
200 CompatSessionFilter::new().for_user(&user).active_only(),
201 cursor,
202 )
203 .await
204 .map_err(JobError::retry)?;
205
206 for (compat_session, _) in page.edges {
207 if let Some(ref device) = compat_session.device {
208 devices.insert(device.as_str().to_owned());
209 }
210 cursor = cursor.after(compat_session.id);
211 }
212
213 if !page.has_next_page {
214 break;
215 }
216 }
217
218 let mut cursor = Pagination::first(5000);
220 loop {
221 let page = repo
222 .oauth2_session()
223 .list(
224 OAuth2SessionFilter::new().for_user(&user).active_only(),
225 cursor,
226 )
227 .await
228 .map_err(JobError::retry)?;
229
230 for oauth2_session in page.edges {
231 for scope in &*oauth2_session.scope {
232 if let Some(device) = Device::from_scope_token(scope) {
233 devices.insert(device.as_str().to_owned());
234 }
235 }
236
237 cursor = cursor.after(oauth2_session.id);
238 }
239
240 if !page.has_next_page {
241 break;
242 }
243 }
244
245 matrix
246 .sync_devices(&user.username, devices)
247 .await
248 .map_err(JobError::retry)?;
249
250 repo.save().await.map_err(JobError::retry)?;
253
254 Ok(())
255 }
256}