Compare commits

..

2 commits

Author SHA1 Message Date
Daniel Bulant
b14ac917d6
small improvements 2026-05-12 16:39:01 +02:00
Daniel Bulant
2bcdb34515
add serde 2026-05-12 15:55:51 +02:00
14 changed files with 401 additions and 123 deletions

View file

@ -6,11 +6,11 @@ import { syncApp } from "./routes/sync";
import "./workflows/sync";
import "./workflows/party-analysis";
import "./dbos.ts";
import { deviceClaimApp, deviceSocketApp } from "./routes/device-socket.ts";
import { partyApp } from "./routes/party";
import { partySocketApp, pubsub } from "./routes/party-socket";
import { quizRoutes } from "./routes/quiz.ts";
import { statsApp } from "./routes/stats.ts";
import { deviceClaimApp, deviceSocketApp } from "./routes/device-socket.ts";
const app = new Elysia()
.use(betterAuthElysia)

View file

@ -1,6 +1,9 @@
import { and, eq, inArray } from "drizzle-orm";
import type { db } from "../db";
import { topArtist as topArtistTable, topTrack as topTrackTable } from "../db/schema";
import {
topArtist as topArtistTable,
topTrack as topTrackTable,
} from "../db/schema";
import type { Question } from "../party-types";
import {
buildQuestionWindow,
@ -53,13 +56,20 @@ async function countTopTrackListeners({
}: BuildNumericQuestionInput): Promise<NumericQuestion | null> {
const trackName = analytics?.storyClusters?.[0]?.tracks?.[0]?.name;
if (!trackName || members.length === 0) return null;
const dbTrack = await db.query.track.findFirst({ where: { name: trackName } });
const dbTrack = await db.query.track.findFirst({
where: { name: trackName },
});
if (!dbTrack) return null;
const memberIds = members.map((m) => m.userId);
const entries = await db
.select({ userId: topTrackTable.userId })
.from(topTrackTable)
.where(and(eq(topTrackTable.trackId, dbTrack.id), inArray(topTrackTable.userId, memberIds)));
.where(
and(
eq(topTrackTable.trackId, dbTrack.id),
inArray(topTrackTable.userId, memberIds),
),
);
const correct = new Set(entries.map((e) => e.userId)).size;
return {
type: "numeric",
@ -77,13 +87,20 @@ async function countFavouriteArtistListeners({
}: BuildNumericQuestionInput): Promise<NumericQuestion | null> {
const artistName = analytics?.storyClusters?.[0]?.artists?.[0]?.name;
if (!artistName || members.length === 0) return null;
const dbArtist = await db.query.artist.findFirst({ where: { name: artistName } });
const dbArtist = await db.query.artist.findFirst({
where: { name: artistName },
});
if (!dbArtist) return null;
const memberIds = members.map((m) => m.userId);
const entries = await db
.select({ userId: topArtistTable.userId })
.from(topArtistTable)
.where(and(eq(topArtistTable.artistId, dbArtist.id), inArray(topArtistTable.userId, memberIds)));
.where(
and(
eq(topArtistTable.artistId, dbArtist.id),
inArray(topArtistTable.userId, memberIds),
),
);
const correct = new Set(entries.map((e) => e.userId)).size;
return {
type: "numeric",
@ -110,4 +127,4 @@ export async function buildNumericQuestion(
const question = questions[input.index % questions.length] ?? questions[0];
if (!question) throw new Error("Question not found");
return buildQuestionWindow(question);
}
}

View file

@ -242,7 +242,10 @@ export function buildMemberPairOptions(
const pairs: string[] = [correctPair];
for (let i = 0; i < members.length; i++) {
for (let j = i + 1; j < members.length; j++) {
const pair = `${members[i]!.name} & ${members[j]!.name}`;
const left = members[i];
const right = members[j];
if (!left || !right) continue;
const pair = `${left.name} & ${right.name}`;
if (pair !== correctPair) pairs.push(pair);
}
}

View file

@ -1,7 +1,8 @@
import { eq } from "drizzle-orm";
import type { db as Db } from "../db";
import { party } from "../db/schema";
import type { QuizState } from "../party-types";
import type { PartySocketEvent, QuizState } from "../party-types";
import { publishDeviceEventForUser } from "../routes/device-socket";
import { pubsub } from "../routes/party-socket";
export async function updatePartyData(
@ -32,6 +33,19 @@ export async function updatePartyData(
},
members,
});
const event: PartySocketEvent = {
type: "party_status",
party: {
...partyObject,
data,
},
members,
};
for (const member of members) {
if (!member.userId) continue;
void publishDeviceEventForUser(member.userId, event);
}
await db
.update(party)
.set({

View file

@ -1,10 +1,11 @@
import { DBOS } from "@dbos-inc/dbos-sdk";
import { eq } from "drizzle-orm";
import Elysia from "elysia";
import { betterAuthElysia } from "../auth";
import { db } from "../db";
import { deviceConnection } from "../db/schema";
import { getMemberRecord } from "../party-data";
import type { PartySocketEvent } from "../party-types";
import type { PartySocketEvent, QuizState } from "../party-types";
import { pubsub, topic } from "./party-socket";
type DeviceSocketMessage =
@ -12,6 +13,10 @@ type DeviceSocketMessage =
| { type: "hello" }
| { type: "device_event"; deviceId: string; event: PartySocketEvent };
type DeviceQuizResponsePayload = {
QuizResponse: number;
};
let devProxySocket: WebSocket | null = null;
function isDeviceMessage(
@ -25,6 +30,29 @@ function isDeviceMessage(
);
}
function isDeviceQuizResponsePayload(
value: unknown,
): value is DeviceQuizResponsePayload {
return (
typeof value === "object" &&
value !== null &&
"QuizResponse" in value &&
Number.isInteger((value as DeviceQuizResponsePayload).QuizResponse)
);
}
function sendDeviceEvent(deviceId: string, event: PartySocketEvent) {
if (!devProxySocket || devProxySocket.readyState !== WebSocket.OPEN) return;
devProxySocket.send(
JSON.stringify({
type: "device_event",
deviceId,
event,
} satisfies DeviceSocketMessage),
);
}
export async function claimDeviceForUser(deviceId: string, userId: string) {
await db
.insert(deviceConnection)
@ -55,29 +83,72 @@ export async function publishDeviceEventForUser(
.where(eq(deviceConnection.userId, userId));
for (const device of devices) {
devProxySocket.send(
JSON.stringify({
type: "device_event",
deviceId: device.id,
event,
} satisfies DeviceSocketMessage),
);
sendDeviceEvent(device.id, event);
}
}
async function forwardDevicePayload(deviceId: string, payload: unknown) {
if (!isDeviceQuizResponsePayload(payload)) {
sendDeviceEvent(deviceId, {
type: "error",
message: "Unsupported device payload.",
});
return;
}
const device = await db
.select()
.from(deviceConnection)
.where(eq(deviceConnection.id, deviceId))
.then((rows) => rows[0]);
if (!device) return;
if (!device) {
sendDeviceEvent(deviceId, {
type: "error",
message: "Device not linked to a user.",
});
return;
}
const membership = await getMemberRecord(db, device.userId);
if (!membership) return;
if (!membership) {
sendDeviceEvent(deviceId, {
type: "error",
message: "Device not linked to a party member.",
});
return;
}
const payloadString = JSON.stringify(payload);
if (payloadString.length > 8_000) return;
const party = await db.query.party.findFirst({
where: { id: membership.partyId },
});
if (!party) {
sendDeviceEvent(deviceId, {
type: "error",
message: "Party not found.",
});
return;
}
const quizData = party.data as QuizState | null;
if (!quizData || quizData.status !== "running") {
sendDeviceEvent(deviceId, {
type: "error",
message: "Quiz not running.",
});
return;
}
if (!quizData.workflowId) {
sendDeviceEvent(deviceId, {
type: "error",
message: "Workflow ID not found.",
});
return;
}
await DBOS.send(
quizData.workflowId,
{ playerId: device.userId, selected: payload.QuizResponse },
"quiz_responses",
);
pubsub.publish(
topic.party(membership.partyId),

View file

@ -29,7 +29,10 @@ export const pubsub = {
},
};
export async function broadcastQuizState(ws: any, partyId: string) {
export async function broadcastQuizState(
ws: { publish: (topic: string, message: string) => void },
partyId: string,
) {
const partyRecord = await db.query.party.findFirst({
where: { id: partyId },
});

View file

@ -331,7 +331,10 @@ export async function seedPartyWithThreeDiverseUsers(): Promise<{
const user = await createUser(`Diverse User ${i}`);
userIds.push(user.id);
const genre = genres[i]!;
const genre = genres[i];
if (!genre) {
throw new Error("Missing genre");
}
const artist = await createArtist(`artist-${i}-${randomUUID()}`, [
genre.id,
]);
@ -349,7 +352,10 @@ export async function seedPartyWithThreeDiverseUsers(): Promise<{
}
// Create party with first user as host
const hostId = userIds[0]!;
const hostId = userIds[0];
if (!hostId) {
throw new Error("Missing host user");
}
const party = await createParty(hostId);
for (const userId of userIds) {
@ -358,8 +364,14 @@ export async function seedPartyWithThreeDiverseUsers(): Promise<{
// Give each user their unique track
for (let i = 0; i < 3; i++) {
await addTopTrack(userIds[i]!, tracks[i]!.id, 1);
await addTopArtist(userIds[i]!, artists[i]!.id, 1);
const userId = userIds[i];
const track = tracks[i];
const artist = artists[i];
if (!userId || !track || !artist) {
throw new Error("Missing seeded test data");
}
await addTopTrack(userId, track.id, 1);
await addTopArtist(userId, artist.id, 1);
}
return {

View file

@ -12,7 +12,7 @@ let pool: Pool | null = null;
const getPool = (): Pool => {
if (!pool) {
pool = new Pool({ connectionString: url! });
pool = new Pool({ connectionString: url });
}
return pool;
};

View file

@ -1,4 +1,4 @@
/** biome-ignore-all lint/style/noNonNullAssertion: <explanation> */
/** biome-ignore-all lint/style/noNonNullAssertion: test setup uses controlled arrays */
import { DBOS } from "@dbos-inc/dbos-sdk";
import { describe, expect, it } from "vitest";
import {

View file

@ -4,6 +4,55 @@ type ApiEnvelope =
| { type: "hello" }
| { type: "device_event"; deviceId: string; event: unknown };
type DeviceMessage = {
DeviceId: string;
} | {
QuizResponse: number;
}
type DeviceQuestionData = {
text: string;
points: number;
index: number;
q_type: "Choice" | { Numeric: { min: number; max: number } }
}
type QuizQuestion =
| {
type: "choice";
text: string;
points: number;
}
| {
type: "numeric";
text: string;
points: number;
range: { min: number; max: number };
};
type QuizState = {
status: "running" | "results";
questionIndex: number;
currentQuestion: QuizQuestion | null;
};
type PartyStatusEvent = {
type: "party_status";
party: { data?: QuizState } | null;
};
type QuizStateEvent = {
type: "quiz_state";
quiz: QuizState;
};
type ErrorEvent = {
type: "error";
message: string;
};
type PartySocketEvent = PartyStatusEvent | QuizStateEvent | ErrorEvent;
const sockets = new Map<string, Socket>();
const socketIds = new WeakMap<Socket, string>();
const apiSocket = new WebSocket("ws://localhost:4000/api/dev-socket/ws");
@ -16,35 +65,66 @@ function registerSocket(socket: Socket, deviceId: string) {
const existing = sockets.get(deviceId);
if (existing && existing !== socket) existing.end();
sockets.set(deviceId, socket);
socketIds.set(socket, deviceId);
socketIds.set(socket, deviceId);
console.log("Registered", socket.remoteAddress, deviceId);
}
function toDeviceQuestionData(
quizData: QuizState,
): DeviceQuestionData | null {
if (!quizData.currentQuestion) return null;
const question = quizData.currentQuestion;
const q_type =
question.type === "choice"
? "Choice"
: { Numeric: { min: question.range.min, max: question.range.max } };
return {
text: question.text,
points: question.points,
index: quizData.questionIndex,
q_type,
};
}
const listener = Bun.listen({
port: 7070,
hostname: "0.0.0.0",
socket: {
open(socket) {
socket.setKeepAlive(true);
},
data(socket, buf) {
const raw = new TextDecoder().decode(buf).trim();
if (!raw) return;
socket: {
open(socket) {
socket.setKeepAlive(true);
console.log("Connection", socket.remoteAddress, socket.remotePort);
},
data(socket, buf) {
const raw = new TextDecoder().decode(buf).trim();
let data: DeviceMessage;
try {
data = JSON.parse(raw);
} catch {
return;
}
console.log("Data", socket.remoteAddress, data);
if (!data) return;
const currentDeviceId = socketDeviceId(socket);
if (!currentDeviceId) {
registerSocket(socket, raw);
return;
}
apiSocket?.send(
JSON.stringify({
type: "device_message",
deviceId: currentDeviceId,
payload: raw,
}),
);
if ("DeviceId" in data) {
registerSocket(socket, data.DeviceId);
return;
}
if ("QuizResponse" in data) {
const deviceId = socketDeviceId(socket);
if (!deviceId) return;
apiSocket?.send(
JSON.stringify({
type: "device_message",
deviceId,
payload: { QuizResponse: data.QuizResponse },
}),
);
return;
}
},
close(socket) {
console.log("Connection", socket.remoteAddress);
const deviceId = socketDeviceId(socket);
if (deviceId && sockets.get(deviceId) === socket) {
sockets.delete(deviceId);
@ -64,9 +144,35 @@ apiSocket.onmessage = (e) => {
if (message.type !== "device_event") return;
const socket = sockets.get(message.deviceId);
if (!socket) return;
const event = message.event as PartySocketEvent;
if (event.type === "error") {
socket.write(`${JSON.stringify({ Error: event.message })}\n`);
return;
}
if (event.type === "party_status") {
const quizData = event.party?.data ?? null;
if (!quizData) return;
const question = toDeviceQuestionData(quizData);
socket.write(
`${JSON.stringify({ Question: question, Status: quizData.status })}\n`,
);
return;
}
if (event.type === "quiz_state") {
const question = toDeviceQuestionData(event.quiz);
socket.write(
`${JSON.stringify({ Question: question, Status: event.quiz.status })}\n`,
);
return;
}
socket.write(`${JSON.stringify(message.event)}\n`);
};
apiSocket.onerror = (error) => {
console.error(error);
};

21
esp32/Cargo.lock generated
View file

@ -1019,6 +1019,8 @@ dependencies = [
"owned_str",
"panic-probe",
"portable-atomic",
"serde",
"serde_json",
"static_cell",
"ufmt 0.2.0",
]
@ -1698,6 +1700,19 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "serde_json"
version = "1.0.149"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
dependencies = [
"itoa",
"memchr",
"serde",
"serde_core",
"zmij",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
@ -2178,3 +2193,9 @@ dependencies = [
"quote",
"syn 2.0.117",
]
[[package]]
name = "zmij"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"

View file

@ -12,6 +12,8 @@ codegen-units = 1
[dependencies]
[target.'cfg(target_arch = "xtensa")'.dependencies]
serde = { version = "1.0.228", default-features = false, features = ["derive", "alloc"] }
serde_json = { version = "1.0", default-features = false, features = ["alloc"] }
arrayvec = { version = "0.7.6", default-features = false }
ag-lcd = { version = "0.3", features = ["ufmt"] }
as5600 = "0.8.0"

View file

@ -7,7 +7,7 @@ use core::str::FromStr;
use embassy_executor::Spawner;
use embassy_futures::select::{Either, select};
use embassy_net::tcp::{TcpReader, TcpWriter};
use embassy_net::tcp::{State, TcpReader, TcpWriter};
use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
use embassy_sync::mutex::Mutex;
use embassy_sync::signal::Signal;
@ -21,6 +21,7 @@ use esp_hal::{
};
use esp_println::println;
use owned_str::OwnedStr;
use serde::{Deserialize, Serialize};
use ufmt::uwrite;
mod buffer;
@ -41,6 +42,7 @@ const WHEEL_PRECISION: i32 = 32;
const WHEEL_INVERTED: bool = false;
const DEVICE_ID: &str = "esp32-1";
#[derive(Deserialize)]
enum QuestionType {
Choice,
Numeric { min: i32, max: i32 },
@ -58,7 +60,33 @@ struct QuestionData {
text: OwnedStr<256>,
q_type: QuestionType,
points: i32,
generation: usize,
index: usize,
}
#[derive(Deserialize)]
struct QuestionDataNet<'a> {
text: &'a str,
q_type: QuestionType,
points: i32,
index: usize,
}
#[derive(Deserialize)]
enum ProxyOutput<'a> {
Question(QuestionDataNet<'a>),
Results,
Error(&'a str),
}
impl<'a> From<QuestionDataNet<'a>> for QuestionData {
fn from(value: QuestionDataNet<'a>) -> Self {
QuestionData {
text: OwnedStr::from_str(value.text).unwrap(),
q_type: value.q_type,
points: value.points,
index: value.index,
}
}
}
#[derive(Clone, Copy)]
@ -69,6 +97,14 @@ struct WheelData {
accumulated: i32,
}
#[derive(Clone, Copy)]
enum MainState {
Loading,
Question,
Results,
}
static MAIN_STATE: Mutex<CriticalSectionRawMutex, MainState> = Mutex::new(MainState::Loading);
static QUESTION: Mutex<CriticalSectionRawMutex, Option<QuestionData>> = Mutex::new(None);
static QUESTION_UPDATE: Signal<CriticalSectionRawMutex, ()> = Signal::new();
static WHEEL_VALUE: Mutex<CriticalSectionRawMutex, WheelData> = Mutex::new(WheelData {
@ -99,7 +135,6 @@ pub async fn tcp_read_loop(
cancel: &Signal<CriticalSectionRawMutex, ()>,
) -> Result<(), TcpDisconnect> {
let mut buf = [0u8; 1024];
let mut generation = 1;
loop {
let read_fut = read.read(&mut buf);
@ -120,7 +155,6 @@ pub async fn tcp_read_loop(
let Ok(str) = core::str::from_utf8(&buf[..len]) else {
continue;
};
let mut counter = 0;
let mut question_data = None;
let mut future_wheel = WheelData {
value: 0,
@ -128,89 +162,61 @@ pub async fn tcp_read_loop(
max: 0,
accumulated: 0,
};
for line in str.lines() {
if line == "##" {
overwrite_lcd("Waiting", DEVICE_ID).await;
break;
}
if line == "$$" {
counter = 1;
if let Some(last) = str.lines().last() {
let Ok(data) = serde_json::from_str::<ProxyOutput>(last) else {
continue;
}
if counter == 1 {
let mut q_type = QuestionType::Choice;
let mut points = -1;
for pairs in line.split(' ') {
let (key, value) = pairs.split_once('=').unwrap();
if key == "type" {
q_type = if value == "choice" {
QuestionType::Choice
} else {
QuestionType::Numeric { min: -1, max: -1 }
};
}
if key == "points" {
points = value.parse().unwrap();
}
if key == "rangeMin" || key == "rangeMax" {
match q_type {
QuestionType::Choice => {}
QuestionType::Numeric {
ref mut min,
ref mut max,
} => {
if key == "rangeMin" {
*min = value.parse().unwrap();
future_wheel.min = *min;
} else {
*max = value.parse().unwrap();
future_wheel.max = *max;
}
}
};
match data {
ProxyOutput::Question(data) => {
let data: QuestionData = data.into();
match data.q_type {
QuestionType::Numeric { min, max } => {
future_wheel.max = max;
future_wheel.min = min;
future_wheel.value = (min + max) / 2;
}
}
_ => {}
};
question_data = Some(data);
}
if let QuestionType::Numeric { min, max } = q_type {
let diff = max - min;
future_wheel.value = min + diff / 2;
future_wheel.accumulated = 0;
ProxyOutput::Results => {
*MAIN_STATE.lock().await = MainState::Results;
}
question_data = Some(QuestionData {
text: OwnedStr::new(),
q_type,
points,
generation,
});
counter = 2;
continue;
}
if counter == 2 {
question_data.as_mut().unwrap().text = OwnedStr::from_str(line).unwrap();
generation += 1;
counter = 0;
ProxyOutput::Error(e) => {}
}
}
if let Some(question_data) = question_data {
*QUESTION.lock().await = Some(question_data);
*MAIN_STATE.lock().await = MainState::Question;
*WHEEL_VALUE.lock().await = future_wheel;
QUESTION_UPDATE.signal(());
}
}
}
#[derive(Serialize)]
enum WriteType<'a> {
QuizResponse(i32),
DeviceId(&'a str),
}
pub async fn tcp_write_loop(
mut write: TcpWriter<'_>,
cancel: &Signal<CriticalSectionRawMutex, ()>,
) -> Result<(), TcpDisconnect> {
if write.write(DEVICE_ID.as_bytes()).await.is_err() {
if write
.write(
serde_json::to_string(&WriteType::DeviceId(DEVICE_ID))
.unwrap()
.as_bytes(),
)
.await
.is_err()
{
cancel.signal(());
return Err(TcpDisconnect::WriteError);
}
let mut buffer = buffer::WriteBuffer::<256>::new();
loop {
let input_fut = input::INPUT.receive();
let cancel_fut = cancel.wait();
@ -219,14 +225,24 @@ pub async fn tcp_write_loop(
Either::Second(()) => return Err(TcpDisconnect::Cancelled),
};
println!("button={}", data);
let angle = *ANGLE.lock().await;
core::writeln!(buffer, "button={} angle={}", data, angle).ok();
let value = {
let question = QUESTION.lock().await;
let wheel = *WHEEL_VALUE.lock().await;
match question.as_ref() {
Some(q) => match q.q_type {
QuestionType::Numeric { .. } => wheel.value,
QuestionType::Choice => data as _,
},
_ => data as _,
}
};
let data = WriteType::QuizResponse(value);
let buffer = serde_json::to_string(&data).unwrap();
println!("write: {}", &buffer);
if write.write(&buffer).await.is_err() {
if write.write(buffer.as_bytes()).await.is_err() {
cancel.signal(());
return Err(TcpDisconnect::WriteError);
}
buffer.clear();
}
}
@ -236,19 +252,32 @@ const DOT: char = char::from_u32(0b1010_0101).unwrap();
#[embassy_executor::task]
pub async fn main_loop() {
let mut last_gen = 0;
let mut last_index = 0;
let mut title_offset = 0;
println!("Main loop started");
loop {
embassy_time::Timer::after_millis(50).await;
let state = *MAIN_STATE.lock().await;
match state {
MainState::Loading => {
continue;
}
MainState::Question => {}
MainState::Results => {
overwrite_lcd("Results", "").await;
continue;
}
}
let wheel = *WHEEL_VALUE.lock().await;
let question = QUESTION.lock().await;
let Some(question) = question.as_ref() else {
continue;
};
title_offset += 1;
if question.generation != last_gen {
last_gen = question.generation;
if question.index != last_index {
last_index = question.index;
title_offset = 0;
}
let title_line = if question.text.len() > 16 {

View file

@ -21,8 +21,8 @@ pub async fn overwrite_lcd(line1: &str, line2: &str) {
let mut buffer = SCREEN_BUFFER.lock().await;
buffer.line1_ptr = 0;
buffer.line2_ptr = 0;
buffer.line1.fill(0);
buffer.line2.fill(0);
buffer.line1.fill(32);
buffer.line2.fill(32);
let len1 = line1.len().min(buffer.line1.len());
let len2 = line2.len().min(buffer.line2.len());
buffer.line1[..len1].copy_from_slice(line1[..len1].as_bytes());