diff --git a/api/src/db/schema.ts b/api/src/db/schema.ts index 419102a..48f4dc5 100644 --- a/api/src/db/schema.ts +++ b/api/src/db/schema.ts @@ -375,8 +375,15 @@ export const relations = defineRelations( track, trackArtist, user, + deviceConnection, }, (r) => ({ + deviceConnection: { + user: r.one.user({ + from: r.deviceConnection.userId, + to: r.user.id, + }), + }, artist: { artistGenres: r.many.artistGenre(), artistImages: r.many.artistImage(), @@ -601,6 +608,7 @@ export const relations = defineRelations( from: r.user.id, to: r.party.hostId, }), + deviceConnection: r.many.deviceConnection(), }, }), ); diff --git a/api/src/routes/device-socket.ts b/api/src/routes/device-socket.ts index 102dde7..6190c6a 100644 --- a/api/src/routes/device-socket.ts +++ b/api/src/routes/device-socket.ts @@ -10,7 +10,7 @@ import { pubsub, topic } from "./party-socket"; type DeviceSocketMessage = | { type: "device_message"; deviceId: string; payload: unknown } - | { type: "device_connected"; deviceId: string } + | { type: "device_status_request"; deviceId: string } | { type: "hello" } | { type: "device_event"; deviceId: string; event: DeviceProxyEvent }; @@ -23,7 +23,22 @@ type DeviceQuizResponsePayload = { QuizResponse: number; }; -let devProxySocket: WebSocket | null = null; +type DeviceConnectionRecord = typeof deviceConnection.$inferSelect; + +type DevProxySocket = { + send: (message: string) => unknown; +}; + +let devProxySocket: DevProxySocket | null = null; + +function withTimeout(promise: Promise, timeoutMs: number, label: string) { + return Promise.race([ + promise, + new Promise((_, reject) => { + setTimeout(() => reject(new Error(`${label} timed out`)), timeoutMs); + }), + ]); +} function isDeviceMessage( value: unknown, @@ -36,13 +51,13 @@ function isDeviceMessage( ); } -function isDeviceConnectedMessage( +function isDeviceStatusRequestMessage( value: unknown, -): value is Extract { +): value is Extract { return ( typeof value === "object" && value !== null && - (value as { type?: unknown }).type === "device_connected" && + (value as { type?: unknown }).type === "device_status_request" && typeof (value as { deviceId?: unknown }).deviceId === "string" ); } @@ -59,29 +74,58 @@ function isDeviceQuizResponsePayload( } function sendDeviceEvent(deviceId: string, event: DeviceProxyEvent) { - if (!devProxySocket || devProxySocket.readyState !== WebSocket.OPEN) return; + if (!devProxySocket) { + console.log("[device-socket] no dev proxy for event", deviceId, event.type); + return; + } - devProxySocket.send( - JSON.stringify({ - type: "device_event", - deviceId, - event, - } satisfies DeviceSocketMessage), - ); + try { + console.log("[device-socket] sending event", deviceId, event.type); + devProxySocket.send( + JSON.stringify({ + type: "device_event", + deviceId, + event, + } satisfies DeviceSocketMessage), + ); + } catch (error) { + console.error("[device-socket] failed to send event", error); + devProxySocket = null; + } } async function syncDeviceConnectionStatus(deviceId: string) { - const device = await db - .select() - .from(deviceConnection) - .where(eq(deviceConnection.id, deviceId)) - .then((rows) => rows[0]); - - if (!device) { + console.log("[device-socket] status request", deviceId); + let device: DeviceConnectionRecord | undefined; + try { + console.log("[device-socket] lookup device start", deviceId); + device = await withTimeout( + db + .select() + .from(deviceConnection) + .where(eq(deviceConnection.id, deviceId)) + .then((rows) => rows[0]), + 2_000, + `device lookup ${deviceId}`, + ); + console.log( + "[device-socket] lookup device result", + deviceId, + device ? "claimed" : "missing", + ); + } catch (error) { + console.error("[device-socket] lookup device failed", deviceId, error); sendDeviceEvent(deviceId, { type: "device_connect_required" }); return; } + if (!device) { + console.log("[device-socket] device unclaimed", deviceId); + sendDeviceEvent(deviceId, { type: "device_connect_required" }); + return; + } + + console.log("[device-socket] device claimed", deviceId, device.userId); await db .update(deviceConnection) .set({ lastSeen: new Date() }) @@ -111,7 +155,7 @@ export async function publishDeviceEventForUser( userId: string, event: PartySocketEvent, ) { - if (!devProxySocket || devProxySocket.readyState !== WebSocket.OPEN) return; + if (!devProxySocket) return; const devices = await db .select() @@ -202,31 +246,27 @@ export const deviceSocketApp = new Elysia().group("/dev-socket", (app) => .get("/test", () => ({ ok: 1 })) .ws("/ws", { open(ws) { - devProxySocket = ws as unknown as WebSocket; + console.log("[device-socket] dev proxy connected"); + devProxySocket = ws; ws.send( JSON.stringify({ type: "hello" } satisfies DeviceSocketMessage), ); }, - message: async (_ws, message) => { - if (typeof message !== "string") return; + message: async (_ws, message: DeviceSocketMessage) => { + if (typeof message !== "object") return; + console.log("[device-socket] received", message.type); - let parsed: DeviceSocketMessage; - try { - parsed = JSON.parse(message) as DeviceSocketMessage; - } catch { + if (isDeviceStatusRequestMessage(message)) { + await syncDeviceConnectionStatus(message.deviceId); return; } - if (isDeviceConnectedMessage(parsed)) { - await syncDeviceConnectionStatus(parsed.deviceId); - return; - } - - if (isDeviceMessage(parsed)) { - await forwardDevicePayload(parsed.deviceId, parsed.payload); + if (isDeviceMessage(message)) { + await forwardDevicePayload(message.deviceId, message.payload); } }, close() { + console.log("[device-socket] dev proxy disconnected"); if (devProxySocket === null) return; devProxySocket = null; }, diff --git a/dev-proxy/index.ts b/dev-proxy/index.ts index 9973615..9ed2f0c 100644 --- a/dev-proxy/index.ts +++ b/dev-proxy/index.ts @@ -27,7 +27,7 @@ type ProxyOutput = | { Error: string }; type ApiMessage = - | { type: "device_connected"; deviceId: string } + | { type: "device_status_request"; deviceId: string } | { type: "device_message"; deviceId: string; payload: unknown }; type QuizQuestion = @@ -76,7 +76,10 @@ type PartySocketEvent = const sockets = new Map(); const socketIds = new WeakMap(); -const apiSocket = new WebSocket("ws://localhost:4000/api/dev-socket/ws"); +const API_SOCKET_URL = "ws://localhost:4000/api/dev-socket/ws"; +let apiSocket: WebSocket | null = null; +let apiReconnectTimer: ReturnType | null = null; +const pendingDeviceStatus = new Set(); function socketDeviceId(socket: Socket) { return socketIds.get(socket); @@ -95,11 +98,77 @@ function writeProxyOutput(socket: Socket, output: ProxyOutput) { } function sendApiMessage(message: ApiMessage) { - if (apiSocket.readyState !== WebSocket.OPEN) return false; + if (!apiSocket || apiSocket.readyState !== WebSocket.OPEN) return false; + console.log("API send", message.type, message.deviceId); apiSocket.send(JSON.stringify(message)); return true; } +function requestDeviceStatus(deviceId: string) { + pendingDeviceStatus.add(deviceId); + if (sendApiMessage({ type: "device_status_request", deviceId })) { + console.log("Requested device status", deviceId); + pendingDeviceStatus.delete(deviceId); + return; + } + console.log("Queued device status request", deviceId); +} + +function flushPendingDeviceStatus() { + for (const deviceId of pendingDeviceStatus) { + if ( + sockets.has(deviceId) && + sendApiMessage({ type: "device_status_request", deviceId }) + ) { + console.log("Flushed device status request", deviceId); + pendingDeviceStatus.delete(deviceId); + } + } +} + +function disconnectDeviceClients(reason: string) { + console.log("Disconnecting device clients", reason, sockets.size); + for (const socket of sockets.values()) { + socket.end(); + } + sockets.clear(); + pendingDeviceStatus.clear(); +} + +function scheduleApiReconnect() { + if (apiReconnectTimer) return; + apiReconnectTimer = setTimeout(() => { + apiReconnectTimer = null; + connectApiSocket(); + }, 500); +} + +function connectApiSocket() { + if ( + apiSocket?.readyState === WebSocket.OPEN || + apiSocket?.readyState === WebSocket.CONNECTING + ) { + return; + } + + console.log("Connecting to API device socket"); + apiSocket = new WebSocket(API_SOCKET_URL); + apiSocket.onmessage = handleApiMessage; + apiSocket.onerror = (error) => { + console.error("API device socket error", error); + }; + apiSocket.onclose = () => { + console.log("API device socket closed; reconnecting"); + apiSocket = null; + disconnectDeviceClients("api socket closed"); + scheduleApiReconnect(); + }; + apiSocket.onopen = () => { + console.log("Connected to API device socket"); + flushPendingDeviceStatus(); + }; +} + function toDeviceQuestionData(quizData: QuizState): DeviceQuestionData | null { if (!quizData.currentQuestion) return null; const question = quizData.currentQuestion; @@ -137,13 +206,13 @@ const listener = Bun.listen({ if ("DeviceId" in data) { registerSocket(socket, data.DeviceId); - if ( - !sendApiMessage({ type: "device_connected", deviceId: data.DeviceId }) - ) { - writeProxyOutput(socket, { - Error: "API device socket not connected.", - }); - } + console.log( + "Requesting device status", + data.DeviceId, + "apiState", + apiSocket?.readyState, + ); + requestDeviceStatus(data.DeviceId); return; } if ("QuizResponse" in data) { @@ -162,12 +231,13 @@ const listener = Bun.listen({ const deviceId = socketDeviceId(socket); if (deviceId && sockets.get(deviceId) === socket) { sockets.delete(deviceId); + pendingDeviceStatus.delete(deviceId); } }, }, }); -apiSocket.onmessage = (e) => { +function handleApiMessage(e: MessageEvent) { let message: ApiEnvelope; try { message = JSON.parse(e.data) as ApiEnvelope; @@ -175,16 +245,23 @@ apiSocket.onmessage = (e) => { return; } + console.log("API recv", message.type); if (message.type !== "device_event") return; const socket = sockets.get(message.deviceId); - if (!socket) return; + if (!socket) { + console.log("No TCP socket for API event", message.deviceId); + return; + } const event = message.event as PartySocketEvent; + console.log("API device event", message.deviceId, event.type); if (event.type === "device_connect_required") { + console.log("Writing connect prompt", message.deviceId); writeProxyOutput(socket, { ConnectPrompt: message.deviceId }); return; } if (event.type === "device_connected") { + console.log("Writing waiting-for-party", message.deviceId); writeProxyOutput(socket, "WaitingForParty"); return; } @@ -217,14 +294,8 @@ apiSocket.onmessage = (e) => { } writeProxyOutput(socket, { Error: "Unsupported proxy event." }); -}; +} -apiSocket.onerror = (error) => { - console.error(error); -}; - -apiSocket.onopen = () => { - console.log("Connected to API device socket"); -}; +connectApiSocket(); console.log(`Started on :${listener.port}`); diff --git a/device-state/src/lib.rs b/device-state/src/lib.rs index 0789c6f..a50d739 100644 --- a/device-state/src/lib.rs +++ b/device-state/src/lib.rs @@ -16,6 +16,7 @@ const DOT: char = char::from_u32(0b1010_0101).unwrap(); #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum ViewState { Loading, + Reconnecting, ConnectPrompt, WaitingForParty, Question, @@ -105,12 +106,19 @@ impl DeviceState { } pub fn reset(&mut self) { + self.view = ViewState::Loading; self.question = None; self.wheel = WheelData::empty(); self.last_index = 0; self.title_offset = 0; } + pub fn reconnecting(&mut self) { + self.question = None; + self.wheel = WheelData::empty(); + self.view = ViewState::Reconnecting; + } + pub fn view_state(&self) -> ViewState { self.view } @@ -178,6 +186,20 @@ impl DeviceState { } pub fn render_lines(&mut self) -> Option<(OwnedStr<16>, OwnedStr<16>)> { + if self.view == ViewState::Loading { + return Some(( + OwnedStr::from_str("Connecting").unwrap(), + OwnedStr::from_str("Please wait").unwrap(), + )); + } + + if self.view == ViewState::Reconnecting { + return Some(( + OwnedStr::from_str("Reconnecting").unwrap(), + OwnedStr::from_str("Please wait").unwrap(), + )); + } + if self.view == ViewState::ConnectPrompt { let device_id = self.device_id.as_ref()?; let mut display_id: OwnedStr<16> = OwnedStr::new(); @@ -396,6 +418,28 @@ mod tests { assert_eq!(line2.as_str(), "Waiting party"); } + #[test] + fn renders_reconnecting_state() { + let mut state = DeviceState::new(); + + state.reconnecting(); + + assert_eq!(state.view_state(), ViewState::Reconnecting); + let (line1, line2) = state.render_lines().unwrap(); + assert_eq!(line1.as_str(), "Reconnecting"); + assert_eq!(line2.as_str(), "Please wait"); + } + + #[test] + fn renders_loading_state() { + let mut state = DeviceState::new(); + + let (line1, line2) = state.render_lines().unwrap(); + + assert_eq!(line1.as_str(), "Connecting"); + assert_eq!(line2.as_str(), "Please wait"); + } + #[test] fn wraps_forward_across_zero() { assert_eq!(wheel_delta(4090, 5, false), 11); diff --git a/esp32/src/main.rs b/esp32/src/main.rs index a84cbc3..85682ef 100644 --- a/esp32/src/main.rs +++ b/esp32/src/main.rs @@ -49,6 +49,11 @@ pub async fn reset_state() { state.reset(); } +pub async fn reconnecting_state() { + let mut state = STATE.lock().await; + state.reconnecting(); +} + #[panic_handler] fn panic(info: &core::panic::PanicInfo) -> ! { println!("PANIC! {:?}", info); diff --git a/esp32/src/net.rs b/esp32/src/net.rs index f647807..1f938f3 100644 --- a/esp32/src/net.rs +++ b/esp32/src/net.rs @@ -14,8 +14,9 @@ use esp_radio::wifi::sta::StationConfig; use esp_radio::wifi::{Config, ControllerConfig, scan::ScanConfig}; use crate::screen::overwrite_lcd; -use crate::{buffer::wait_for_config, tcp_read_loop, tcp_write_loop}; use crate::{WIFI_NETWORK, WIFI_PASSWORD}; +use crate::{buffer::wait_for_config, tcp_read_loop, tcp_write_loop}; +use crate::{reconnecting_state, reset_state}; pub struct NetworkConfig<'a> { pub wifi: WIFI<'a>, @@ -129,10 +130,12 @@ pub async fn network_setup_task( .await { println!("tcp connect error: {:?}", e); + reconnecting_state().await; overwrite_lcd("TCP error", &format!("{}", e)).await; Timer::after_millis(1000).await; continue; } + reset_state().await; overwrite_lcd("Connected", "").await; let cancel = Signal::::new(); @@ -149,10 +152,11 @@ pub async fn network_setup_task( if !stack.is_config_up() { println!("wifi down, reconnecting wifi"); + reconnecting_state().await; break; } - overwrite_lcd("Connection close", "").await; + reconnecting_state().await; Timer::after_millis(500).await; } } diff --git a/simulator/src/main.rs b/simulator/src/main.rs index 2d9e199..4f52b43 100644 --- a/simulator/src/main.rs +++ b/simulator/src/main.rs @@ -4,11 +4,11 @@ use std::sync::{Arc, Mutex}; use std::thread; use clap::Parser; -use crossterm::event::{Event, KeyCode, KeyEventKind, KeyModifiers}; -use crossterm::terminal::{disable_raw_mode, enable_raw_mode}; use crossterm::cursor::Show; -use crossterm::queue; use crossterm::event; +use crossterm::event::{Event, KeyCode, KeyEventKind, KeyModifiers}; +use crossterm::queue; +use crossterm::terminal::{disable_raw_mode, enable_raw_mode}; use device_state::{DeviceState, WriteType, apply_wheel_delta}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -88,6 +88,7 @@ async fn main() -> io::Result<()> { Ok(stream) => break stream, Err(err) => { log_error(&format!("connect error: {err}")); + state.lock().unwrap().reconnecting(); tokio::select! { _ = &mut sigint => { log_error("received SIGINT"); @@ -107,10 +108,12 @@ async fn main() -> io::Result<()> { } }; let (mut read, mut write) = stream.into_split(); + state.lock().unwrap().reset(); let device_id = device_state::serialize_write(&WriteType::DeviceId(DEVICE_ID)).unwrap(); if write.write_all(device_id.as_bytes()).await.is_err() { log_error("failed to send device id"); + state.lock().unwrap().reconnecting(); tokio::time::sleep(Duration::from_secs(1)).await; continue; } @@ -184,6 +187,7 @@ async fn main() -> io::Result<()> { } log_error("reconnecting in 500ms"); + state.lock().unwrap().reconnecting(); tokio::select! { _ = &mut sigint => { log_error("received SIGINT");