1use tokio::sync::mpsc;
24use tower_lsp_server::Client;
25use tower_lsp_server::jsonrpc::Result;
26use tower_lsp_server::ls_types::{
27 ProgressParams, ProgressParamsValue, ProgressToken, WorkDoneProgress, WorkDoneProgressBegin,
28 WorkDoneProgressEnd, WorkDoneProgressReport,
29};
30
31const PROGRESS_CHANNEL_CAPACITY: usize = 8;
34
35#[derive(Clone)]
40pub struct ProgressSender {
41 tx: mpsc::Sender<ProgressUpdate>,
42 total: usize,
43}
44
45struct ProgressUpdate {
46 fetched: usize,
47 total: usize,
48}
49
50impl ProgressSender {
51 pub fn send(&self, fetched: usize) {
57 let _ = self.tx.try_send(ProgressUpdate {
58 fetched,
59 total: self.total,
60 });
61 }
62}
63
64pub struct RegistryProgress {
69 client: Client,
70 token: ProgressToken,
71 active: bool,
72 _consumer_handle: tokio::task::JoinHandle<()>,
75}
76
77impl RegistryProgress {
78 pub async fn start(
83 client: Client,
84 uri: &str,
85 total_deps: usize,
86 ) -> Result<(Self, ProgressSender)> {
87 let token = ProgressToken::String(format!("deps-fetch-{}", uri));
88
89 client
91 .send_request::<tower_lsp_server::ls_types::request::WorkDoneProgressCreate>(
92 tower_lsp_server::ls_types::WorkDoneProgressCreateParams {
93 token: token.clone(),
94 },
95 )
96 .await?;
97
98 client
100 .send_notification::<tower_lsp_server::ls_types::notification::Progress>(
101 ProgressParams {
102 token: token.clone(),
103 value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(
104 WorkDoneProgressBegin {
105 title: "Fetching package versions".to_string(),
106 message: Some(format!("Loading {} dependencies...", total_deps)),
107 cancellable: Some(false),
108 percentage: Some(0),
109 },
110 )),
111 },
112 )
113 .await;
114
115 let (tx, rx) = mpsc::channel(PROGRESS_CHANNEL_CAPACITY);
116
117 let consumer_client = client.clone();
119 let consumer_token = token.clone();
120 let consumer_handle = tokio::spawn(async move {
121 consume_progress_updates(rx, consumer_client, consumer_token).await;
122 });
123
124 let sender = ProgressSender {
125 tx,
126 total: total_deps,
127 };
128
129 Ok((
130 Self {
131 client,
132 token,
133 active: true,
134 _consumer_handle: consumer_handle,
135 },
136 sender,
137 ))
138 }
139
140 pub async fn end(mut self, success: bool) {
142 if !self.active {
143 return;
144 }
145
146 self.active = false;
147
148 self._consumer_handle.abort();
150
151 let message = if success {
152 "Package versions loaded"
153 } else {
154 "Failed to fetch some versions"
155 };
156
157 self.client
158 .send_notification::<tower_lsp_server::ls_types::notification::Progress>(
159 ProgressParams {
160 token: self.token.clone(),
161 value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(
162 WorkDoneProgressEnd {
163 message: Some(message.to_string()),
164 },
165 )),
166 },
167 )
168 .await;
169 }
170}
171
172async fn consume_progress_updates(
174 mut rx: mpsc::Receiver<ProgressUpdate>,
175 client: Client,
176 token: ProgressToken,
177) {
178 while let Some(update) = rx.recv().await {
179 let percentage = if update.total > 0 {
180 ((update.fetched as f64 / update.total as f64) * 100.0) as u32
181 } else {
182 0
183 };
184
185 client
186 .send_notification::<tower_lsp_server::ls_types::notification::Progress>(
187 ProgressParams {
188 token: token.clone(),
189 value: ProgressParamsValue::WorkDone(WorkDoneProgress::Report(
190 WorkDoneProgressReport {
191 message: Some(format!(
192 "Fetched {}/{} packages",
193 update.fetched, update.total
194 )),
195 percentage: Some(percentage),
196 cancellable: Some(false),
197 },
198 )),
199 },
200 )
201 .await;
202 }
203}
204
205impl Drop for RegistryProgress {
207 fn drop(&mut self) {
208 if self.active {
209 tracing::warn!(
210 token = ?self.token,
211 "RegistryProgress dropped without explicit end() - spawning cleanup"
212 );
213 self._consumer_handle.abort();
214 let client = self.client.clone();
215 let token = self.token.clone();
216 tokio::spawn(async move {
217 client
218 .send_notification::<tower_lsp_server::ls_types::notification::Progress>(
219 ProgressParams {
220 token,
221 value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(
222 WorkDoneProgressEnd { message: None },
223 )),
224 },
225 )
226 .await;
227 });
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 #[test]
235 fn test_progress_token_format() {
236 let uri = "file:///test/Cargo.toml";
237 let token = format!("deps-fetch-{}", uri);
238 assert_eq!(token, "deps-fetch-file:///test/Cargo.toml");
239 }
240
241 #[test]
242 fn test_percentage_calculation() {
243 let calculate = |fetched: usize, total: usize| -> u32 {
244 if total == 0 {
245 return 0;
246 }
247 ((fetched as f64 / total as f64) * 100.0) as u32
248 };
249
250 assert_eq!(calculate(0, 10), 0);
251 assert_eq!(calculate(5, 10), 50);
252 assert_eq!(calculate(10, 10), 100);
253 assert_eq!(calculate(7, 10), 70);
254 assert_eq!(calculate(0, 0), 0);
255 }
256
257 #[test]
258 fn test_progress_message_format() {
259 let format_message = |fetched: usize, total: usize| -> String {
260 format!("Fetched {}/{} packages", fetched, total)
261 };
262
263 assert_eq!(format_message(5, 10), "Fetched 5/10 packages");
264 assert_eq!(format_message(0, 15), "Fetched 0/15 packages");
265 assert_eq!(format_message(20, 20), "Fetched 20/20 packages");
266 }
267
268 #[tokio::test]
269 async fn test_progress_sender_try_send_on_closed_channel() {
270 use super::*;
271
272 let (tx, rx) = mpsc::channel(1);
273 let sender = ProgressSender { tx, total: 10 };
274
275 drop(rx);
277
278 sender.send(5);
280 }
281
282 #[tokio::test]
283 async fn test_progress_sender_try_send_on_full_channel() {
284 use super::*;
285
286 let (tx, _rx) = mpsc::channel(1);
287 let sender = ProgressSender { tx, total: 10 };
288
289 sender.send(1);
291 sender.send(2);
293 sender.send(3);
294 }
295}