1use std::collections::{HashMap, VecDeque};
11use std::sync::{Arc, Mutex};
12
13use anyhow::Result;
14use futures_util::{SinkExt, StreamExt};
15use tokio::runtime::Runtime;
16use tokio::sync::mpsc;
17use tokio_tungstenite::connect_async;
18use tokio_tungstenite::tungstenite::Message;
19use wasmtime::{Caller, Linker};
20
21use crate::capabilities::{
22 console_log, read_guest_bytes, read_guest_string, write_guest_bytes, ConsoleLevel, HostState,
23};
24
25pub const WS_CONNECTING: u32 = 0;
29pub const WS_OPEN: u32 = 1;
31pub const WS_CLOSING: u32 = 2;
33pub const WS_CLOSED: u32 = 3;
35
36struct RecvMsg {
38 is_binary: bool,
39 data: Vec<u8>,
40}
41
42enum SendMsg {
44 Data { is_binary: bool, data: Vec<u8> },
45 Close,
46}
47
48struct WsConn {
50 send_tx: mpsc::UnboundedSender<SendMsg>,
52 recv_queue: Arc<Mutex<VecDeque<RecvMsg>>>,
54 ready_state: Arc<Mutex<u32>>,
56}
57
58pub struct WsState {
60 runtime: Runtime,
61 connections: HashMap<u32, WsConn>,
62 next_id: u32,
63}
64
65impl WsState {
66 pub fn new() -> Option<Self> {
67 let runtime = Runtime::new().ok()?;
68 Some(Self {
69 runtime,
70 connections: HashMap::new(),
71 next_id: 1,
72 })
73 }
74
75 fn alloc_id(&mut self) -> u32 {
76 let id = self.next_id;
77 self.next_id = self.next_id.wrapping_add(1).max(1);
78 id
79 }
80
81 fn connect(&mut self, url: &str) -> u32 {
87 let url = url.to_string();
88 let id = self.alloc_id();
89
90 let ready_state = Arc::new(Mutex::new(WS_CONNECTING));
91 let recv_queue: Arc<Mutex<VecDeque<RecvMsg>>> = Arc::new(Mutex::new(VecDeque::new()));
92 let (send_tx, mut send_rx) = mpsc::unbounded_channel::<SendMsg>();
93
94 let rs = ready_state.clone();
95 let rq = recv_queue.clone();
96
97 self.runtime.spawn(async move {
98 let ws_stream = match connect_async(&url).await {
99 Ok((stream, _)) => stream,
100 Err(_) => {
101 *rs.lock().unwrap() = WS_CLOSED;
102 return;
103 }
104 };
105
106 *rs.lock().unwrap() = WS_OPEN;
107 let (mut writer, mut reader) = ws_stream.split();
108
109 loop {
111 tokio::select! {
112 msg = reader.next() => {
114 match msg {
115 Some(Ok(Message::Text(text))) => {
116 rq.lock().unwrap().push_back(RecvMsg {
117 is_binary: false,
118 data: text.into_bytes(),
119 });
120 }
121 Some(Ok(Message::Binary(bytes))) => {
122 rq.lock().unwrap().push_back(RecvMsg {
123 is_binary: true,
124 data: bytes.to_vec(),
125 });
126 }
127 Some(Ok(Message::Close(_))) | None => {
128 *rs.lock().unwrap() = WS_CLOSED;
129 break;
130 }
131 Some(Ok(Message::Ping(payload))) => {
132 let _ = writer.send(Message::Pong(payload)).await;
133 }
134 _ => {}
135 }
136 }
137 outgoing = send_rx.recv() => {
139 match outgoing {
140 Some(SendMsg::Data { is_binary, data }) => {
141 let msg = if is_binary {
142 Message::Binary(data)
143 } else {
144 match String::from_utf8(data) {
145 Ok(text) => Message::Text(text),
146 Err(e) => Message::Binary(e.into_bytes()),
147 }
148 };
149 if writer.send(msg).await.is_err() {
150 *rs.lock().unwrap() = WS_CLOSED;
151 break;
152 }
153 }
154 Some(SendMsg::Close) => {
155 *rs.lock().unwrap() = WS_CLOSING;
156 let _ = writer.send(Message::Close(None)).await;
157 *rs.lock().unwrap() = WS_CLOSED;
158 break;
159 }
160 None => {
161 *rs.lock().unwrap() = WS_CLOSED;
163 break;
164 }
165 }
166 }
167 }
168 }
169 });
170
171 self.connections.insert(
172 id,
173 WsConn {
174 send_tx,
175 recv_queue,
176 ready_state,
177 },
178 );
179
180 id
181 }
182
183 fn send(&self, id: u32, data: Vec<u8>, is_binary: bool) -> bool {
184 if let Some(conn) = self.connections.get(&id) {
185 conn.send_tx.send(SendMsg::Data { is_binary, data }).is_ok()
186 } else {
187 false
188 }
189 }
190
191 fn recv(&self, id: u32) -> Option<RecvMsg> {
192 self.connections
193 .get(&id)?
194 .recv_queue
195 .lock()
196 .unwrap()
197 .pop_front()
198 }
199
200 fn ready_state(&self, id: u32) -> u32 {
201 self.connections
202 .get(&id)
203 .map(|c| *c.ready_state.lock().unwrap())
204 .unwrap_or(WS_CLOSED)
205 }
206
207 fn close(&mut self, id: u32) -> bool {
208 if let Some(conn) = self.connections.get(&id) {
209 *conn.ready_state.lock().unwrap() = WS_CLOSING;
210 let _ = conn.send_tx.send(SendMsg::Close);
211 true
212 } else {
213 false
214 }
215 }
216
217 fn remove(&mut self, id: u32) {
218 self.connections.remove(&id);
219 }
220}
221
222fn ensure_ws(state: &Arc<Mutex<Option<WsState>>>) -> bool {
223 let mut g = state.lock().unwrap();
224 if g.is_none() {
225 *g = WsState::new();
226 }
227 g.is_some()
228}
229
230pub fn register_ws_functions(linker: &mut Linker<HostState>) -> Result<()> {
232 linker.func_wrap(
236 "oxide",
237 "api_ws_connect",
238 |caller: Caller<'_, HostState>, url_ptr: u32, url_len: u32| -> u32 {
239 let console = caller.data().console.clone();
240 let ws = caller.data().ws.clone();
241 if !ensure_ws(&ws) {
242 console_log(&console, ConsoleLevel::Error, "[WS] Init failed".into());
243 return 0;
244 }
245 let mem = match caller.data().memory {
246 Some(m) => m,
247 None => return 0,
248 };
249 let url = match read_guest_string(&mem, &caller, url_ptr, url_len) {
250 Ok(s) => s,
251 Err(_) => return 0,
252 };
253 let id = ws.lock().unwrap().as_mut().unwrap().connect(&url);
254 console_log(
255 &console,
256 ConsoleLevel::Log,
257 format!("[WS] Connecting to {url} (id={id})"),
258 );
259 id
260 },
261 )?;
262
263 linker.func_wrap(
267 "oxide",
268 "api_ws_send_text",
269 |caller: Caller<'_, HostState>, id: u32, ptr: u32, len: u32| -> i32 {
270 let mem = match caller.data().memory {
271 Some(m) => m,
272 None => return -1,
273 };
274 let data = match read_guest_bytes(&mem, &caller, ptr, len) {
275 Ok(b) => b,
276 Err(_) => return -1,
277 };
278 let ws = caller.data().ws.clone();
279 let g = ws.lock().unwrap();
280 if let Some(ref state) = *g {
281 if state.send(id, data, false) {
282 0
283 } else {
284 -1
285 }
286 } else {
287 -1
288 }
289 },
290 )?;
291
292 linker.func_wrap(
295 "oxide",
296 "api_ws_send_binary",
297 |caller: Caller<'_, HostState>, id: u32, ptr: u32, len: u32| -> i32 {
298 let mem = match caller.data().memory {
299 Some(m) => m,
300 None => return -1,
301 };
302 let data = match read_guest_bytes(&mem, &caller, ptr, len) {
303 Ok(b) => b,
304 Err(_) => return -1,
305 };
306 let ws = caller.data().ws.clone();
307 let g = ws.lock().unwrap();
308 if let Some(ref state) = *g {
309 if state.send(id, data, true) {
310 0
311 } else {
312 -1
313 }
314 } else {
315 -1
316 }
317 },
318 )?;
319
320 linker.func_wrap(
329 "oxide",
330 "api_ws_recv",
331 |mut caller: Caller<'_, HostState>, id: u32, out_ptr: u32, out_cap: u32| -> i64 {
332 let ws = caller.data().ws.clone();
333 let msg = {
334 let g = ws.lock().unwrap();
335 g.as_ref().and_then(|s| s.recv(id))
336 };
337 let msg = match msg {
338 Some(m) => m,
339 None => return -1,
340 };
341 let mem = match caller.data().memory {
342 Some(m) => m,
343 None => return -1,
344 };
345 let to_write = if msg.data.len() > out_cap as usize {
346 &msg.data[..out_cap as usize]
347 } else {
348 &msg.data
349 };
350 if write_guest_bytes(&mem, &mut caller, out_ptr, to_write).is_err() {
351 return -1;
352 }
353 let len = to_write.len() as i64;
354 if msg.is_binary {
355 len | (1i64 << 32)
356 } else {
357 len
358 }
359 },
360 )?;
361
362 linker.func_wrap(
366 "oxide",
367 "api_ws_ready_state",
368 |caller: Caller<'_, HostState>, id: u32| -> u32 {
369 let ws = caller.data().ws.clone();
370 let g = ws.lock().unwrap();
371 g.as_ref().map(|s| s.ready_state(id)).unwrap_or(WS_CLOSED)
372 },
373 )?;
374
375 linker.func_wrap(
379 "oxide",
380 "api_ws_close",
381 |caller: Caller<'_, HostState>, id: u32| -> i32 {
382 let ws = caller.data().ws.clone();
383 let mut g = ws.lock().unwrap();
384 if let Some(ref mut state) = *g {
385 if state.close(id) {
386 1
387 } else {
388 0
389 }
390 } else {
391 0
392 }
393 },
394 )?;
395
396 linker.func_wrap(
400 "oxide",
401 "api_ws_remove",
402 |caller: Caller<'_, HostState>, id: u32| {
403 let ws = caller.data().ws.clone();
404 let mut g = ws.lock().unwrap();
405 if let Some(ref mut state) = *g {
406 state.remove(id);
407 }
408 },
409 )?;
410
411 Ok(())
412}