diff --git a/api/src/index.ts b/api/src/index.ts index d8d6caa..a195b48 100644 --- a/api/src/index.ts +++ b/api/src/index.ts @@ -6,11 +6,11 @@ import { syncApp } from "./routes/sync"; import "./workflows/sync"; import "./workflows/party-analysis"; import "./dbos.ts"; +import { deviceClaimApp, deviceSocketApp } from "./routes/device-socket.ts"; import { partyApp } from "./routes/party"; import { partySocketApp, pubsub } from "./routes/party-socket"; import { quizRoutes } from "./routes/quiz.ts"; import { statsApp } from "./routes/stats.ts"; -import { deviceClaimApp, deviceSocketApp } from "./routes/device-socket.ts"; const app = new Elysia() .use(betterAuthElysia) diff --git a/api/src/party/numeric-question-generator.ts b/api/src/party/numeric-question-generator.ts index 5092ee9..e17d93e 100644 --- a/api/src/party/numeric-question-generator.ts +++ b/api/src/party/numeric-question-generator.ts @@ -1,6 +1,9 @@ import { and, eq, inArray } from "drizzle-orm"; import type { db } from "../db"; -import { topArtist as topArtistTable, topTrack as topTrackTable } from "../db/schema"; +import { + topArtist as topArtistTable, + topTrack as topTrackTable, +} from "../db/schema"; import type { Question } from "../party-types"; import { buildQuestionWindow, @@ -53,13 +56,20 @@ async function countTopTrackListeners({ }: BuildNumericQuestionInput): Promise { const trackName = analytics?.storyClusters?.[0]?.tracks?.[0]?.name; if (!trackName || members.length === 0) return null; - const dbTrack = await db.query.track.findFirst({ where: { name: trackName } }); + const dbTrack = await db.query.track.findFirst({ + where: { name: trackName }, + }); if (!dbTrack) return null; const memberIds = members.map((m) => m.userId); const entries = await db .select({ userId: topTrackTable.userId }) .from(topTrackTable) - .where(and(eq(topTrackTable.trackId, dbTrack.id), inArray(topTrackTable.userId, memberIds))); + .where( + and( + eq(topTrackTable.trackId, dbTrack.id), + inArray(topTrackTable.userId, memberIds), + ), + ); const correct = new Set(entries.map((e) => e.userId)).size; return { type: "numeric", @@ -77,13 +87,20 @@ async function countFavouriteArtistListeners({ }: BuildNumericQuestionInput): Promise { const artistName = analytics?.storyClusters?.[0]?.artists?.[0]?.name; if (!artistName || members.length === 0) return null; - const dbArtist = await db.query.artist.findFirst({ where: { name: artistName } }); + const dbArtist = await db.query.artist.findFirst({ + where: { name: artistName }, + }); if (!dbArtist) return null; const memberIds = members.map((m) => m.userId); const entries = await db .select({ userId: topArtistTable.userId }) .from(topArtistTable) - .where(and(eq(topArtistTable.artistId, dbArtist.id), inArray(topArtistTable.userId, memberIds))); + .where( + and( + eq(topArtistTable.artistId, dbArtist.id), + inArray(topArtistTable.userId, memberIds), + ), + ); const correct = new Set(entries.map((e) => e.userId)).size; return { type: "numeric", @@ -110,4 +127,4 @@ export async function buildNumericQuestion( const question = questions[input.index % questions.length] ?? questions[0]; if (!question) throw new Error("Question not found"); return buildQuestionWindow(question); -} \ No newline at end of file +} diff --git a/api/src/party/question-utils.ts b/api/src/party/question-utils.ts index 5c1dec4..1797ea3 100644 --- a/api/src/party/question-utils.ts +++ b/api/src/party/question-utils.ts @@ -242,7 +242,10 @@ export function buildMemberPairOptions( const pairs: string[] = [correctPair]; for (let i = 0; i < members.length; i++) { for (let j = i + 1; j < members.length; j++) { - const pair = `${members[i]!.name} & ${members[j]!.name}`; + const left = members[i]; + const right = members[j]; + if (!left || !right) continue; + const pair = `${left.name} & ${right.name}`; if (pair !== correctPair) pairs.push(pair); } } diff --git a/api/src/party/state.ts b/api/src/party/state.ts index df01141..838e253 100644 --- a/api/src/party/state.ts +++ b/api/src/party/state.ts @@ -1,7 +1,8 @@ import { eq } from "drizzle-orm"; import type { db as Db } from "../db"; import { party } from "../db/schema"; -import type { QuizState } from "../party-types"; +import type { PartySocketEvent, QuizState } from "../party-types"; +import { publishDeviceEventForUser } from "../routes/device-socket"; import { pubsub } from "../routes/party-socket"; export async function updatePartyData( @@ -32,6 +33,19 @@ export async function updatePartyData( }, members, }); + + const event: PartySocketEvent = { + type: "party_status", + party: { + ...partyObject, + data, + }, + members, + }; + for (const member of members) { + if (!member.userId) continue; + void publishDeviceEventForUser(member.userId, event); + } await db .update(party) .set({ diff --git a/api/src/routes/device-socket.ts b/api/src/routes/device-socket.ts index 419cc21..b98be62 100644 --- a/api/src/routes/device-socket.ts +++ b/api/src/routes/device-socket.ts @@ -1,10 +1,11 @@ +import { DBOS } from "@dbos-inc/dbos-sdk"; import { eq } from "drizzle-orm"; import Elysia from "elysia"; import { betterAuthElysia } from "../auth"; import { db } from "../db"; import { deviceConnection } from "../db/schema"; import { getMemberRecord } from "../party-data"; -import type { PartySocketEvent } from "../party-types"; +import type { PartySocketEvent, QuizState } from "../party-types"; import { pubsub, topic } from "./party-socket"; type DeviceSocketMessage = @@ -12,6 +13,10 @@ type DeviceSocketMessage = | { type: "hello" } | { type: "device_event"; deviceId: string; event: PartySocketEvent }; +type DeviceQuizResponsePayload = { + QuizResponse: number; +}; + let devProxySocket: WebSocket | null = null; function isDeviceMessage( @@ -25,6 +30,29 @@ function isDeviceMessage( ); } +function isDeviceQuizResponsePayload( + value: unknown, +): value is DeviceQuizResponsePayload { + return ( + typeof value === "object" && + value !== null && + "QuizResponse" in value && + Number.isInteger((value as DeviceQuizResponsePayload).QuizResponse) + ); +} + +function sendDeviceEvent(deviceId: string, event: PartySocketEvent) { + if (!devProxySocket || devProxySocket.readyState !== WebSocket.OPEN) return; + + devProxySocket.send( + JSON.stringify({ + type: "device_event", + deviceId, + event, + } satisfies DeviceSocketMessage), + ); +} + export async function claimDeviceForUser(deviceId: string, userId: string) { await db .insert(deviceConnection) @@ -55,29 +83,72 @@ export async function publishDeviceEventForUser( .where(eq(deviceConnection.userId, userId)); for (const device of devices) { - devProxySocket.send( - JSON.stringify({ - type: "device_event", - deviceId: device.id, - event, - } satisfies DeviceSocketMessage), - ); + sendDeviceEvent(device.id, event); } } async function forwardDevicePayload(deviceId: string, payload: unknown) { + if (!isDeviceQuizResponsePayload(payload)) { + sendDeviceEvent(deviceId, { + type: "error", + message: "Unsupported device payload.", + }); + return; + } + const device = await db .select() .from(deviceConnection) .where(eq(deviceConnection.id, deviceId)) .then((rows) => rows[0]); - if (!device) return; + if (!device) { + sendDeviceEvent(deviceId, { + type: "error", + message: "Device not linked to a user.", + }); + return; + } const membership = await getMemberRecord(db, device.userId); - if (!membership) return; + if (!membership) { + sendDeviceEvent(deviceId, { + type: "error", + message: "Device not linked to a party member.", + }); + return; + } - const payloadString = JSON.stringify(payload); - if (payloadString.length > 8_000) return; + const party = await db.query.party.findFirst({ + where: { id: membership.partyId }, + }); + if (!party) { + sendDeviceEvent(deviceId, { + type: "error", + message: "Party not found.", + }); + return; + } + const quizData = party.data as QuizState | null; + if (!quizData || quizData.status !== "running") { + sendDeviceEvent(deviceId, { + type: "error", + message: "Quiz not running.", + }); + return; + } + if (!quizData.workflowId) { + sendDeviceEvent(deviceId, { + type: "error", + message: "Workflow ID not found.", + }); + return; + } + + await DBOS.send( + quizData.workflowId, + { playerId: device.userId, selected: payload.QuizResponse }, + "quiz_responses", + ); pubsub.publish( topic.party(membership.partyId), diff --git a/api/src/routes/party-socket.ts b/api/src/routes/party-socket.ts index 9bc0330..213e322 100644 --- a/api/src/routes/party-socket.ts +++ b/api/src/routes/party-socket.ts @@ -29,7 +29,10 @@ export const pubsub = { }, }; -export async function broadcastQuizState(ws: any, partyId: string) { +export async function broadcastQuizState( + ws: { publish: (topic: string, message: string) => void }, + partyId: string, +) { const partyRecord = await db.query.party.findFirst({ where: { id: partyId }, }); diff --git a/api/src/test/factories.ts b/api/src/test/factories.ts index 93f343a..cff8946 100644 --- a/api/src/test/factories.ts +++ b/api/src/test/factories.ts @@ -331,7 +331,10 @@ export async function seedPartyWithThreeDiverseUsers(): Promise<{ const user = await createUser(`Diverse User ${i}`); userIds.push(user.id); - const genre = genres[i]!; + const genre = genres[i]; + if (!genre) { + throw new Error("Missing genre"); + } const artist = await createArtist(`artist-${i}-${randomUUID()}`, [ genre.id, ]); @@ -349,7 +352,10 @@ export async function seedPartyWithThreeDiverseUsers(): Promise<{ } // Create party with first user as host - const hostId = userIds[0]!; + const hostId = userIds[0]; + if (!hostId) { + throw new Error("Missing host user"); + } const party = await createParty(hostId); for (const userId of userIds) { @@ -358,8 +364,14 @@ export async function seedPartyWithThreeDiverseUsers(): Promise<{ // Give each user their unique track for (let i = 0; i < 3; i++) { - await addTopTrack(userIds[i]!, tracks[i]!.id, 1); - await addTopArtist(userIds[i]!, artists[i]!.id, 1); + const userId = userIds[i]; + const track = tracks[i]; + const artist = artists[i]; + if (!userId || !track || !artist) { + throw new Error("Missing seeded test data"); + } + await addTopTrack(userId, track.id, 1); + await addTopArtist(userId, artist.id, 1); } return { diff --git a/api/src/test/setup.ts b/api/src/test/setup.ts index 007809d..db63867 100644 --- a/api/src/test/setup.ts +++ b/api/src/test/setup.ts @@ -12,7 +12,7 @@ let pool: Pool | null = null; const getPool = (): Pool => { if (!pool) { - pool = new Pool({ connectionString: url! }); + pool = new Pool({ connectionString: url }); } return pool; }; diff --git a/api/src/workflows/__tests__/party-analysis.test.ts b/api/src/workflows/__tests__/party-analysis.test.ts index 1fe6142..c1ab22c 100644 --- a/api/src/workflows/__tests__/party-analysis.test.ts +++ b/api/src/workflows/__tests__/party-analysis.test.ts @@ -1,4 +1,4 @@ -/** biome-ignore-all lint/style/noNonNullAssertion: */ +/** biome-ignore-all lint/style/noNonNullAssertion: test setup uses controlled arrays */ import { DBOS } from "@dbos-inc/dbos-sdk"; import { describe, expect, it } from "vitest"; import { diff --git a/dev-proxy/index.ts b/dev-proxy/index.ts index f7ad30e..956863a 100644 --- a/dev-proxy/index.ts +++ b/dev-proxy/index.ts @@ -10,6 +10,49 @@ type DeviceMessage = { QuizResponse: number; } +type DeviceQuestionData = { + text: string; + points: number; + index: number; + q_type: "Choice" | { Numeric: { min: number; max: number } } +} + +type QuizQuestion = + | { + type: "choice"; + text: string; + points: number; + } + | { + type: "numeric"; + text: string; + points: number; + range: { min: number; max: number }; + }; + +type QuizState = { + status: "running" | "results"; + questionIndex: number; + currentQuestion: QuizQuestion | null; +}; + +type PartyStatusEvent = { + type: "party_status"; + party: { data?: QuizState } | null; +}; + +type QuizStateEvent = { + type: "quiz_state"; + quiz: QuizState; +}; + +type ErrorEvent = { + type: "error"; + message: string; +}; + +type PartySocketEvent = PartyStatusEvent | QuizStateEvent | ErrorEvent; + const sockets = new Map(); const socketIds = new WeakMap(); const apiSocket = new WebSocket("ws://localhost:4000/api/dev-socket/ws"); @@ -26,15 +69,33 @@ function registerSocket(socket: Socket, deviceId: string) { console.log("Registered", socket.remoteAddress, deviceId); } +function toDeviceQuestionData( + quizData: QuizState, +): DeviceQuestionData | null { + if (!quizData.currentQuestion) return null; + const question = quizData.currentQuestion; + const q_type = + question.type === "choice" + ? "Choice" + : { Numeric: { min: question.range.min, max: question.range.max } }; + + return { + text: question.text, + points: question.points, + index: quizData.questionIndex, + q_type, + }; +} + const listener = Bun.listen({ port: 7070, hostname: "0.0.0.0", - socket: { - open(socket) { + socket: { + open(socket) { socket.setKeepAlive(true); console.log("Connection", socket.remoteAddress, socket.remotePort); - }, - data(socket, buf) { + }, + data(socket, buf) { const raw = new TextDecoder().decode(buf).trim(); let data: DeviceMessage; try { @@ -50,16 +111,17 @@ const listener = Bun.listen({ return; } if ("QuizResponse" in data) { - + const deviceId = socketDeviceId(socket); + if (!deviceId) return; + apiSocket?.send( + JSON.stringify({ + type: "device_message", + deviceId, + payload: { QuizResponse: data.QuizResponse }, + }), + ); + return; } - - // apiSocket?.send( - // JSON.stringify({ - // type: "device_message", - // deviceId: currentDeviceId, - // payload: raw, - // }), - // ); }, close(socket) { console.log("Connection", socket.remoteAddress); @@ -82,6 +144,30 @@ apiSocket.onmessage = (e) => { if (message.type !== "device_event") return; const socket = sockets.get(message.deviceId); if (!socket) return; + const event = message.event as PartySocketEvent; + if (event.type === "error") { + socket.write(`${JSON.stringify({ Error: event.message })}\n`); + return; + } + + if (event.type === "party_status") { + const quizData = event.party?.data ?? null; + if (!quizData) return; + const question = toDeviceQuestionData(quizData); + socket.write( + `${JSON.stringify({ Question: question, Status: quizData.status })}\n`, + ); + return; + } + + if (event.type === "quiz_state") { + const question = toDeviceQuestionData(event.quiz); + socket.write( + `${JSON.stringify({ Question: question, Status: event.quiz.status })}\n`, + ); + return; + } + socket.write(`${JSON.stringify(message.event)}\n`); }; diff --git a/esp32/src/main.rs b/esp32/src/main.rs index 134f3f9..9d057c1 100644 --- a/esp32/src/main.rs +++ b/esp32/src/main.rs @@ -7,7 +7,7 @@ use core::str::FromStr; use embassy_executor::Spawner; use embassy_futures::select::{Either, select}; -use embassy_net::tcp::{TcpReader, TcpWriter}; +use embassy_net::tcp::{State, TcpReader, TcpWriter}; use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex; use embassy_sync::mutex::Mutex; use embassy_sync::signal::Signal; @@ -32,6 +32,8 @@ mod screen; pub use input::ANGLE; +use crate::screen::overwrite_lcd; + const WIFI_NETWORK: &str = "flamme"; const WIFI_PASSWORD: &str = "12345678"; const TARGET_IP: &str = "84.238.32.253"; @@ -69,6 +71,13 @@ struct QuestionDataNet<'a> { index: usize, } +#[derive(Deserialize)] +enum ProxyOutput<'a> { + Question(QuestionDataNet<'a>), + Results, + Error(&'a str), +} + impl<'a> From> for QuestionData { fn from(value: QuestionDataNet<'a>) -> Self { QuestionData { @@ -88,6 +97,14 @@ struct WheelData { accumulated: i32, } +#[derive(Clone, Copy)] +enum MainState { + Loading, + Question, + Results, +} + +static MAIN_STATE: Mutex = Mutex::new(MainState::Loading); static QUESTION: Mutex> = Mutex::new(None); static QUESTION_UPDATE: Signal = Signal::new(); static WHEEL_VALUE: Mutex = Mutex::new(WheelData { @@ -146,23 +163,32 @@ pub async fn tcp_read_loop( accumulated: 0, }; if let Some(last) = str.lines().last() { - let Ok(data) = serde_json::from_str::(last) else { + let Ok(data) = serde_json::from_str::(last) else { continue; }; - let data: QuestionData = data.into(); - match data.q_type { - QuestionType::Numeric { min, max } => { - future_wheel.max = max; - future_wheel.min = min; - future_wheel.value = (min + max) / 2; + match data { + ProxyOutput::Question(data) => { + let data: QuestionData = data.into(); + match data.q_type { + QuestionType::Numeric { min, max } => { + future_wheel.max = max; + future_wheel.min = min; + future_wheel.value = (min + max) / 2; + } + _ => {} + }; + question_data = Some(data); } - _ => {} - }; - question_data = Some(data); + ProxyOutput::Results => { + *MAIN_STATE.lock().await = MainState::Results; + } + ProxyOutput::Error(e) => {} + } } if let Some(question_data) = question_data { *QUESTION.lock().await = Some(question_data); + *MAIN_STATE.lock().await = MainState::Question; *WHEEL_VALUE.lock().await = future_wheel; QUESTION_UPDATE.signal(()); } @@ -231,6 +257,19 @@ pub async fn main_loop() { println!("Main loop started"); loop { embassy_time::Timer::after_millis(50).await; + let state = *MAIN_STATE.lock().await; + + match state { + MainState::Loading => { + continue; + } + MainState::Question => {} + MainState::Results => { + overwrite_lcd("Results", "").await; + continue; + } + } + let wheel = *WHEEL_VALUE.lock().await; let question = QUESTION.lock().await; let Some(question) = question.as_ref() else {