Skip to main content

oxide_browser/
websocket.rs

1//! Host-side WebSocket connections for Oxide guest modules.
2//!
3//! Guests call the `api_ws_*` imports to open connections, send messages, poll
4//! for incoming messages, query connection state, and close connections.
5//! All I/O is non-blocking from the guest's perspective: the host spins up a
6//! tokio task per connection that drives the underlying `tokio-tungstenite`
7//! stream, pushes received frames into a `VecDeque`, and forwards outgoing
8//! frames from an mpsc channel.
9
10use std::collections::{HashMap, VecDeque};
11use std::sync::{Arc, Mutex};
12
13use anyhow::Result;
14use futures_util::{SinkExt, StreamExt};
15use tokio::runtime::Runtime;
16use tokio::sync::mpsc;
17use tokio_tungstenite::connect_async;
18use tokio_tungstenite::tungstenite::Message;
19use wasmtime::{Caller, Linker};
20
21use crate::capabilities::{
22    console_log, read_guest_bytes, read_guest_string, write_guest_bytes, ConsoleLevel, HostState,
23};
24
25// ── Ready-state constants (mirrors browser WebSocket.readyState) ──────────
26
27/// Connection is being established.
28pub const WS_CONNECTING: u32 = 0;
29/// Connection is open and ready to communicate.
30pub const WS_OPEN: u32 = 1;
31/// Connection is in the process of closing.
32pub const WS_CLOSING: u32 = 2;
33/// Connection is closed or could not be opened.
34pub const WS_CLOSED: u32 = 3;
35
36/// A queued incoming message (text or binary frame).
37struct RecvMsg {
38    is_binary: bool,
39    data: Vec<u8>,
40}
41
42/// An outgoing message queued by the guest for the writer task.
43enum SendMsg {
44    Data { is_binary: bool, data: Vec<u8> },
45    Close,
46}
47
48/// Per-connection state shared between the host API and the background task.
49struct WsConn {
50    /// Sender half of the outgoing message channel consumed by the writer task.
51    send_tx: mpsc::UnboundedSender<SendMsg>,
52    /// Incoming frames pushed by the reader task, drained by `api_ws_recv`.
53    recv_queue: Arc<Mutex<VecDeque<RecvMsg>>>,
54    /// Current connection state (one of the `WS_*` constants above).
55    ready_state: Arc<Mutex<u32>>,
56}
57
58/// All WebSocket state for a tab. Lazily initialised on the first `api_ws_*` call.
59pub struct WsState {
60    runtime: Runtime,
61    connections: HashMap<u32, WsConn>,
62    next_id: u32,
63}
64
65impl WsState {
66    pub fn new() -> Option<Self> {
67        let runtime = Runtime::new().ok()?;
68        Some(Self {
69            runtime,
70            connections: HashMap::new(),
71            next_id: 1,
72        })
73    }
74
75    fn alloc_id(&mut self) -> u32 {
76        let id = self.next_id;
77        self.next_id = self.next_id.wrapping_add(1).max(1);
78        id
79    }
80
81    /// Open a new WebSocket connection to `url`.
82    ///
83    /// Returns a handle (`> 0`) on success or `0` if the URL is invalid.
84    /// The actual TCP/TLS handshake happens asynchronously; poll
85    /// [`WsState::ready_state`] until it reaches [`WS_OPEN`].
86    fn connect(&mut self, url: &str) -> u32 {
87        let url = url.to_string();
88        let id = self.alloc_id();
89
90        let ready_state = Arc::new(Mutex::new(WS_CONNECTING));
91        let recv_queue: Arc<Mutex<VecDeque<RecvMsg>>> = Arc::new(Mutex::new(VecDeque::new()));
92        let (send_tx, mut send_rx) = mpsc::unbounded_channel::<SendMsg>();
93
94        let rs = ready_state.clone();
95        let rq = recv_queue.clone();
96
97        self.runtime.spawn(async move {
98            let ws_stream = match connect_async(&url).await {
99                Ok((stream, _)) => stream,
100                Err(_) => {
101                    *rs.lock().unwrap() = WS_CLOSED;
102                    return;
103                }
104            };
105
106            *rs.lock().unwrap() = WS_OPEN;
107            let (mut writer, mut reader) = ws_stream.split();
108
109            // Drive reading and writing concurrently.
110            loop {
111                tokio::select! {
112                    // Incoming frame from the remote server.
113                    msg = reader.next() => {
114                        match msg {
115                            Some(Ok(Message::Text(text))) => {
116                                rq.lock().unwrap().push_back(RecvMsg {
117                                    is_binary: false,
118                                    data: text.into_bytes(),
119                                });
120                            }
121                            Some(Ok(Message::Binary(bytes))) => {
122                                rq.lock().unwrap().push_back(RecvMsg {
123                                    is_binary: true,
124                                    data: bytes.to_vec(),
125                                });
126                            }
127                            Some(Ok(Message::Close(_))) | None => {
128                                *rs.lock().unwrap() = WS_CLOSED;
129                                break;
130                            }
131                            Some(Ok(Message::Ping(payload))) => {
132                                let _ = writer.send(Message::Pong(payload)).await;
133                            }
134                            _ => {}
135                        }
136                    }
137                    // Outgoing frame queued by the guest.
138                    outgoing = send_rx.recv() => {
139                        match outgoing {
140                            Some(SendMsg::Data { is_binary, data }) => {
141                                let msg = if is_binary {
142                                    Message::Binary(data)
143                                } else {
144                                    match String::from_utf8(data) {
145                                        Ok(text) => Message::Text(text),
146                                        Err(e) => Message::Binary(e.into_bytes()),
147                                    }
148                                };
149                                if writer.send(msg).await.is_err() {
150                                    *rs.lock().unwrap() = WS_CLOSED;
151                                    break;
152                                }
153                            }
154                            Some(SendMsg::Close) => {
155                                *rs.lock().unwrap() = WS_CLOSING;
156                                let _ = writer.send(Message::Close(None)).await;
157                                *rs.lock().unwrap() = WS_CLOSED;
158                                break;
159                            }
160                            None => {
161                                // Channel closed — host dropped the connection handle.
162                                *rs.lock().unwrap() = WS_CLOSED;
163                                break;
164                            }
165                        }
166                    }
167                }
168            }
169        });
170
171        self.connections.insert(
172            id,
173            WsConn {
174                send_tx,
175                recv_queue,
176                ready_state,
177            },
178        );
179
180        id
181    }
182
183    fn send(&self, id: u32, data: Vec<u8>, is_binary: bool) -> bool {
184        if let Some(conn) = self.connections.get(&id) {
185            conn.send_tx.send(SendMsg::Data { is_binary, data }).is_ok()
186        } else {
187            false
188        }
189    }
190
191    fn recv(&self, id: u32) -> Option<RecvMsg> {
192        self.connections
193            .get(&id)?
194            .recv_queue
195            .lock()
196            .unwrap()
197            .pop_front()
198    }
199
200    fn ready_state(&self, id: u32) -> u32 {
201        self.connections
202            .get(&id)
203            .map(|c| *c.ready_state.lock().unwrap())
204            .unwrap_or(WS_CLOSED)
205    }
206
207    fn close(&mut self, id: u32) -> bool {
208        if let Some(conn) = self.connections.get(&id) {
209            *conn.ready_state.lock().unwrap() = WS_CLOSING;
210            let _ = conn.send_tx.send(SendMsg::Close);
211            true
212        } else {
213            false
214        }
215    }
216
217    fn remove(&mut self, id: u32) {
218        self.connections.remove(&id);
219    }
220}
221
222fn ensure_ws(state: &Arc<Mutex<Option<WsState>>>) -> bool {
223    let mut g = state.lock().unwrap();
224    if g.is_none() {
225        *g = WsState::new();
226    }
227    g.is_some()
228}
229
230/// Register all `api_ws_*` host functions on the given linker.
231pub fn register_ws_functions(linker: &mut Linker<HostState>) -> Result<()> {
232    // ── ws_connect ────────────────────────────────────────────────────────
233    // api_ws_connect(url_ptr: u32, url_len: u32) -> u32
234    //   Returns a connection handle (> 0), or 0 on error.
235    linker.func_wrap(
236        "oxide",
237        "api_ws_connect",
238        |caller: Caller<'_, HostState>, url_ptr: u32, url_len: u32| -> u32 {
239            let console = caller.data().console.clone();
240            let ws = caller.data().ws.clone();
241            if !ensure_ws(&ws) {
242                console_log(&console, ConsoleLevel::Error, "[WS] Init failed".into());
243                return 0;
244            }
245            let mem = match caller.data().memory {
246                Some(m) => m,
247                None => return 0,
248            };
249            let url = match read_guest_string(&mem, &caller, url_ptr, url_len) {
250                Ok(s) => s,
251                Err(_) => return 0,
252            };
253            let id = ws.lock().unwrap().as_mut().unwrap().connect(&url);
254            console_log(
255                &console,
256                ConsoleLevel::Log,
257                format!("[WS] Connecting to {url} (id={id})"),
258            );
259            id
260        },
261    )?;
262
263    // ── ws_send_text ──────────────────────────────────────────────────────
264    // api_ws_send_text(id: u32, data_ptr: u32, data_len: u32) -> i32
265    //   Returns 0 on success, -1 if the connection is unknown or closed.
266    linker.func_wrap(
267        "oxide",
268        "api_ws_send_text",
269        |caller: Caller<'_, HostState>, id: u32, ptr: u32, len: u32| -> i32 {
270            let mem = match caller.data().memory {
271                Some(m) => m,
272                None => return -1,
273            };
274            let data = match read_guest_bytes(&mem, &caller, ptr, len) {
275                Ok(b) => b,
276                Err(_) => return -1,
277            };
278            let ws = caller.data().ws.clone();
279            let g = ws.lock().unwrap();
280            if let Some(ref state) = *g {
281                if state.send(id, data, false) {
282                    0
283                } else {
284                    -1
285                }
286            } else {
287                -1
288            }
289        },
290    )?;
291
292    // ── ws_send_binary ────────────────────────────────────────────────────
293    // api_ws_send_binary(id: u32, data_ptr: u32, data_len: u32) -> i32
294    linker.func_wrap(
295        "oxide",
296        "api_ws_send_binary",
297        |caller: Caller<'_, HostState>, id: u32, ptr: u32, len: u32| -> i32 {
298            let mem = match caller.data().memory {
299                Some(m) => m,
300                None => return -1,
301            };
302            let data = match read_guest_bytes(&mem, &caller, ptr, len) {
303                Ok(b) => b,
304                Err(_) => return -1,
305            };
306            let ws = caller.data().ws.clone();
307            let g = ws.lock().unwrap();
308            if let Some(ref state) = *g {
309                if state.send(id, data, true) {
310                    0
311                } else {
312                    -1
313                }
314            } else {
315                -1
316            }
317        },
318    )?;
319
320    // ── ws_recv ───────────────────────────────────────────────────────────
321    // api_ws_recv(id: u32, out_ptr: u32, out_cap: u32) -> i64
322    //
323    // Dequeues one frame and writes its bytes into guest memory at `out_ptr`.
324    // Return value encoding (same pattern as other APIs):
325    //   -1          : no message available (queue is empty)
326    //   >= 0        : low 32 bits = byte length written;
327    //                 bit 32 set   = frame is binary (bit 32 = 0 → text)
328    linker.func_wrap(
329        "oxide",
330        "api_ws_recv",
331        |mut caller: Caller<'_, HostState>, id: u32, out_ptr: u32, out_cap: u32| -> i64 {
332            let ws = caller.data().ws.clone();
333            let msg = {
334                let g = ws.lock().unwrap();
335                g.as_ref().and_then(|s| s.recv(id))
336            };
337            let msg = match msg {
338                Some(m) => m,
339                None => return -1,
340            };
341            let mem = match caller.data().memory {
342                Some(m) => m,
343                None => return -1,
344            };
345            let to_write = if msg.data.len() > out_cap as usize {
346                &msg.data[..out_cap as usize]
347            } else {
348                &msg.data
349            };
350            if write_guest_bytes(&mem, &mut caller, out_ptr, to_write).is_err() {
351                return -1;
352            }
353            let len = to_write.len() as i64;
354            if msg.is_binary {
355                len | (1i64 << 32)
356            } else {
357                len
358            }
359        },
360    )?;
361
362    // ── ws_ready_state ────────────────────────────────────────────────────
363    // api_ws_ready_state(id: u32) -> u32
364    //   0=CONNECTING  1=OPEN  2=CLOSING  3=CLOSED
365    linker.func_wrap(
366        "oxide",
367        "api_ws_ready_state",
368        |caller: Caller<'_, HostState>, id: u32| -> u32 {
369            let ws = caller.data().ws.clone();
370            let g = ws.lock().unwrap();
371            g.as_ref().map(|s| s.ready_state(id)).unwrap_or(WS_CLOSED)
372        },
373    )?;
374
375    // ── ws_close ──────────────────────────────────────────────────────────
376    // api_ws_close(id: u32) -> i32
377    //   Returns 1 if the close was initiated, 0 if the id is unknown.
378    linker.func_wrap(
379        "oxide",
380        "api_ws_close",
381        |caller: Caller<'_, HostState>, id: u32| -> i32 {
382            let ws = caller.data().ws.clone();
383            let mut g = ws.lock().unwrap();
384            if let Some(ref mut state) = *g {
385                if state.close(id) {
386                    1
387                } else {
388                    0
389                }
390            } else {
391                0
392            }
393        },
394    )?;
395
396    // ── ws_remove ─────────────────────────────────────────────────────────
397    // api_ws_remove(id: u32)
398    //   Frees host-side resources for a closed connection.
399    linker.func_wrap(
400        "oxide",
401        "api_ws_remove",
402        |caller: Caller<'_, HostState>, id: u32| {
403            let ws = caller.data().ws.clone();
404            let mut g = ws.lock().unwrap();
405            if let Some(ref mut state) = *g {
406                state.remove(id);
407            }
408        },
409    )?;
410
411    Ok(())
412}