diff --git a/api/src/db/schema.ts b/api/src/db/schema.ts index 1446d59..143885f 100644 --- a/api/src/db/schema.ts +++ b/api/src/db/schema.ts @@ -52,7 +52,10 @@ export const partyMember = pgTable( joinedAt: timestamp().defaultNow().notNull(), lastSeen: timestamp().defaultNow().notNull(), }, - (partyMember) => [uniqueIndex().on(partyMember.partyId, partyMember.userId)], + (partyMember) => [ + uniqueIndex().on(partyMember.partyId, partyMember.userId), + index().on(partyMember.userId, partyMember.joinedAt), + ], ); export const platform = pgEnum("enum_platform", ["spotify", "apple"]); diff --git a/api/src/index.ts b/api/src/index.ts index 13ab176..11bb024 100644 --- a/api/src/index.ts +++ b/api/src/index.ts @@ -7,7 +7,7 @@ import "./workflows/sync"; import "./workflows/party-analysis"; import "./dbos.ts"; import { partyApp } from "./routes/party"; -import { partySocketApp } from "./routes/party-socket"; +import { partySocketApp, pubsub } from "./routes/party-socket"; import { statsApp } from "./routes/stats.ts"; const app = new Elysia() @@ -23,6 +23,8 @@ const app = new Elysia() ) .listen(4000); +pubsub.setServer(app.server); + export type App = typeof app; await DBOS.launch({ diff --git a/api/src/party-data.ts b/api/src/party-data.ts index eef5210..e2839d3 100644 --- a/api/src/party-data.ts +++ b/api/src/party-data.ts @@ -30,6 +30,9 @@ export async function getMemberRecord(dbClient: DbLike, userId: string) { where: { userId, }, + orderBy: { + joinedAt: "desc", + }, })) ?? null ); } diff --git a/api/src/party-sockets.ts b/api/src/party-sockets.ts deleted file mode 100644 index 796b7b2..0000000 --- a/api/src/party-sockets.ts +++ /dev/null @@ -1,156 +0,0 @@ -import type { PartySocketEvent } from "./party-types"; - -type WebSocketLike = { - send: (data: string) => void; - close?: (code?: number, reason?: string) => void; -}; - -const partySockets = new Map>>(); -const userSockets = new Map>(); - -function getPartyUserSockets(partyId: string, userId: string) { - const partyMap = partySockets.get(partyId); - if (!partyMap) return null; - return partyMap.get(userId) ?? null; -} - -export function registerPartySocket( - partyId: string, - userId: string, - ws: WebSocketLike, -) { - let partyMap = partySockets.get(partyId); - if (!partyMap) { - partyMap = new Map(); - partySockets.set(partyId, partyMap); - } - - let userSockets = partyMap.get(userId); - if (!userSockets) { - userSockets = new Set(); - partyMap.set(userId, userSockets); - } - - userSockets.add(ws); -} - -export function unregisterPartySocket( - partyId: string, - userId: string, - ws: WebSocketLike, -) { - const partyMap = partySockets.get(partyId); - if (!partyMap) return; - - const userSockets = partyMap.get(userId); - if (!userSockets) return; - - userSockets.delete(ws); - - if (userSockets.size === 0) { - partyMap.delete(userId); - } - - if (partyMap.size === 0) { - partySockets.delete(partyId); - } -} - -export function registerUserSocket(userId: string, ws: WebSocketLike) { - let sockets = userSockets.get(userId); - if (!sockets) { - sockets = new Set(); - userSockets.set(userId, sockets); - } - - sockets.add(ws); -} - -export function unregisterUserSocket(userId: string, ws: WebSocketLike) { - const sockets = userSockets.get(userId); - if (!sockets) return; - - sockets.delete(ws); - - if (sockets.size === 0) { - userSockets.delete(userId); - } -} - -export function unregisterUserSocketFromAllParties( - userId: string, - ws: WebSocketLike, -) { - for (const [partyId, partyMap] of partySockets) { - const userSockets = partyMap.get(userId); - if (!userSockets) continue; - userSockets.delete(ws); - if (userSockets.size === 0) { - partyMap.delete(userId); - } - if (partyMap.size === 0) { - partySockets.delete(partyId); - } - } -} - -export function broadcastPartyEvent(partyId: string, event: PartySocketEvent) { - const partyMap = partySockets.get(partyId); - if (!partyMap) return; - - const payload = JSON.stringify(event); - for (const userSockets of partyMap.values()) { - for (const ws of userSockets) { - ws.send(payload); - } - } -} - -export function sendPartyEventToUser( - partyId: string, - userId: string, - event: PartySocketEvent, -) { - const userSockets = getPartyUserSockets(partyId, userId); - if (!userSockets) return; - - const payload = JSON.stringify(event); - for (const ws of userSockets) { - ws.send(payload); - } -} - -export function sendDirectEventToUser(userId: string, event: PartySocketEvent) { - const sockets = userSockets.get(userId); - if (!sockets) return; - - const payload = JSON.stringify(event); - for (const ws of sockets) { - ws.send(payload); - } -} - -export function reassignUserSocketsToParty( - userId: string, - partyId: string | null, -) { - for (const [existingPartyId, partyMap] of partySockets) { - if (!partyMap.has(userId)) continue; - partyMap.delete(userId); - if (partyMap.size === 0) { - partySockets.delete(existingPartyId); - } - } - - if (!partyId) return; - const sockets = userSockets.get(userId); - if (!sockets) return; - - let partyMap = partySockets.get(partyId); - if (!partyMap) { - partyMap = new Map(); - partySockets.set(partyId, partyMap); - } - - partyMap.set(userId, new Set(sockets)); -} diff --git a/api/src/party-types.ts b/api/src/party-types.ts index 9ff318d..3d3a1dd 100644 --- a/api/src/party-types.ts +++ b/api/src/party-types.ts @@ -27,9 +27,6 @@ export type PartySocketOutgoing = export type PartySocketEvent = | { type: "snapshot"; party: Party | null; members: PartyMemberWithUser[] } | { type: "party_status"; party: Party; members: PartyMemberWithUser[] } - | { type: "member_joined"; userId: string } - | { type: "member_left"; userId: string; kickedBy?: string } - | { type: "host_changed"; hostId: string } | { type: "member_payload"; fromUserId: string; payload: unknown } | { type: "error"; message: string } | { type: "pong" }; diff --git a/api/src/routes/party-socket.ts b/api/src/routes/party-socket.ts index f1040d9..13d75ff 100644 --- a/api/src/routes/party-socket.ts +++ b/api/src/routes/party-socket.ts @@ -1,30 +1,34 @@ -import Elysia, { t } from "elysia"; -import { auth, betterAuthElysia } from "../auth"; +import { Elysia } from "elysia"; + +import { betterAuthElysia } from "../auth"; + import { db } from "../db"; import { getMemberRecord, getPartyStatus } from "../party-data"; -import { - registerPartySocket, - registerUserSocket, - sendPartyEventToUser, - unregisterPartySocket, - unregisterUserSocket, - unregisterUserSocketFromAllParties, -} from "../party-sockets"; -import type { PartySocketOutgoing } from "../party-types"; -const MAX_MEMBER_PAYLOAD_SIZE = 8_000; +function userTopic(userId: string) { + return `user:${userId}`; +} -type PartyWsData = { - partyId?: string | null; +function partyTopic(partyId: string) { + return `party:${partyId}`; +} + +const socketPartyId = new WeakMap(); + +export const pubsub = { + _server: null as ReturnType | null, + setServer(server: ReturnType | null) { + this._server = server; + }, + publish(topic: string, data: string) { + this._server?.publish(topic, data); + }, }; -function getPayloadSize(payload: unknown) { - try { - return JSON.stringify(payload).length; - } catch { - return Infinity; - } -} +export const topic = { + user: userTopic, + party: partyTopic, +}; export const partySocketApp = new Elysia() .use(betterAuthElysia) @@ -33,10 +37,13 @@ export const partySocketApp = new Elysia() .get("/test", () => ({ ok: 1 })) .ws("/ws", { auth: true, + publishToSelf: true, open: async (ws) => { const user = ws.data.user; if (!user) return; - registerUserSocket(user.id, ws); + + ws.subscribe(userTopic(user.id)); + const membership = await getMemberRecord(db, user.id); if (!membership) { ws.send( @@ -49,9 +56,10 @@ export const partySocketApp = new Elysia() return; } + socketPartyId.set(ws, membership.partyId); + ws.subscribe(partyTopic(membership.partyId)); + const snapshot = await getPartyStatus(membership.partyId); - ws.data.partyId = membership.partyId; - registerPartySocket(membership.partyId, user.id, ws); if (snapshot) { ws.send( JSON.stringify({ @@ -62,56 +70,64 @@ export const partySocketApp = new Elysia() ); } }, - message: async (ws, message: PartySocketOutgoing) => { + message: async (ws, message) => { const data = ws.data; const user = data.user; if (!user) return; - if (message.type === "ping") { + + if (typeof message !== "string") return; + + let parsed: { type: string; payload?: unknown }; + try { + parsed = JSON.parse(message); + } catch { + ws.send(JSON.stringify({ type: "error", message: "Invalid JSON" })); + return; + } + + if (parsed.type === "ping") { ws.send(JSON.stringify({ type: "pong" })); return; } - if (message.type !== "member_payload") return; - const membership = await getMemberRecord(db, user.id); - if (!membership) return; + if (parsed.type !== "member_payload") return; - if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) { + const MAX_MEMBER_PAYLOAD_SIZE = 8_000; + const payloadString = JSON.stringify(parsed.payload); + if (payloadString.length > MAX_MEMBER_PAYLOAD_SIZE) { ws.send( - JSON.stringify({ - type: "error", - message: "Payload too large.", - }), + JSON.stringify({ type: "error", message: "Payload too large." }), ); return; } + const membership = await getMemberRecord(db, user.id); + if (!membership) return; + const currentParty = await db.query.party.findFirst({ where: { id: membership.partyId }, }); if (!currentParty) return; - sendPartyEventToUser(membership.partyId, currentParty.hostId, { - type: "member_payload", - fromUserId: user.id, - payload: message.payload, - }); + ws.publish( + partyTopic(membership.partyId), + JSON.stringify({ + type: "member_payload", + fromUserId: user.id, + payload: parsed.payload, + }), + ); }, close: async (ws) => { - const data = ws.data; - const user = data.user; - const { partyId } = data; + const user = ws.data.user; if (!user) return; - if (!partyId) { - unregisterUserSocketFromAllParties(user.id, ws); - unregisterUserSocket(user.id, ws); - return; - } - unregisterPartySocket(partyId, user.id, ws); - unregisterUserSocket(user.id, ws); + + ws.unsubscribe(userTopic(user.id)); + + const partyId = socketPartyId.get(ws); + if (!partyId) return; + + ws.unsubscribe(partyTopic(partyId)); }, - body: t.Union([ - t.Object({ type: t.Literal("ping") }), - t.Object({ type: t.Literal("member_payload"), payload: t.Any() }), - ]), }), ); diff --git a/api/src/routes/party.ts b/api/src/routes/party.ts index 350d49f..96477b1 100644 --- a/api/src/routes/party.ts +++ b/api/src/routes/party.ts @@ -10,28 +10,29 @@ import { getPartyStatus, leaveParty, } from "../party-data"; -import { - broadcastPartyEvent, - reassignUserSocketsToParty, - sendDirectEventToUser, -} from "../party-sockets"; -import { - PARTY_STATUS, - type PartySnapshot, - type PartyStatus, -} from "../party-types"; +import type { PartySnapshot } from "../party-types"; +import { pubsub, topic } from "./party-socket"; function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) { if (!snapshot) return; - broadcastPartyEvent(partyId, { - type: "party_status", - party: snapshot.party, - members: snapshot.members, - }); + pubsub.publish( + topic.party(partyId), + JSON.stringify({ + type: "party_status", + party: snapshot.party, + members: snapshot.members, + }), + ); } -function isValidStatus(status: string): status is PartyStatus { - return PARTY_STATUS.includes(status as PartyStatus); +function broadcastToUser(userId: string, event: Record) { + pubsub.publish(topic.user(userId), JSON.stringify(event)); +} + +function isValidStatus( + status: string, +): status is import("../party-types").PartyStatus { + return ["created", "started", "ended"].includes(status); } export const partyApp = new Elysia() @@ -126,31 +127,19 @@ export const partyApp = new Elysia() if (!partyId) return { party: null, members: [] }; const status = await getPartyStatus(partyId); if (leaveResult?.newHostId) { - broadcastPartyEvent(leaveResult.partyId, { - type: "host_changed", - hostId: leaveResult.newHostId, - }); + broadcastSnapshot(leaveResult.partyId, status); } if (hostChanged) { - broadcastPartyEvent(partyId, { - type: "host_changed", - hostId: targetUserId, - }); + broadcastSnapshot(partyId, status); } - broadcastPartyEvent(partyId, { - type: "member_joined", - userId: user.id, - }); broadcastSnapshot(partyId, status); - reassignUserSocketsToParty(user.id, partyId); - reassignUserSocketsToParty(targetUserId, partyId); if (status) { - sendDirectEventToUser(targetUserId, { + broadcastToUser(targetUserId, { type: "party_status", party: status.party, members: status.members, }); - sendDirectEventToUser(user.id, { + broadcastToUser(user.id, { type: "party_status", party: status.party, members: status.members, @@ -173,18 +162,7 @@ export const partyApp = new Elysia() }); if (!result) return { party: null, members: [] }; const status = await getPartyStatus(result.partyId); - broadcastPartyEvent(result.partyId, { - type: "member_left", - userId: user.id, - }); - if (result.newHostId) { - broadcastPartyEvent(result.partyId, { - type: "host_changed", - hostId: result.newHostId, - }); - } broadcastSnapshot(result.partyId, status); - reassignUserSocketsToParty(user.id, null); return status ?? { party: null, members: [] }; }, { auth: true }, @@ -225,13 +203,7 @@ export const partyApp = new Elysia() await cleanupPartyIfEmpty(tx, currentMembership.partyId); }); const status = await getPartyStatus(currentMembership.partyId); - broadcastPartyEvent(currentMembership.partyId, { - type: "member_left", - userId: body.memberUserId, - kickedBy: user.id, - }); broadcastSnapshot(currentMembership.partyId, status); - reassignUserSocketsToParty(body.memberUserId, null); return status ?? { party: null, members: [] }; }, { diff --git a/web/src/hooks/use-party.ts b/web/src/hooks/use-party.ts index 6f3e232..6ea1243 100644 --- a/web/src/hooks/use-party.ts +++ b/web/src/hooks/use-party.ts @@ -13,24 +13,8 @@ function reducePartyState( ): PartyState { switch (event.type) { case "snapshot": - return { party: event.party, members: event.members }; case "party_status": return { party: event.party, members: event.members }; - case "member_joined": - return state; - case "member_left": - return { - ...state, - members: state.members.filter( - (m: PartyMember) => m.userId !== event.userId, - ), - }; - case "host_changed": - if (!state.party) return state; - return { - ...state, - party: { ...state.party, hostId: event.hostId }, - }; case "member_payload": case "pong": case "error":