refactor socket

This commit is contained in:
Daniel Bulant 2026-04-21 23:00:21 +02:00
parent 49390778c2
commit 2d16fb8ecc
No known key found for this signature in database
5 changed files with 347 additions and 227 deletions

View file

@ -5,11 +5,14 @@ import { syncApp } from "./routes/sync";
import "./workflows/sync";
import "./dbos.ts";
import { partyApp } from "./routes/party";
import { partySocketApp } from "./routes/party-socket";
import { statsApp } from "./routes/stats.ts";
const app = new Elysia()
.use(betterAuthElysia)
.group("/api", (app) => app.use(syncApp).use(statsApp).use(partyApp))
.group("/api", (app) =>
app.use(syncApp).use(statsApp).use(partyApp).use(partySocketApp),
)
.listen(4000);
export type App = typeof app;

106
api/src/party-data.ts Normal file
View file

@ -0,0 +1,106 @@
import { eq } from "drizzle-orm";
import { db } from "./db";
import { party, partyMember } from "./db/schema";
type DbClient = typeof db;
type DbTransaction = Parameters<typeof db.transaction>[0] extends (
tx: infer T,
) => Promise<unknown>
? T
: never;
export type DbLike = DbClient | DbTransaction;
export async function getPartyForUser(userId: string) {
const memberships = await db.query.partyMember.findMany({
where: {
userId,
},
with: {
party: true,
},
limit: 1,
});
return memberships[0]?.party ?? null;
}
export async function getMemberRecord(dbClient: DbLike, userId: string) {
return (
(await dbClient.query.partyMember.findFirst({
where: {
userId,
},
})) ?? null
);
}
export async function getPartyStatus(partyId: string) {
const partyRecord = await db.query.party.findFirst({
where: {
id: partyId,
},
});
if (!partyRecord) return null;
const members = await db.query.partyMember.findMany({
where: {
partyId,
},
with: {
user: true,
},
orderBy: {
joinedAt: "asc",
},
});
return {
party: partyRecord,
members,
};
}
export async function cleanupPartyIfEmpty(dbClient: DbLike, partyId: string) {
const members = await dbClient.query.partyMember.findMany({
where: {
partyId,
},
limit: 1,
});
if (members.length > 0) return;
await dbClient.delete(party).where(eq(party.id, partyId));
}
export async function leaveParty(dbClient: DbLike, userId: string) {
const member = await getMemberRecord(dbClient, userId);
if (!member) return null;
await dbClient.delete(partyMember).where(eq(partyMember.id, member.id));
const nextHost = await dbClient.query.partyMember.findFirst({
where: {
partyId: member.partyId,
},
orderBy: {
joinedAt: "asc",
},
});
let newHostId: string | null = null;
if (nextHost) {
const currentParty = await dbClient.query.party.findFirst({
where: {
id: member.partyId,
},
});
if (currentParty?.hostId === userId) {
await dbClient
.update(party)
.set({
hostId: nextHost.userId,
lastUpdated: new Date(),
})
.where(eq(party.id, member.partyId));
newHostId = nextHost.userId;
}
}
await cleanupPartyIfEmpty(dbClient, member.partyId);
return {
partyId: member.partyId,
newHostId,
};
}

View file

@ -9,6 +9,7 @@ type WebSocketLike = {
};
const partySockets = new Map<string, Map<string, Set<WebSocketLike>>>();
const userSockets = new Map<string, Set<WebSocketLike>>();
function getPartyUserSockets(partyId: string, userId: string) {
const partyMap = partySockets.get(partyId);
@ -58,6 +59,44 @@ export function unregisterPartySocket(
}
}
export function registerUserSocket(userId: string, ws: WebSocketLike) {
let sockets = userSockets.get(userId);
if (!sockets) {
sockets = new Set();
userSockets.set(userId, sockets);
}
sockets.add(ws);
}
export function unregisterUserSocket(userId: string, ws: WebSocketLike) {
const sockets = userSockets.get(userId);
if (!sockets) return;
sockets.delete(ws);
if (sockets.size === 0) {
userSockets.delete(userId);
}
}
export function unregisterUserSocketFromAllParties(
userId: string,
ws: WebSocketLike,
) {
for (const [partyId, partyMap] of partySockets) {
const userSockets = partyMap.get(userId);
if (!userSockets) continue;
userSockets.delete(ws);
if (userSockets.size === 0) {
partyMap.delete(userId);
}
if (partyMap.size === 0) {
partySockets.delete(partyId);
}
}
}
export function broadcastPartyEvent(partyId: string, event: PartySocketEvent) {
const partyMap = partySockets.get(partyId);
if (!partyMap) return;
@ -83,3 +122,38 @@ export function sendPartyEventToUser(
ws.send(payload);
}
}
export function sendDirectEventToUser(userId: string, event: PartySocketEvent) {
const sockets = userSockets.get(userId);
if (!sockets) return;
const payload = JSON.stringify(event);
for (const ws of sockets) {
ws.send(payload);
}
}
export function reassignUserSocketsToParty(
userId: string,
partyId: string | null,
) {
for (const [existingPartyId, partyMap] of partySockets) {
if (!partyMap.has(userId)) continue;
partyMap.delete(userId);
if (partyMap.size === 0) {
partySockets.delete(existingPartyId);
}
}
if (!partyId) return;
const sockets = userSockets.get(userId);
if (!sockets) return;
let partyMap = partySockets.get(partyId);
if (!partyMap) {
partyMap = new Map();
partySockets.set(partyId, partyMap);
}
partyMap.set(userId, new Set(sockets));
}

View file

@ -0,0 +1,137 @@
import Elysia, { t } from "elysia";
import { auth, betterAuthElysia } from "../auth";
import { db } from "../db";
import { getMemberRecord, getPartyStatus } from "../party-data";
import {
registerPartySocket,
registerUserSocket,
sendPartyEventToUser,
unregisterPartySocket,
unregisterUserSocket,
unregisterUserSocketFromAllParties,
} from "../party-sockets";
type PartySocketMessage =
| {
type: "member_payload";
payload: unknown;
}
| {
type: "ping";
};
const MAX_MEMBER_PAYLOAD_SIZE = 8_000;
type PartyWsData = {
user?: { id: string };
partyId?: string | null;
};
function getPayloadSize(payload: unknown) {
try {
return JSON.stringify(payload).length;
} catch {
return Infinity;
}
}
export const partySocketApp = new Elysia()
.use(betterAuthElysia)
.group("/party-socket", (app) =>
app.ws("/ws", {
beforeHandle: async ({ request, set }) => {
const session = await auth.api.getSession({
headers: request.headers,
});
if (!session) {
set.status = 401;
return;
}
return {
user: session.user,
session: session.session,
};
},
open: async (ws) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
if (!user) return;
registerUserSocket(user.id, ws);
const membership = await getMemberRecord(db, user.id);
if (!membership) {
ws.send(
JSON.stringify({
type: "snapshot",
party: null,
members: [],
}),
);
return;
}
const snapshot = await getPartyStatus(membership.partyId);
data.partyId = membership.partyId;
registerPartySocket(membership.partyId, user.id, ws);
if (snapshot) {
ws.send(
JSON.stringify({
type: "snapshot",
party: snapshot.party,
members: snapshot.members,
}),
);
}
},
message: async (ws, message: PartySocketMessage) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
if (!user) return;
if (message.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
return;
}
if (message.type !== "member_payload") return;
const membership = await getMemberRecord(db, user.id);
if (!membership) return;
if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) {
ws.send(
JSON.stringify({
type: "error",
message: "Payload too large.",
}),
);
return;
}
const currentParty = await db.query.party.findFirst({
where: { id: membership.partyId },
});
if (!currentParty) return;
sendPartyEventToUser(membership.partyId, currentParty.hostId, {
type: "member_payload",
fromUserId: user.id,
payload: message.payload,
});
},
close: async (ws) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
const { partyId } = data;
if (!user) return;
if (!partyId) {
unregisterUserSocketFromAllParties(user.id, ws);
unregisterUserSocket(user.id, ws);
return;
}
unregisterPartySocket(partyId, user.id, ws);
unregisterUserSocket(user.id, ws);
},
body: t.Union([
t.Object({ type: t.Literal("ping") }),
t.Object({ type: t.Literal("member_payload"), payload: t.Any() }),
]),
}),
);

View file

@ -1,92 +1,27 @@
import { and, eq } from "drizzle-orm";
import Elysia, { t } from "elysia";
import { auth, betterAuthElysia } from "../auth";
import { betterAuthElysia } from "../auth";
import { db } from "../db";
import { party, partyMember } from "../db/schema";
import {
cleanupPartyIfEmpty,
getMemberRecord,
getPartyForUser,
getPartyStatus,
leaveParty,
} from "../party-data";
import {
broadcastPartyEvent,
registerPartySocket,
sendPartyEventToUser,
unregisterPartySocket,
reassignUserSocketsToParty,
sendDirectEventToUser,
} from "../party-sockets";
const PARTY_STATUS = ["created", "started", "ended"] as const;
type PartyStatus = (typeof PARTY_STATUS)[number];
type DbClient = typeof db;
type DbTransaction = Parameters<typeof db.transaction>[0] extends (
tx: infer T,
) => Promise<unknown>
? T
: never;
type DbLike = DbClient | DbTransaction;
type PartySnapshot = NonNullable<Awaited<ReturnType<typeof getPartyStatus>>>;
type PartySocketMessage =
| {
type: "member_payload";
payload: unknown;
}
| {
type: "ping";
};
const MAX_MEMBER_PAYLOAD_SIZE = 8_000;
type PartyWsData = {
user?: { id: string };
partyId?: string;
};
async function getPartyForUser(userId: string) {
const memberships = await db.query.partyMember.findMany({
where: {
userId,
},
with: {
party: true,
},
limit: 1,
});
return memberships[0]?.party ?? null;
}
async function getMemberRecord(dbClient: DbLike, userId: string) {
return (
(await dbClient.query.partyMember.findFirst({
where: {
userId,
},
})) ?? null
);
}
async function getPartyStatus(partyId: string) {
const party = await db.query.party.findFirst({
where: {
id: partyId,
},
});
if (!party) return null;
const members = await db.query.partyMember.findMany({
where: {
partyId,
},
with: {
user: true,
},
orderBy: {
joinedAt: "asc",
},
});
return {
party,
members,
};
}
function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) {
if (!snapshot) return;
broadcastPartyEvent(partyId, {
@ -96,62 +31,6 @@ function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) {
});
}
function getPayloadSize(payload: unknown) {
try {
return JSON.stringify(payload).length;
} catch {
return Infinity;
}
}
async function cleanupPartyIfEmpty(dbClient: DbLike, partyId: string) {
const members = await dbClient.query.partyMember.findMany({
where: {
partyId,
},
limit: 1,
});
if (members.length > 0) return;
await dbClient.delete(party).where(eq(party.id, partyId));
}
async function leaveParty(dbClient: DbLike, userId: string) {
const member = await getMemberRecord(dbClient, userId);
if (!member) return null;
await dbClient.delete(partyMember).where(eq(partyMember.id, member.id));
const nextHost = await dbClient.query.partyMember.findFirst({
where: {
partyId: member.partyId,
},
orderBy: {
joinedAt: "asc",
},
});
let newHostId: string | null = null;
if (nextHost) {
const currentParty = await dbClient.query.party.findFirst({
where: {
id: member.partyId,
},
});
if (currentParty?.hostId === userId) {
await dbClient
.update(party)
.set({
hostId: nextHost.userId,
lastUpdated: new Date(),
})
.where(eq(party.id, member.partyId));
newHostId = nextHost.userId;
}
}
await cleanupPartyIfEmpty(dbClient, member.partyId);
return {
partyId: member.partyId,
newHostId,
};
}
function isValidStatus(status: string): status is PartyStatus {
return PARTY_STATUS.includes(status as PartyStatus);
}
@ -160,101 +39,6 @@ export const partyApp = new Elysia()
.use(betterAuthElysia)
.group("/party", (app) =>
app
.ws("/ws", {
beforeHandle: async ({ request, set }) => {
const session = await auth.api.getSession({
headers: request.headers,
});
if (!session) {
set.status = 401;
return;
}
return {
user: session.user,
session: session.session,
};
},
open: async (ws) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
if (!user) return;
const membership = await getMemberRecord(db, user.id);
if (!membership) {
ws.send(
JSON.stringify({
type: "error",
message: "You are not in a party.",
}),
);
ws.close?.(1008, "Not in a party");
return;
}
const snapshot = await getPartyStatus(membership.partyId);
data.partyId = membership.partyId;
registerPartySocket(membership.partyId, user.id, ws);
if (snapshot) {
ws.send(
JSON.stringify({
type: "snapshot",
party: snapshot.party,
members: snapshot.members,
}),
);
}
},
message: async (ws, message: PartySocketMessage) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
if (!user) return;
if (message.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
return;
}
if (message.type !== "member_payload") return;
const membership = await getMemberRecord(db, user.id);
if (!membership) return;
if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) {
ws.send(
JSON.stringify({
type: "error",
message: "Payload too large.",
}),
);
return;
}
const currentParty = await db.query.party.findFirst({
where: { id: membership.partyId },
});
if (!currentParty) return;
sendPartyEventToUser(membership.partyId, currentParty.hostId, {
type: "member_payload",
fromUserId: user.id,
payload: message.payload,
});
},
close: async (ws) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
const { partyId } = data;
if (!user) return;
if (!partyId) {
const membership = await getMemberRecord(db, user.id);
if (!membership) return;
unregisterPartySocket(membership.partyId, user.id, ws);
return;
}
unregisterPartySocket(partyId, user.id, ws);
},
body: t.Union([
t.Object({ type: t.Literal("ping") }),
t.Object({ type: t.Literal("member_payload"), payload: t.Any() }),
]),
})
.get(
"/status",
async ({ user }) => {
@ -359,6 +143,20 @@ export const partyApp = new Elysia()
userId: user.id,
});
broadcastSnapshot(partyId, status);
reassignUserSocketsToParty(user.id, partyId);
reassignUserSocketsToParty(targetUserId, partyId);
if (status) {
sendDirectEventToUser(targetUserId, {
type: "party_status",
party: status.party,
members: status.members,
});
sendDirectEventToUser(user.id, {
type: "party_status",
party: status.party,
members: status.members,
});
}
return status ?? { party: null, members: [] };
},
{
@ -387,6 +185,7 @@ export const partyApp = new Elysia()
});
}
broadcastSnapshot(result.partyId, status);
reassignUserSocketsToParty(user.id, null);
return status ?? { party: null, members: [] };
},
{ auth: true },
@ -433,6 +232,7 @@ export const partyApp = new Elysia()
kickedBy: user.id,
});
broadcastSnapshot(currentMembership.partyId, status);
reassignUserSocketsToParty(body.memberUserId, null);
return status ?? { party: null, members: [] };
},
{