1use crate::error::{DepsError, Result};
2use bytes::Bytes;
3use dashmap::DashMap;
4use reqwest::{Client, StatusCode, header};
5use std::time::Instant;
6
7const MAX_CACHE_ENTRIES: usize = 1000;
9
10const HTTP_TIMEOUT_SECS: u64 = 30;
12
13const CACHE_EVICTION_PERCENTAGE: usize = 10;
15
16#[inline]
23fn ensure_https(url: &str) -> Result<()> {
24 #[cfg(not(test))]
25 if !url.starts_with("https://") {
26 return Err(DepsError::CacheError(format!("URL must use HTTPS: {url}")));
27 }
28 #[cfg(test)]
29 let _ = url; Ok(())
31}
32
33#[derive(Debug, Clone)]
58pub struct CachedResponse {
59 pub body: Bytes,
60 pub etag: Option<String>,
61 pub last_modified: Option<String>,
62 pub fetched_at: Instant,
63}
64
65pub struct HttpCache {
93 entries: DashMap<String, CachedResponse>,
94 client: Client,
95}
96
97impl HttpCache {
98 pub fn new() -> Self {
103 let client = Client::builder()
104 .user_agent(format!("deps-lsp/{}", env!("CARGO_PKG_VERSION")))
105 .timeout(std::time::Duration::from_secs(HTTP_TIMEOUT_SECS))
106 .build()
107 .expect("failed to create HTTP client");
108
109 Self {
110 entries: DashMap::new(),
111 client,
112 }
113 }
114
115 pub async fn get_cached(&self, url: &str) -> Result<Bytes> {
148 if self.entries.len() >= MAX_CACHE_ENTRIES {
150 self.evict_entries();
151 }
152
153 if let Some(cached) = self.entries.get(url) {
154 match self.conditional_request(url, &cached).await {
156 Ok(Some(new_body)) => {
157 return Ok(new_body);
159 }
160 Ok(None) => {
161 return Ok(cached.body.clone());
163 }
164 Err(e) => {
165 tracing::warn!("conditional request failed, using cache: {e}");
167 return Ok(cached.body.clone());
168 }
169 }
170 }
171
172 self.fetch_and_store(url).await
174 }
175
176 async fn conditional_request(
187 &self,
188 url: &str,
189 cached: &CachedResponse,
190 ) -> Result<Option<Bytes>> {
191 ensure_https(url)?;
192 let mut request = self.client.get(url);
193
194 if let Some(etag) = &cached.etag {
195 request = request.header(header::IF_NONE_MATCH, etag);
196 }
197 if let Some(last_modified) = &cached.last_modified {
198 request = request.header(header::IF_MODIFIED_SINCE, last_modified);
199 }
200
201 let response = request.send().await.map_err(|e| DepsError::RegistryError {
202 package: url.to_string(),
203 source: e,
204 })?;
205
206 if response.status() == StatusCode::NOT_MODIFIED {
207 return Ok(None);
209 }
210
211 let etag = response
213 .headers()
214 .get(header::ETAG)
215 .and_then(|v| v.to_str().ok())
216 .map(String::from);
217
218 let last_modified = response
219 .headers()
220 .get(header::LAST_MODIFIED)
221 .and_then(|v| v.to_str().ok())
222 .map(String::from);
223
224 let body = response
225 .bytes()
226 .await
227 .map_err(|e| DepsError::RegistryError {
228 package: url.to_string(),
229 source: e,
230 })?;
231
232 self.entries.insert(
234 url.to_string(),
235 CachedResponse {
236 body: body.clone(),
237 etag,
238 last_modified,
239 fetched_at: Instant::now(),
240 },
241 );
242
243 Ok(Some(body))
244 }
245
246 pub(crate) async fn fetch_and_store(&self, url: &str) -> Result<Bytes> {
257 ensure_https(url)?;
258 tracing::debug!("fetching fresh: {url}");
259
260 let response = self
261 .client
262 .get(url)
263 .send()
264 .await
265 .map_err(|e| DepsError::RegistryError {
266 package: url.to_string(),
267 source: e,
268 })?;
269
270 if !response.status().is_success() {
271 let status = response.status();
272 return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
273 }
274
275 let etag = response
276 .headers()
277 .get(header::ETAG)
278 .and_then(|v| v.to_str().ok())
279 .map(String::from);
280
281 let last_modified = response
282 .headers()
283 .get(header::LAST_MODIFIED)
284 .and_then(|v| v.to_str().ok())
285 .map(String::from);
286
287 let body = response
288 .bytes()
289 .await
290 .map_err(|e| DepsError::RegistryError {
291 package: url.to_string(),
292 source: e,
293 })?;
294
295 self.entries.insert(
296 url.to_string(),
297 CachedResponse {
298 body: body.clone(),
299 etag,
300 last_modified,
301 fetched_at: Instant::now(),
302 },
303 );
304
305 Ok(body)
306 }
307
308 pub fn clear(&self) {
313 self.entries.clear();
314 }
315
316 pub fn len(&self) -> usize {
318 self.entries.len()
319 }
320
321 pub fn is_empty(&self) -> bool {
323 self.entries.is_empty()
324 }
325
326 fn evict_entries(&self) {
334 use std::cmp::Reverse;
335 use std::collections::BinaryHeap;
336
337 let target_removals = MAX_CACHE_ENTRIES / CACHE_EVICTION_PERCENTAGE;
338
339 let mut oldest = BinaryHeap::with_capacity(target_removals);
342
343 for entry in &self.entries {
344 let item = (entry.value().fetched_at, entry.key().clone());
345
346 if oldest.len() < target_removals {
347 oldest.push(Reverse(item));
349 } else if let Some(Reverse(newest_of_oldest)) = oldest.peek() {
350 if item.0 < newest_of_oldest.0 {
353 oldest.pop();
354 oldest.push(Reverse(item));
355 }
356 }
357 }
358
359 let removed = oldest.len();
361 for Reverse((_, url)) in oldest {
362 self.entries.remove(&url);
363 }
364
365 tracing::debug!("evicted {} cache entries (O(N) algorithm)", removed);
366 }
367
368 #[doc(hidden)]
370 pub fn get_for_bench(&self, url: &str) -> Option<Bytes> {
371 self.entries.get(url).map(|entry| entry.body.clone())
372 }
373
374 #[doc(hidden)]
376 pub fn insert_for_bench(&self, url: String, response: CachedResponse) {
377 self.entries.insert(url, response);
378 }
379}
380
381impl Default for HttpCache {
382 fn default() -> Self {
383 Self::new()
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_cache_creation() {
393 let cache = HttpCache::new();
394 assert_eq!(cache.len(), 0);
395 assert!(cache.is_empty());
396 }
397
398 #[test]
399 fn test_cache_clear() {
400 let cache = HttpCache::new();
401 cache.entries.insert(
402 "test".into(),
403 CachedResponse {
404 body: Bytes::from_static(&[1, 2, 3]),
405 etag: None,
406 last_modified: None,
407 fetched_at: Instant::now(),
408 },
409 );
410 assert_eq!(cache.len(), 1);
411 cache.clear();
412 assert_eq!(cache.len(), 0);
413 }
414
415 #[test]
416 fn test_cached_response_clone() {
417 let response = CachedResponse {
418 body: Bytes::from_static(&[1, 2, 3]),
419 etag: Some("test".into()),
420 last_modified: Some("date".into()),
421 fetched_at: Instant::now(),
422 };
423 let cloned = response.clone();
424 assert_eq!(response.body, cloned.body);
426 assert_eq!(response.etag, cloned.etag);
427 }
428
429 #[test]
430 fn test_cache_len() {
431 let cache = HttpCache::new();
432 assert_eq!(cache.len(), 0);
433
434 cache.entries.insert(
435 "url1".into(),
436 CachedResponse {
437 body: Bytes::new(),
438 etag: None,
439 last_modified: None,
440 fetched_at: Instant::now(),
441 },
442 );
443
444 assert_eq!(cache.len(), 1);
445 }
446
447 #[tokio::test]
448 async fn test_get_cached_fresh_fetch() {
449 let mut server = mockito::Server::new_async().await;
450
451 let _m = server
452 .mock("GET", "/api/data")
453 .with_status(200)
454 .with_header("etag", "\"abc123\"")
455 .with_body("test data")
456 .create_async()
457 .await;
458
459 let cache = HttpCache::new();
460 let url = format!("{}/api/data", server.url());
461 let result: Bytes = cache.get_cached(&url).await.unwrap();
462
463 assert_eq!(result.as_ref(), b"test data");
464 assert_eq!(cache.len(), 1);
465 }
466
467 #[tokio::test]
468 async fn test_get_cached_cache_hit() {
469 let mut server = mockito::Server::new_async().await;
470 let url = format!("{}/api/data", server.url());
471
472 let cache = HttpCache::new();
473
474 let _m1 = server
475 .mock("GET", "/api/data")
476 .with_status(200)
477 .with_header("etag", "\"abc123\"")
478 .with_body("original data")
479 .create_async()
480 .await;
481
482 let result1: Bytes = cache.get_cached(&url).await.unwrap();
483 assert_eq!(result1.as_ref(), b"original data");
484 assert_eq!(cache.len(), 1);
485
486 drop(_m1);
487
488 let _m2 = server
489 .mock("GET", "/api/data")
490 .match_header("if-none-match", "\"abc123\"")
491 .with_status(304)
492 .create_async()
493 .await;
494
495 let result2: Bytes = cache.get_cached(&url).await.unwrap();
496 assert_eq!(result2.as_ref(), b"original data");
497 }
498
499 #[tokio::test]
500 async fn test_get_cached_304_not_modified() {
501 let mut server = mockito::Server::new_async().await;
502 let url = format!("{}/api/data", server.url());
503
504 let cache = HttpCache::new();
505
506 let _m1 = server
507 .mock("GET", "/api/data")
508 .with_status(200)
509 .with_header("etag", "\"abc123\"")
510 .with_body("original data")
511 .create_async()
512 .await;
513
514 let result1: Bytes = cache.get_cached(&url).await.unwrap();
515 assert_eq!(result1.as_ref(), b"original data");
516
517 drop(_m1);
518
519 let _m2 = server
520 .mock("GET", "/api/data")
521 .match_header("if-none-match", "\"abc123\"")
522 .with_status(304)
523 .create_async()
524 .await;
525
526 let result2: Bytes = cache.get_cached(&url).await.unwrap();
527 assert_eq!(result2.as_ref(), b"original data");
528 }
529
530 #[tokio::test]
531 async fn test_get_cached_etag_validation() {
532 let mut server = mockito::Server::new_async().await;
533 let url = format!("{}/api/data", server.url());
534
535 let cache = HttpCache::new();
536
537 cache.entries.insert(
538 url.clone(),
539 CachedResponse {
540 body: Bytes::from_static(b"cached"),
541 etag: Some("\"tag123\"".into()),
542 last_modified: None,
543 fetched_at: Instant::now(),
544 },
545 );
546
547 let _m = server
548 .mock("GET", "/api/data")
549 .match_header("if-none-match", "\"tag123\"")
550 .with_status(304)
551 .create_async()
552 .await;
553
554 let result: Bytes = cache.get_cached(&url).await.unwrap();
555 assert_eq!(result.as_ref(), b"cached");
556 }
557
558 #[tokio::test]
559 async fn test_get_cached_last_modified_validation() {
560 let mut server = mockito::Server::new_async().await;
561 let url = format!("{}/api/data", server.url());
562
563 let cache = HttpCache::new();
564
565 cache.entries.insert(
566 url.clone(),
567 CachedResponse {
568 body: Bytes::from_static(b"cached"),
569 etag: None,
570 last_modified: Some("Wed, 21 Oct 2024 07:28:00 GMT".into()),
571 fetched_at: Instant::now(),
572 },
573 );
574
575 let _m = server
576 .mock("GET", "/api/data")
577 .match_header("if-modified-since", "Wed, 21 Oct 2024 07:28:00 GMT")
578 .with_status(304)
579 .create_async()
580 .await;
581
582 let result: Bytes = cache.get_cached(&url).await.unwrap();
583 assert_eq!(result.as_ref(), b"cached");
584 }
585
586 #[tokio::test]
587 async fn test_get_cached_network_error_fallback() {
588 let cache = HttpCache::new();
589 let url = "http://invalid.localhost.test/data";
590
591 cache.entries.insert(
592 url.to_string(),
593 CachedResponse {
594 body: Bytes::from_static(b"stale data"),
595 etag: Some("\"old\"".into()),
596 last_modified: None,
597 fetched_at: Instant::now(),
598 },
599 );
600
601 let result: Bytes = cache.get_cached(url).await.unwrap();
602 assert_eq!(result.as_ref(), b"stale data");
603 }
604
605 #[tokio::test]
606 async fn test_fetch_and_store_http_error() {
607 let mut server = mockito::Server::new_async().await;
608
609 let _m = server
610 .mock("GET", "/api/missing")
611 .with_status(404)
612 .with_body("Not Found")
613 .create_async()
614 .await;
615
616 let cache = HttpCache::new();
617 let url = format!("{}/api/missing", server.url());
618 let result: Result<Bytes> = cache.fetch_and_store(&url).await;
619
620 assert!(result.is_err());
621 match result {
622 Err(DepsError::CacheError(msg)) => {
623 assert!(msg.contains("404"));
624 }
625 _ => panic!("Expected CacheError"),
626 }
627 }
628
629 #[tokio::test]
630 async fn test_fetch_and_store_stores_headers() {
631 let mut server = mockito::Server::new_async().await;
632
633 let _m = server
634 .mock("GET", "/api/data")
635 .with_status(200)
636 .with_header("etag", "\"abc123\"")
637 .with_header("last-modified", "Wed, 21 Oct 2024 07:28:00 GMT")
638 .with_body("test")
639 .create_async()
640 .await;
641
642 let cache = HttpCache::new();
643 let url = format!("{}/api/data", server.url());
644 let _: Bytes = cache.fetch_and_store(&url).await.unwrap();
645
646 let cached = cache.entries.get(&url).unwrap();
647 assert_eq!(cached.etag, Some("\"abc123\"".into()));
648 assert_eq!(
649 cached.last_modified,
650 Some("Wed, 21 Oct 2024 07:28:00 GMT".into())
651 );
652 }
653}