working improved sockets

This commit is contained in:
Daniel Bulant 2026-04-29 22:36:48 +02:00
parent 9072ce76ba
commit 58878752a8
No known key found for this signature in database
8 changed files with 101 additions and 280 deletions

View file

@ -52,7 +52,10 @@ export const partyMember = pgTable(
joinedAt: timestamp().defaultNow().notNull(),
lastSeen: timestamp().defaultNow().notNull(),
},
(partyMember) => [uniqueIndex().on(partyMember.partyId, partyMember.userId)],
(partyMember) => [
uniqueIndex().on(partyMember.partyId, partyMember.userId),
index().on(partyMember.userId, partyMember.joinedAt),
],
);
export const platform = pgEnum("enum_platform", ["spotify", "apple"]);

View file

@ -7,7 +7,7 @@ import "./workflows/sync";
import "./workflows/party-analysis";
import "./dbos.ts";
import { partyApp } from "./routes/party";
import { partySocketApp } from "./routes/party-socket";
import { partySocketApp, pubsub } from "./routes/party-socket";
import { statsApp } from "./routes/stats.ts";
const app = new Elysia()
@ -23,6 +23,8 @@ const app = new Elysia()
)
.listen(4000);
pubsub.setServer(app.server);
export type App = typeof app;
await DBOS.launch({

View file

@ -30,6 +30,9 @@ export async function getMemberRecord(dbClient: DbLike, userId: string) {
where: {
userId,
},
orderBy: {
joinedAt: "desc",
},
})) ?? null
);
}

View file

@ -1,156 +0,0 @@
import type { PartySocketEvent } from "./party-types";
type WebSocketLike = {
send: (data: string) => void;
close?: (code?: number, reason?: string) => void;
};
const partySockets = new Map<string, Map<string, Set<WebSocketLike>>>();
const userSockets = new Map<string, Set<WebSocketLike>>();
function getPartyUserSockets(partyId: string, userId: string) {
const partyMap = partySockets.get(partyId);
if (!partyMap) return null;
return partyMap.get(userId) ?? null;
}
export function registerPartySocket(
partyId: string,
userId: string,
ws: WebSocketLike,
) {
let partyMap = partySockets.get(partyId);
if (!partyMap) {
partyMap = new Map();
partySockets.set(partyId, partyMap);
}
let userSockets = partyMap.get(userId);
if (!userSockets) {
userSockets = new Set();
partyMap.set(userId, userSockets);
}
userSockets.add(ws);
}
export function unregisterPartySocket(
partyId: string,
userId: string,
ws: WebSocketLike,
) {
const partyMap = partySockets.get(partyId);
if (!partyMap) return;
const userSockets = partyMap.get(userId);
if (!userSockets) return;
userSockets.delete(ws);
if (userSockets.size === 0) {
partyMap.delete(userId);
}
if (partyMap.size === 0) {
partySockets.delete(partyId);
}
}
export function registerUserSocket(userId: string, ws: WebSocketLike) {
let sockets = userSockets.get(userId);
if (!sockets) {
sockets = new Set();
userSockets.set(userId, sockets);
}
sockets.add(ws);
}
export function unregisterUserSocket(userId: string, ws: WebSocketLike) {
const sockets = userSockets.get(userId);
if (!sockets) return;
sockets.delete(ws);
if (sockets.size === 0) {
userSockets.delete(userId);
}
}
export function unregisterUserSocketFromAllParties(
userId: string,
ws: WebSocketLike,
) {
for (const [partyId, partyMap] of partySockets) {
const userSockets = partyMap.get(userId);
if (!userSockets) continue;
userSockets.delete(ws);
if (userSockets.size === 0) {
partyMap.delete(userId);
}
if (partyMap.size === 0) {
partySockets.delete(partyId);
}
}
}
export function broadcastPartyEvent(partyId: string, event: PartySocketEvent) {
const partyMap = partySockets.get(partyId);
if (!partyMap) return;
const payload = JSON.stringify(event);
for (const userSockets of partyMap.values()) {
for (const ws of userSockets) {
ws.send(payload);
}
}
}
export function sendPartyEventToUser(
partyId: string,
userId: string,
event: PartySocketEvent,
) {
const userSockets = getPartyUserSockets(partyId, userId);
if (!userSockets) return;
const payload = JSON.stringify(event);
for (const ws of userSockets) {
ws.send(payload);
}
}
export function sendDirectEventToUser(userId: string, event: PartySocketEvent) {
const sockets = userSockets.get(userId);
if (!sockets) return;
const payload = JSON.stringify(event);
for (const ws of sockets) {
ws.send(payload);
}
}
export function reassignUserSocketsToParty(
userId: string,
partyId: string | null,
) {
for (const [existingPartyId, partyMap] of partySockets) {
if (!partyMap.has(userId)) continue;
partyMap.delete(userId);
if (partyMap.size === 0) {
partySockets.delete(existingPartyId);
}
}
if (!partyId) return;
const sockets = userSockets.get(userId);
if (!sockets) return;
let partyMap = partySockets.get(partyId);
if (!partyMap) {
partyMap = new Map();
partySockets.set(partyId, partyMap);
}
partyMap.set(userId, new Set(sockets));
}

View file

@ -27,9 +27,6 @@ export type PartySocketOutgoing =
export type PartySocketEvent =
| { type: "snapshot"; party: Party | null; members: PartyMemberWithUser[] }
| { type: "party_status"; party: Party; members: PartyMemberWithUser[] }
| { type: "member_joined"; userId: string }
| { type: "member_left"; userId: string; kickedBy?: string }
| { type: "host_changed"; hostId: string }
| { type: "member_payload"; fromUserId: string; payload: unknown }
| { type: "error"; message: string }
| { type: "pong" };

View file

@ -1,30 +1,34 @@
import Elysia, { t } from "elysia";
import { auth, betterAuthElysia } from "../auth";
import { Elysia } from "elysia";
import { betterAuthElysia } from "../auth";
import { db } from "../db";
import { getMemberRecord, getPartyStatus } from "../party-data";
import {
registerPartySocket,
registerUserSocket,
sendPartyEventToUser,
unregisterPartySocket,
unregisterUserSocket,
unregisterUserSocketFromAllParties,
} from "../party-sockets";
import type { PartySocketOutgoing } from "../party-types";
const MAX_MEMBER_PAYLOAD_SIZE = 8_000;
function userTopic(userId: string) {
return `user:${userId}`;
}
type PartyWsData = {
partyId?: string | null;
function partyTopic(partyId: string) {
return `party:${partyId}`;
}
const socketPartyId = new WeakMap<object, string>();
export const pubsub = {
_server: null as ReturnType<typeof Bun.serve> | null,
setServer(server: ReturnType<typeof Bun.serve> | null) {
this._server = server;
},
publish(topic: string, data: string) {
this._server?.publish(topic, data);
},
};
function getPayloadSize(payload: unknown) {
try {
return JSON.stringify(payload).length;
} catch {
return Infinity;
}
}
export const topic = {
user: userTopic,
party: partyTopic,
};
export const partySocketApp = new Elysia()
.use(betterAuthElysia)
@ -33,10 +37,13 @@ export const partySocketApp = new Elysia()
.get("/test", () => ({ ok: 1 }))
.ws("/ws", {
auth: true,
publishToSelf: true,
open: async (ws) => {
const user = ws.data.user;
if (!user) return;
registerUserSocket(user.id, ws);
ws.subscribe(userTopic(user.id));
const membership = await getMemberRecord(db, user.id);
if (!membership) {
ws.send(
@ -49,9 +56,10 @@ export const partySocketApp = new Elysia()
return;
}
socketPartyId.set(ws, membership.partyId);
ws.subscribe(partyTopic(membership.partyId));
const snapshot = await getPartyStatus(membership.partyId);
ws.data.partyId = membership.partyId;
registerPartySocket(membership.partyId, user.id, ws);
if (snapshot) {
ws.send(
JSON.stringify({
@ -62,56 +70,64 @@ export const partySocketApp = new Elysia()
);
}
},
message: async (ws, message: PartySocketOutgoing) => {
message: async (ws, message) => {
const data = ws.data;
const user = data.user;
if (!user) return;
if (message.type === "ping") {
if (typeof message !== "string") return;
let parsed: { type: string; payload?: unknown };
try {
parsed = JSON.parse(message);
} catch {
ws.send(JSON.stringify({ type: "error", message: "Invalid JSON" }));
return;
}
if (parsed.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
return;
}
if (message.type !== "member_payload") return;
const membership = await getMemberRecord(db, user.id);
if (!membership) return;
if (parsed.type !== "member_payload") return;
if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) {
const MAX_MEMBER_PAYLOAD_SIZE = 8_000;
const payloadString = JSON.stringify(parsed.payload);
if (payloadString.length > MAX_MEMBER_PAYLOAD_SIZE) {
ws.send(
JSON.stringify({
type: "error",
message: "Payload too large.",
}),
JSON.stringify({ type: "error", message: "Payload too large." }),
);
return;
}
const membership = await getMemberRecord(db, user.id);
if (!membership) return;
const currentParty = await db.query.party.findFirst({
where: { id: membership.partyId },
});
if (!currentParty) return;
sendPartyEventToUser(membership.partyId, currentParty.hostId, {
type: "member_payload",
fromUserId: user.id,
payload: message.payload,
});
ws.publish(
partyTopic(membership.partyId),
JSON.stringify({
type: "member_payload",
fromUserId: user.id,
payload: parsed.payload,
}),
);
},
close: async (ws) => {
const data = ws.data;
const user = data.user;
const { partyId } = data;
const user = ws.data.user;
if (!user) return;
if (!partyId) {
unregisterUserSocketFromAllParties(user.id, ws);
unregisterUserSocket(user.id, ws);
return;
}
unregisterPartySocket(partyId, user.id, ws);
unregisterUserSocket(user.id, ws);
ws.unsubscribe(userTopic(user.id));
const partyId = socketPartyId.get(ws);
if (!partyId) return;
ws.unsubscribe(partyTopic(partyId));
},
body: t.Union([
t.Object({ type: t.Literal("ping") }),
t.Object({ type: t.Literal("member_payload"), payload: t.Any() }),
]),
}),
);

View file

@ -10,28 +10,29 @@ import {
getPartyStatus,
leaveParty,
} from "../party-data";
import {
broadcastPartyEvent,
reassignUserSocketsToParty,
sendDirectEventToUser,
} from "../party-sockets";
import {
PARTY_STATUS,
type PartySnapshot,
type PartyStatus,
} from "../party-types";
import type { PartySnapshot } from "../party-types";
import { pubsub, topic } from "./party-socket";
function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) {
if (!snapshot) return;
broadcastPartyEvent(partyId, {
type: "party_status",
party: snapshot.party,
members: snapshot.members,
});
pubsub.publish(
topic.party(partyId),
JSON.stringify({
type: "party_status",
party: snapshot.party,
members: snapshot.members,
}),
);
}
function isValidStatus(status: string): status is PartyStatus {
return PARTY_STATUS.includes(status as PartyStatus);
function broadcastToUser(userId: string, event: Record<string, unknown>) {
pubsub.publish(topic.user(userId), JSON.stringify(event));
}
function isValidStatus(
status: string,
): status is import("../party-types").PartyStatus {
return ["created", "started", "ended"].includes(status);
}
export const partyApp = new Elysia()
@ -126,31 +127,19 @@ export const partyApp = new Elysia()
if (!partyId) return { party: null, members: [] };
const status = await getPartyStatus(partyId);
if (leaveResult?.newHostId) {
broadcastPartyEvent(leaveResult.partyId, {
type: "host_changed",
hostId: leaveResult.newHostId,
});
broadcastSnapshot(leaveResult.partyId, status);
}
if (hostChanged) {
broadcastPartyEvent(partyId, {
type: "host_changed",
hostId: targetUserId,
});
broadcastSnapshot(partyId, status);
}
broadcastPartyEvent(partyId, {
type: "member_joined",
userId: user.id,
});
broadcastSnapshot(partyId, status);
reassignUserSocketsToParty(user.id, partyId);
reassignUserSocketsToParty(targetUserId, partyId);
if (status) {
sendDirectEventToUser(targetUserId, {
broadcastToUser(targetUserId, {
type: "party_status",
party: status.party,
members: status.members,
});
sendDirectEventToUser(user.id, {
broadcastToUser(user.id, {
type: "party_status",
party: status.party,
members: status.members,
@ -173,18 +162,7 @@ export const partyApp = new Elysia()
});
if (!result) return { party: null, members: [] };
const status = await getPartyStatus(result.partyId);
broadcastPartyEvent(result.partyId, {
type: "member_left",
userId: user.id,
});
if (result.newHostId) {
broadcastPartyEvent(result.partyId, {
type: "host_changed",
hostId: result.newHostId,
});
}
broadcastSnapshot(result.partyId, status);
reassignUserSocketsToParty(user.id, null);
return status ?? { party: null, members: [] };
},
{ auth: true },
@ -225,13 +203,7 @@ export const partyApp = new Elysia()
await cleanupPartyIfEmpty(tx, currentMembership.partyId);
});
const status = await getPartyStatus(currentMembership.partyId);
broadcastPartyEvent(currentMembership.partyId, {
type: "member_left",
userId: body.memberUserId,
kickedBy: user.id,
});
broadcastSnapshot(currentMembership.partyId, status);
reassignUserSocketsToParty(body.memberUserId, null);
return status ?? { party: null, members: [] };
},
{

View file

@ -13,24 +13,8 @@ function reducePartyState(
): PartyState {
switch (event.type) {
case "snapshot":
return { party: event.party, members: event.members };
case "party_status":
return { party: event.party, members: event.members };
case "member_joined":
return state;
case "member_left":
return {
...state,
members: state.members.filter(
(m: PartyMember) => m.userId !== event.userId,
),
};
case "host_changed":
if (!state.party) return state;
return {
...state,
party: { ...state.party, hostId: event.hostId },
};
case "member_payload":
case "pong":
case "error":