From 2d16fb8ecc23e40d43babba8ef6b4b76c743d479 Mon Sep 17 00:00:00 2001 From: Daniel Bulant Date: Tue, 21 Apr 2026 23:00:21 +0200 Subject: [PATCH] refactor socket --- api/src/index.ts | 5 +- api/src/party-data.ts | 106 ++++++++++++++ api/src/party-sockets.ts | 74 ++++++++++ api/src/routes/party-socket.ts | 137 ++++++++++++++++++ api/src/routes/party.ts | 252 ++++----------------------------- 5 files changed, 347 insertions(+), 227 deletions(-) create mode 100644 api/src/party-data.ts create mode 100644 api/src/routes/party-socket.ts diff --git a/api/src/index.ts b/api/src/index.ts index a23f591..ee3f361 100644 --- a/api/src/index.ts +++ b/api/src/index.ts @@ -5,11 +5,14 @@ import { syncApp } from "./routes/sync"; import "./workflows/sync"; import "./dbos.ts"; import { partyApp } from "./routes/party"; +import { partySocketApp } from "./routes/party-socket"; import { statsApp } from "./routes/stats.ts"; const app = new Elysia() .use(betterAuthElysia) - .group("/api", (app) => app.use(syncApp).use(statsApp).use(partyApp)) + .group("/api", (app) => + app.use(syncApp).use(statsApp).use(partyApp).use(partySocketApp), + ) .listen(4000); export type App = typeof app; diff --git a/api/src/party-data.ts b/api/src/party-data.ts new file mode 100644 index 0000000..1ea0dd7 --- /dev/null +++ b/api/src/party-data.ts @@ -0,0 +1,106 @@ +import { eq } from "drizzle-orm"; +import { db } from "./db"; +import { party, partyMember } from "./db/schema"; + +type DbClient = typeof db; +type DbTransaction = Parameters[0] extends ( + tx: infer T, +) => Promise + ? T + : never; +export type DbLike = DbClient | DbTransaction; + +export async function getPartyForUser(userId: string) { + const memberships = await db.query.partyMember.findMany({ + where: { + userId, + }, + with: { + party: true, + }, + limit: 1, + }); + return memberships[0]?.party ?? null; +} + +export async function getMemberRecord(dbClient: DbLike, userId: string) { + return ( + (await dbClient.query.partyMember.findFirst({ + where: { + userId, + }, + })) ?? null + ); +} + +export async function getPartyStatus(partyId: string) { + const partyRecord = await db.query.party.findFirst({ + where: { + id: partyId, + }, + }); + if (!partyRecord) return null; + const members = await db.query.partyMember.findMany({ + where: { + partyId, + }, + with: { + user: true, + }, + orderBy: { + joinedAt: "asc", + }, + }); + return { + party: partyRecord, + members, + }; +} + +export async function cleanupPartyIfEmpty(dbClient: DbLike, partyId: string) { + const members = await dbClient.query.partyMember.findMany({ + where: { + partyId, + }, + limit: 1, + }); + if (members.length > 0) return; + await dbClient.delete(party).where(eq(party.id, partyId)); +} + +export async function leaveParty(dbClient: DbLike, userId: string) { + const member = await getMemberRecord(dbClient, userId); + if (!member) return null; + await dbClient.delete(partyMember).where(eq(partyMember.id, member.id)); + const nextHost = await dbClient.query.partyMember.findFirst({ + where: { + partyId: member.partyId, + }, + orderBy: { + joinedAt: "asc", + }, + }); + let newHostId: string | null = null; + if (nextHost) { + const currentParty = await dbClient.query.party.findFirst({ + where: { + id: member.partyId, + }, + }); + if (currentParty?.hostId === userId) { + await dbClient + .update(party) + .set({ + hostId: nextHost.userId, + lastUpdated: new Date(), + }) + .where(eq(party.id, member.partyId)); + newHostId = nextHost.userId; + } + } + await cleanupPartyIfEmpty(dbClient, member.partyId); + return { + partyId: member.partyId, + newHostId, + }; +} diff --git a/api/src/party-sockets.ts b/api/src/party-sockets.ts index 6db0bd6..fcd4252 100644 --- a/api/src/party-sockets.ts +++ b/api/src/party-sockets.ts @@ -9,6 +9,7 @@ type WebSocketLike = { }; const partySockets = new Map>>(); +const userSockets = new Map>(); function getPartyUserSockets(partyId: string, userId: string) { const partyMap = partySockets.get(partyId); @@ -58,6 +59,44 @@ export function unregisterPartySocket( } } +export function registerUserSocket(userId: string, ws: WebSocketLike) { + let sockets = userSockets.get(userId); + if (!sockets) { + sockets = new Set(); + userSockets.set(userId, sockets); + } + + sockets.add(ws); +} + +export function unregisterUserSocket(userId: string, ws: WebSocketLike) { + const sockets = userSockets.get(userId); + if (!sockets) return; + + sockets.delete(ws); + + if (sockets.size === 0) { + userSockets.delete(userId); + } +} + +export function unregisterUserSocketFromAllParties( + userId: string, + ws: WebSocketLike, +) { + for (const [partyId, partyMap] of partySockets) { + const userSockets = partyMap.get(userId); + if (!userSockets) continue; + userSockets.delete(ws); + if (userSockets.size === 0) { + partyMap.delete(userId); + } + if (partyMap.size === 0) { + partySockets.delete(partyId); + } + } +} + export function broadcastPartyEvent(partyId: string, event: PartySocketEvent) { const partyMap = partySockets.get(partyId); if (!partyMap) return; @@ -83,3 +122,38 @@ export function sendPartyEventToUser( ws.send(payload); } } + +export function sendDirectEventToUser(userId: string, event: PartySocketEvent) { + const sockets = userSockets.get(userId); + if (!sockets) return; + + const payload = JSON.stringify(event); + for (const ws of sockets) { + ws.send(payload); + } +} + +export function reassignUserSocketsToParty( + userId: string, + partyId: string | null, +) { + for (const [existingPartyId, partyMap] of partySockets) { + if (!partyMap.has(userId)) continue; + partyMap.delete(userId); + if (partyMap.size === 0) { + partySockets.delete(existingPartyId); + } + } + + if (!partyId) return; + const sockets = userSockets.get(userId); + if (!sockets) return; + + let partyMap = partySockets.get(partyId); + if (!partyMap) { + partyMap = new Map(); + partySockets.set(partyId, partyMap); + } + + partyMap.set(userId, new Set(sockets)); +} diff --git a/api/src/routes/party-socket.ts b/api/src/routes/party-socket.ts new file mode 100644 index 0000000..a0c36e9 --- /dev/null +++ b/api/src/routes/party-socket.ts @@ -0,0 +1,137 @@ +import Elysia, { t } from "elysia"; +import { auth, betterAuthElysia } from "../auth"; +import { db } from "../db"; +import { getMemberRecord, getPartyStatus } from "../party-data"; +import { + registerPartySocket, + registerUserSocket, + sendPartyEventToUser, + unregisterPartySocket, + unregisterUserSocket, + unregisterUserSocketFromAllParties, +} from "../party-sockets"; + +type PartySocketMessage = + | { + type: "member_payload"; + payload: unknown; + } + | { + type: "ping"; + }; + +const MAX_MEMBER_PAYLOAD_SIZE = 8_000; + +type PartyWsData = { + user?: { id: string }; + partyId?: string | null; +}; + +function getPayloadSize(payload: unknown) { + try { + return JSON.stringify(payload).length; + } catch { + return Infinity; + } +} + +export const partySocketApp = new Elysia() + .use(betterAuthElysia) + .group("/party-socket", (app) => + app.ws("/ws", { + beforeHandle: async ({ request, set }) => { + const session = await auth.api.getSession({ + headers: request.headers, + }); + if (!session) { + set.status = 401; + return; + } + return { + user: session.user, + session: session.session, + }; + }, + open: async (ws) => { + const data = ws.data as unknown as PartyWsData; + const user = data.user; + if (!user) return; + registerUserSocket(user.id, ws); + const membership = await getMemberRecord(db, user.id); + if (!membership) { + ws.send( + JSON.stringify({ + type: "snapshot", + party: null, + members: [], + }), + ); + return; + } + + const snapshot = await getPartyStatus(membership.partyId); + data.partyId = membership.partyId; + registerPartySocket(membership.partyId, user.id, ws); + if (snapshot) { + ws.send( + JSON.stringify({ + type: "snapshot", + party: snapshot.party, + members: snapshot.members, + }), + ); + } + }, + message: async (ws, message: PartySocketMessage) => { + const data = ws.data as unknown as PartyWsData; + const user = data.user; + if (!user) return; + if (message.type === "ping") { + ws.send(JSON.stringify({ type: "pong" })); + return; + } + + if (message.type !== "member_payload") return; + const membership = await getMemberRecord(db, user.id); + if (!membership) return; + + if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) { + ws.send( + JSON.stringify({ + type: "error", + message: "Payload too large.", + }), + ); + return; + } + + const currentParty = await db.query.party.findFirst({ + where: { id: membership.partyId }, + }); + if (!currentParty) return; + + sendPartyEventToUser(membership.partyId, currentParty.hostId, { + type: "member_payload", + fromUserId: user.id, + payload: message.payload, + }); + }, + close: async (ws) => { + const data = ws.data as unknown as PartyWsData; + const user = data.user; + const { partyId } = data; + if (!user) return; + if (!partyId) { + unregisterUserSocketFromAllParties(user.id, ws); + unregisterUserSocket(user.id, ws); + return; + } + unregisterPartySocket(partyId, user.id, ws); + unregisterUserSocket(user.id, ws); + }, + body: t.Union([ + t.Object({ type: t.Literal("ping") }), + t.Object({ type: t.Literal("member_payload"), payload: t.Any() }), + ]), + }), + ); diff --git a/api/src/routes/party.ts b/api/src/routes/party.ts index 4e48af8..2477a53 100644 --- a/api/src/routes/party.ts +++ b/api/src/routes/party.ts @@ -1,92 +1,27 @@ import { and, eq } from "drizzle-orm"; import Elysia, { t } from "elysia"; -import { auth, betterAuthElysia } from "../auth"; +import { betterAuthElysia } from "../auth"; import { db } from "../db"; import { party, partyMember } from "../db/schema"; +import { + cleanupPartyIfEmpty, + getMemberRecord, + getPartyForUser, + getPartyStatus, + leaveParty, +} from "../party-data"; import { broadcastPartyEvent, - registerPartySocket, - sendPartyEventToUser, - unregisterPartySocket, + reassignUserSocketsToParty, + sendDirectEventToUser, } from "../party-sockets"; const PARTY_STATUS = ["created", "started", "ended"] as const; type PartyStatus = (typeof PARTY_STATUS)[number]; -type DbClient = typeof db; -type DbTransaction = Parameters[0] extends ( - tx: infer T, -) => Promise - ? T - : never; -type DbLike = DbClient | DbTransaction; - type PartySnapshot = NonNullable>>; -type PartySocketMessage = - | { - type: "member_payload"; - payload: unknown; - } - | { - type: "ping"; - }; - -const MAX_MEMBER_PAYLOAD_SIZE = 8_000; - -type PartyWsData = { - user?: { id: string }; - partyId?: string; -}; - -async function getPartyForUser(userId: string) { - const memberships = await db.query.partyMember.findMany({ - where: { - userId, - }, - with: { - party: true, - }, - limit: 1, - }); - return memberships[0]?.party ?? null; -} - -async function getMemberRecord(dbClient: DbLike, userId: string) { - return ( - (await dbClient.query.partyMember.findFirst({ - where: { - userId, - }, - })) ?? null - ); -} - -async function getPartyStatus(partyId: string) { - const party = await db.query.party.findFirst({ - where: { - id: partyId, - }, - }); - if (!party) return null; - const members = await db.query.partyMember.findMany({ - where: { - partyId, - }, - with: { - user: true, - }, - orderBy: { - joinedAt: "asc", - }, - }); - return { - party, - members, - }; -} - function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) { if (!snapshot) return; broadcastPartyEvent(partyId, { @@ -96,62 +31,6 @@ function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) { }); } -function getPayloadSize(payload: unknown) { - try { - return JSON.stringify(payload).length; - } catch { - return Infinity; - } -} - -async function cleanupPartyIfEmpty(dbClient: DbLike, partyId: string) { - const members = await dbClient.query.partyMember.findMany({ - where: { - partyId, - }, - limit: 1, - }); - if (members.length > 0) return; - await dbClient.delete(party).where(eq(party.id, partyId)); -} - -async function leaveParty(dbClient: DbLike, userId: string) { - const member = await getMemberRecord(dbClient, userId); - if (!member) return null; - await dbClient.delete(partyMember).where(eq(partyMember.id, member.id)); - const nextHost = await dbClient.query.partyMember.findFirst({ - where: { - partyId: member.partyId, - }, - orderBy: { - joinedAt: "asc", - }, - }); - let newHostId: string | null = null; - if (nextHost) { - const currentParty = await dbClient.query.party.findFirst({ - where: { - id: member.partyId, - }, - }); - if (currentParty?.hostId === userId) { - await dbClient - .update(party) - .set({ - hostId: nextHost.userId, - lastUpdated: new Date(), - }) - .where(eq(party.id, member.partyId)); - newHostId = nextHost.userId; - } - } - await cleanupPartyIfEmpty(dbClient, member.partyId); - return { - partyId: member.partyId, - newHostId, - }; -} - function isValidStatus(status: string): status is PartyStatus { return PARTY_STATUS.includes(status as PartyStatus); } @@ -160,101 +39,6 @@ export const partyApp = new Elysia() .use(betterAuthElysia) .group("/party", (app) => app - .ws("/ws", { - beforeHandle: async ({ request, set }) => { - const session = await auth.api.getSession({ - headers: request.headers, - }); - if (!session) { - set.status = 401; - return; - } - return { - user: session.user, - session: session.session, - }; - }, - open: async (ws) => { - const data = ws.data as unknown as PartyWsData; - const user = data.user; - if (!user) return; - const membership = await getMemberRecord(db, user.id); - if (!membership) { - ws.send( - JSON.stringify({ - type: "error", - message: "You are not in a party.", - }), - ); - ws.close?.(1008, "Not in a party"); - return; - } - - const snapshot = await getPartyStatus(membership.partyId); - data.partyId = membership.partyId; - registerPartySocket(membership.partyId, user.id, ws); - if (snapshot) { - ws.send( - JSON.stringify({ - type: "snapshot", - party: snapshot.party, - members: snapshot.members, - }), - ); - } - }, - message: async (ws, message: PartySocketMessage) => { - const data = ws.data as unknown as PartyWsData; - const user = data.user; - if (!user) return; - if (message.type === "ping") { - ws.send(JSON.stringify({ type: "pong" })); - return; - } - - if (message.type !== "member_payload") return; - const membership = await getMemberRecord(db, user.id); - if (!membership) return; - - if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) { - ws.send( - JSON.stringify({ - type: "error", - message: "Payload too large.", - }), - ); - return; - } - - const currentParty = await db.query.party.findFirst({ - where: { id: membership.partyId }, - }); - if (!currentParty) return; - - sendPartyEventToUser(membership.partyId, currentParty.hostId, { - type: "member_payload", - fromUserId: user.id, - payload: message.payload, - }); - }, - close: async (ws) => { - const data = ws.data as unknown as PartyWsData; - const user = data.user; - const { partyId } = data; - if (!user) return; - if (!partyId) { - const membership = await getMemberRecord(db, user.id); - if (!membership) return; - unregisterPartySocket(membership.partyId, user.id, ws); - return; - } - unregisterPartySocket(partyId, user.id, ws); - }, - body: t.Union([ - t.Object({ type: t.Literal("ping") }), - t.Object({ type: t.Literal("member_payload"), payload: t.Any() }), - ]), - }) .get( "/status", async ({ user }) => { @@ -359,6 +143,20 @@ export const partyApp = new Elysia() userId: user.id, }); broadcastSnapshot(partyId, status); + reassignUserSocketsToParty(user.id, partyId); + reassignUserSocketsToParty(targetUserId, partyId); + if (status) { + sendDirectEventToUser(targetUserId, { + type: "party_status", + party: status.party, + members: status.members, + }); + sendDirectEventToUser(user.id, { + type: "party_status", + party: status.party, + members: status.members, + }); + } return status ?? { party: null, members: [] }; }, { @@ -387,6 +185,7 @@ export const partyApp = new Elysia() }); } broadcastSnapshot(result.partyId, status); + reassignUserSocketsToParty(user.id, null); return status ?? { party: null, members: [] }; }, { auth: true }, @@ -433,6 +232,7 @@ export const partyApp = new Elysia() kickedBy: user.id, }); broadcastSnapshot(currentMembership.partyId, status); + reassignUserSocketsToParty(body.memberUserId, null); return status ?? { party: null, members: [] }; }, {