Skip to main content

oxide_browser/
download.rs

1//! Download manager for non-WASM resources.
2//!
3//! When the Oxide browser navigates to a URL that does not point to a `.wasm`
4//! module it downloads the file to the system Downloads folder instead.
5//! Multiple downloads can run in parallel, each reporting progress (bytes
6//! received, total size, speed) via shared state polled by the UI.
7
8use std::path::PathBuf;
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12/// Identifies a single download; monotonically increasing.
13pub type DownloadId = u64;
14
15/// Snapshot of a single download's progress at a point in time.
16#[derive(Clone, Debug)]
17pub struct DownloadProgress {
18    pub id: DownloadId,
19    pub url: String,
20    pub filename: String,
21    pub state: DownloadState,
22    pub bytes_downloaded: u64,
23    /// `None` when the server omits `Content-Length`.
24    pub total_bytes: Option<u64>,
25    pub speed_bytes_per_sec: f64,
26    pub destination: PathBuf,
27}
28
29#[derive(Clone, Debug, PartialEq)]
30pub enum DownloadState {
31    InProgress,
32    Completed,
33    Failed(String),
34    Cancelled,
35}
36
37impl DownloadProgress {
38    pub fn percent(&self) -> Option<f64> {
39        self.total_bytes
40            .map(|total| (self.bytes_downloaded as f64 / total as f64) * 100.0)
41    }
42
43    pub fn is_finished(&self) -> bool {
44        !matches!(self.state, DownloadState::InProgress)
45    }
46}
47
48/// Thread-safe handle shared between the download worker threads and the UI.
49pub type SharedDownloads = Arc<Mutex<Vec<DownloadProgress>>>;
50
51/// Manages parallel file downloads. Owns a shared list of [`DownloadProgress`]
52/// entries that the UI polls each frame.
53impl Default for DownloadManager {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59#[derive(Clone)]
60pub struct DownloadManager {
61    downloads: SharedDownloads,
62    next_id: Arc<Mutex<u64>>,
63}
64
65impl DownloadManager {
66    pub fn new() -> Self {
67        Self {
68            downloads: Arc::new(Mutex::new(Vec::new())),
69            next_id: Arc::new(Mutex::new(1)),
70        }
71    }
72
73    pub fn downloads(&self) -> SharedDownloads {
74        self.downloads.clone()
75    }
76
77    /// Save arbitrary bytes as a file in the system Downloads directory.
78    /// Returns the path to the saved file on success.
79    pub fn save_data(&self, data: &[u8], filename: &str) -> std::io::Result<PathBuf> {
80        let dest_dir = dirs::download_dir()
81            .unwrap_or_else(|| dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")));
82        let dest = unique_path(&dest_dir, filename);
83        std::fs::write(&dest, data)?;
84        Ok(dest)
85    }
86
87    /// Kick off a background download for `url`.  Returns immediately.
88    /// The file is saved into the system Downloads directory.
89    pub fn start_download(&self, url: String) {
90        let dest_dir = dirs::download_dir()
91            .unwrap_or_else(|| dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")));
92        self.start_download_to(url, &dest_dir);
93    }
94
95    /// Kick off a background download for `url` into a specific directory.
96    pub fn start_download_to(&self, url: String, dest_dir: &std::path::Path) {
97        let id = {
98            let mut next = self.next_id.lock().unwrap();
99            let id = *next;
100            *next += 1;
101            id
102        };
103
104        let filename = filename_from_url(&url);
105        let dest = unique_path(dest_dir, &filename);
106
107        // Touch the file so a second concurrent download picks a different name.
108        let _ = std::fs::File::create(&dest);
109
110        let progress = DownloadProgress {
111            id,
112            url: url.clone(),
113            filename: dest
114                .file_name()
115                .unwrap_or_default()
116                .to_string_lossy()
117                .to_string(),
118            state: DownloadState::InProgress,
119            bytes_downloaded: 0,
120            total_bytes: None,
121            speed_bytes_per_sec: 0.0,
122            destination: dest.clone(),
123        };
124
125        self.downloads.lock().unwrap().push(progress);
126
127        let downloads = self.downloads.clone();
128        std::thread::spawn(move || {
129            let rt = tokio::runtime::Runtime::new().unwrap();
130            rt.block_on(run_download(id, url, dest, downloads));
131        });
132    }
133
134    /// Cancel an in-progress download.
135    pub fn cancel(&self, id: DownloadId) {
136        let mut list = self.downloads.lock().unwrap();
137        if let Some(dl) = list.iter_mut().find(|d| d.id == id) {
138            if dl.state == DownloadState::InProgress {
139                dl.state = DownloadState::Cancelled;
140            }
141        }
142    }
143
144    /// Remove a finished (or cancelled/failed) download entry from the list.
145    pub fn dismiss(&self, id: DownloadId) {
146        let mut list = self.downloads.lock().unwrap();
147        list.retain(|d| d.id != id);
148    }
149
150    pub fn has_active(&self) -> bool {
151        self.downloads
152            .lock()
153            .unwrap()
154            .iter()
155            .any(|d| d.state == DownloadState::InProgress)
156    }
157}
158
159async fn run_download(id: DownloadId, url: String, dest: PathBuf, downloads: SharedDownloads) {
160    use tokio::io::AsyncWriteExt;
161
162    let client = match reqwest::Client::builder()
163        .timeout(std::time::Duration::from_secs(600))
164        .build()
165    {
166        Ok(c) => c,
167        Err(e) => {
168            set_state(&downloads, id, DownloadState::Failed(e.to_string()));
169            return;
170        }
171    };
172
173    let response = match client.get(&url).send().await {
174        Ok(r) => r,
175        Err(e) => {
176            set_state(&downloads, id, DownloadState::Failed(e.to_string()));
177            return;
178        }
179    };
180
181    if !response.status().is_success() {
182        set_state(
183            &downloads,
184            id,
185            DownloadState::Failed(format!("HTTP {}", response.status())),
186        );
187        return;
188    }
189
190    let total = response.content_length();
191    {
192        let mut list = downloads.lock().unwrap();
193        if let Some(dl) = list.iter_mut().find(|d| d.id == id) {
194            dl.total_bytes = total;
195        }
196    }
197
198    let file = match tokio::fs::File::create(&dest).await {
199        Ok(f) => f,
200        Err(e) => {
201            set_state(&downloads, id, DownloadState::Failed(e.to_string()));
202            return;
203        }
204    };
205    let mut writer = tokio::io::BufWriter::new(file);
206    let mut stream = response.bytes_stream();
207    let started = Instant::now();
208    let mut downloaded: u64 = 0;
209
210    use futures_util::StreamExt;
211    while let Some(chunk_result) = stream.next().await {
212        let cancelled = {
213            let list = downloads.lock().unwrap();
214            list.iter()
215                .any(|d| d.id == id && d.state == DownloadState::Cancelled)
216        };
217        if cancelled {
218            let _ = tokio::fs::remove_file(&dest).await;
219            return;
220        }
221
222        let chunk: bytes::Bytes = match chunk_result {
223            Ok(c) => c,
224            Err(e) => {
225                set_state(&downloads, id, DownloadState::Failed(e.to_string()));
226                return;
227            }
228        };
229        if let Err(e) = writer.write_all(&chunk).await {
230            set_state(&downloads, id, DownloadState::Failed(e.to_string()));
231            return;
232        }
233        downloaded += chunk.len() as u64;
234        let elapsed = started.elapsed().as_secs_f64().max(0.001);
235        let speed = downloaded as f64 / elapsed;
236
237        {
238            let mut list = downloads.lock().unwrap();
239            if let Some(dl) = list.iter_mut().find(|d| d.id == id) {
240                dl.bytes_downloaded = downloaded;
241                dl.speed_bytes_per_sec = speed;
242            }
243        }
244    }
245
246    if let Err(e) = writer.flush().await {
247        set_state(&downloads, id, DownloadState::Failed(e.to_string()));
248        return;
249    }
250
251    set_state(&downloads, id, DownloadState::Completed);
252}
253
254fn set_state(downloads: &SharedDownloads, id: DownloadId, state: DownloadState) {
255    let mut list = downloads.lock().unwrap();
256    if let Some(dl) = list.iter_mut().find(|d| d.id == id) {
257        dl.state = state;
258    }
259}
260
261/// Extract a reasonable filename from a URL, falling back to `"download"`.
262fn filename_from_url(url: &str) -> String {
263    url::Url::parse(url)
264        .ok()
265        .and_then(|u| {
266            u.path_segments()
267                .and_then(|mut segs| segs.next_back().map(|s| s.to_string()))
268                .filter(|s| !s.is_empty())
269        })
270        .map(|name| {
271            percent_encoding::percent_decode_str(&name)
272                .decode_utf8_lossy()
273                .into_owned()
274        })
275        .unwrap_or_else(|| "download".to_string())
276}
277
278/// If `dir/name` exists, try `name (1)`, `name (2)`, etc.
279pub(crate) fn unique_path(dir: &std::path::Path, name: &str) -> PathBuf {
280    let candidate = dir.join(name);
281    if !candidate.exists() {
282        return candidate;
283    }
284    let stem = std::path::Path::new(name)
285        .file_stem()
286        .unwrap_or_default()
287        .to_string_lossy();
288    let ext = std::path::Path::new(name)
289        .extension()
290        .map(|e| format!(".{}", e.to_string_lossy()))
291        .unwrap_or_default();
292    for i in 1u32.. {
293        let try_name = format!("{stem} ({i}){ext}");
294        let p = dir.join(&try_name);
295        if !p.exists() {
296            return p;
297        }
298    }
299    dir.join(name)
300}
301
302/// Format byte count for display: "1.2 MB", "340 KB", etc.
303pub fn format_bytes(bytes: u64) -> String {
304    const KB: f64 = 1024.0;
305    const MB: f64 = 1024.0 * 1024.0;
306    const GB: f64 = 1024.0 * 1024.0 * 1024.0;
307    let b = bytes as f64;
308    if b >= GB {
309        format!("{:.1} GB", b / GB)
310    } else if b >= MB {
311        format!("{:.1} MB", b / MB)
312    } else if b >= KB {
313        format!("{:.0} KB", b / KB)
314    } else {
315        format!("{bytes} B")
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn filename_extraction() {
325        assert_eq!(
326            filename_from_url("https://github.com/robots.txt"),
327            "robots.txt"
328        );
329        assert_eq!(
330            filename_from_url("https://example.com/path/to/file.zip"),
331            "file.zip"
332        );
333        assert_eq!(filename_from_url("https://example.com/"), "download");
334        assert_eq!(filename_from_url("https://example.com"), "download");
335        assert_eq!(
336            filename_from_url("https://example.com/hello%20world.pdf"),
337            "hello world.pdf"
338        );
339    }
340
341    #[test]
342    fn unique_path_no_conflict() {
343        let dir = tempfile::tempdir().unwrap();
344        let p = unique_path(dir.path(), "test.txt");
345        assert_eq!(p, dir.path().join("test.txt"));
346    }
347
348    #[test]
349    fn unique_path_with_conflict() {
350        let dir = tempfile::tempdir().unwrap();
351        std::fs::write(dir.path().join("test.txt"), "existing").unwrap();
352        let p = unique_path(dir.path(), "test.txt");
353        assert_eq!(p, dir.path().join("test (1).txt"));
354    }
355
356    #[test]
357    fn format_bytes_display() {
358        assert_eq!(format_bytes(0), "0 B");
359        assert_eq!(format_bytes(512), "512 B");
360        assert_eq!(format_bytes(1024), "1 KB");
361        assert_eq!(format_bytes(1_500_000), "1.4 MB");
362        assert_eq!(format_bytes(2_000_000_000), "1.9 GB");
363    }
364
365    #[test]
366    fn download_github_robots_txt() {
367        let dir = tempfile::tempdir().unwrap();
368        let dm = DownloadManager::new();
369
370        dm.start_download_to("https://github.com/robots.txt".to_string(), dir.path());
371
372        // Poll until complete (up to 30 s).
373        let deadline = Instant::now() + std::time::Duration::from_secs(30);
374        loop {
375            std::thread::sleep(std::time::Duration::from_millis(200));
376            let list = dm.downloads().lock().unwrap().clone();
377            assert_eq!(list.len(), 1);
378            let dl = &list[0];
379            if dl.is_finished() {
380                assert_eq!(dl.state, DownloadState::Completed);
381                assert!(dl.bytes_downloaded > 0, "should have downloaded some bytes");
382                break;
383            }
384            assert!(
385                Instant::now() < deadline,
386                "download did not complete within 30 seconds"
387            );
388        }
389
390        let saved = dir.path().join("robots.txt");
391        assert!(saved.exists(), "robots.txt should exist on disk");
392        let content = std::fs::read_to_string(&saved).unwrap();
393        assert!(
394            content.contains("User-agent"),
395            "robots.txt should contain 'User-agent'"
396        );
397    }
398
399    #[test]
400    fn parallel_downloads() {
401        let dir = tempfile::tempdir().unwrap();
402        let dm = DownloadManager::new();
403
404        dm.start_download_to("https://github.com/robots.txt".to_string(), dir.path());
405        dm.start_download_to("https://github.com/robots.txt".to_string(), dir.path());
406
407        assert_eq!(dm.downloads().lock().unwrap().len(), 2);
408
409        let deadline = Instant::now() + std::time::Duration::from_secs(30);
410        loop {
411            std::thread::sleep(std::time::Duration::from_millis(200));
412            let list = dm.downloads().lock().unwrap().clone();
413            if list.iter().all(|d| d.is_finished()) {
414                assert!(list.iter().all(|d| d.state == DownloadState::Completed));
415                // Second download gets a deduplicated filename.
416                let names: Vec<_> = list.iter().map(|d| d.filename.clone()).collect();
417                assert!(names.contains(&"robots.txt".to_string()));
418                assert!(names.contains(&"robots (1).txt".to_string()));
419                break;
420            }
421            assert!(
422                Instant::now() < deadline,
423                "parallel downloads did not complete within 30 seconds"
424            );
425        }
426    }
427
428    #[test]
429    fn cancel_download() {
430        let dm = DownloadManager::new();
431        let dir = tempfile::tempdir().unwrap();
432
433        dm.start_download_to("https://github.com/robots.txt".to_string(), dir.path());
434
435        let id = dm.downloads().lock().unwrap()[0].id;
436        dm.cancel(id);
437
438        let deadline = Instant::now() + std::time::Duration::from_secs(10);
439        loop {
440            std::thread::sleep(std::time::Duration::from_millis(100));
441            let list = dm.downloads().lock().unwrap().clone();
442            let dl = &list[0];
443            if dl.is_finished() {
444                assert_eq!(dl.state, DownloadState::Cancelled);
445                break;
446            }
447            assert!(
448                Instant::now() < deadline,
449                "cancelled download did not finish within 10 seconds"
450            );
451        }
452    }
453
454    #[test]
455    fn dismiss_removes_entry() {
456        let dm = DownloadManager::new();
457        let dir = tempfile::tempdir().unwrap();
458
459        dm.start_download_to("https://github.com/robots.txt".to_string(), dir.path());
460
461        let id = dm.downloads().lock().unwrap()[0].id;
462        dm.cancel(id);
463
464        let deadline = Instant::now() + std::time::Duration::from_secs(10);
465        loop {
466            std::thread::sleep(std::time::Duration::from_millis(100));
467            if dm.downloads().lock().unwrap()[0].is_finished() {
468                break;
469            }
470            assert!(Instant::now() < deadline);
471        }
472
473        dm.dismiss(id);
474        assert!(dm.downloads().lock().unwrap().is_empty());
475    }
476}