initial socket impl

This commit is contained in:
Daniel Bulant 2026-04-29 19:40:31 +02:00
parent 21910eca41
commit 9072ce76ba
No known key found for this signature in database
9 changed files with 370 additions and 135 deletions

View file

@ -1,6 +1,7 @@
import { eq } from "drizzle-orm";
import { db } from "./db";
import { party, partyMember } from "./db/schema";
import type { PartySnapshot } from "./party-types";
type DbClient = typeof db;
type DbTransaction = Parameters<typeof db.transaction>[0] extends (
@ -33,7 +34,9 @@ export async function getMemberRecord(dbClient: DbLike, userId: string) {
);
}
export async function getPartyStatus(partyId: string) {
export async function getPartyStatus(
partyId: string,
): Promise<PartySnapshot | null> {
const partyRecord = await db.query.party.findFirst({
where: {
id: partyId,

View file

@ -1,7 +1,4 @@
type PartySocketEvent = {
type: string;
[key: string]: unknown;
};
import type { PartySocketEvent } from "./party-types";
type WebSocketLike = {
send: (data: string) => void;

35
api/src/party-types.ts Normal file
View file

@ -0,0 +1,35 @@
import type { InferSelectModel } from "drizzle-orm";
import type { party, partyMember, user } from "./db/schema";
export type Party = InferSelectModel<typeof party>;
export type PartyMember = InferSelectModel<typeof partyMember>;
export type User = InferSelectModel<typeof user>;
export type PartyMemberWithUser = PartyMember & { user: User | null };
export const PARTY_STATUS = ["created", "started", "ended"] as const;
export type PartyStatus = (typeof PARTY_STATUS)[number];
export type PartySnapshot = {
party: Party;
members: PartyMemberWithUser[];
};
export type PartyState = {
party: Party | null;
members: PartyMemberWithUser[];
};
export type PartySocketOutgoing =
| { type: "ping" }
| { type: "member_payload"; payload: unknown };
export type PartySocketEvent =
| { type: "snapshot"; party: Party | null; members: PartyMemberWithUser[] }
| { type: "party_status"; party: Party; members: PartyMemberWithUser[] }
| { type: "member_joined"; userId: string }
| { type: "member_left"; userId: string; kickedBy?: string }
| { type: "host_changed"; hostId: string }
| { type: "member_payload"; fromUserId: string; payload: unknown }
| { type: "error"; message: string }
| { type: "pong" };

View file

@ -10,20 +10,11 @@ import {
unregisterUserSocket,
unregisterUserSocketFromAllParties,
} from "../party-sockets";
type PartySocketMessage =
| {
type: "member_payload";
payload: unknown;
}
| {
type: "ping";
};
import type { PartySocketOutgoing } from "../party-types";
const MAX_MEMBER_PAYLOAD_SIZE = 8_000;
type PartyWsData = {
user?: { id: string };
partyId?: string | null;
};
@ -38,100 +29,89 @@ function getPayloadSize(payload: unknown) {
export const partySocketApp = new Elysia()
.use(betterAuthElysia)
.group("/party-socket", (app) =>
app.ws("/ws", {
beforeHandle: async ({ request, set }) => {
const session = await auth.api.getSession({
headers: request.headers,
});
if (!session) {
set.status = 401;
return;
}
return {
user: session.user,
session: session.session,
};
},
open: async (ws) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
if (!user) return;
registerUserSocket(user.id, ws);
const membership = await getMemberRecord(db, user.id);
if (!membership) {
ws.send(
JSON.stringify({
type: "snapshot",
party: null,
members: [],
}),
);
return;
}
app
.get("/test", () => ({ ok: 1 }))
.ws("/ws", {
auth: true,
open: async (ws) => {
const user = ws.data.user;
if (!user) return;
registerUserSocket(user.id, ws);
const membership = await getMemberRecord(db, user.id);
if (!membership) {
ws.send(
JSON.stringify({
type: "snapshot",
party: null,
members: [],
}),
);
return;
}
const snapshot = await getPartyStatus(membership.partyId);
data.partyId = membership.partyId;
registerPartySocket(membership.partyId, user.id, ws);
if (snapshot) {
ws.send(
JSON.stringify({
type: "snapshot",
party: snapshot.party,
members: snapshot.members,
}),
);
}
},
message: async (ws, message: PartySocketMessage) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
if (!user) return;
if (message.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
return;
}
const snapshot = await getPartyStatus(membership.partyId);
ws.data.partyId = membership.partyId;
registerPartySocket(membership.partyId, user.id, ws);
if (snapshot) {
ws.send(
JSON.stringify({
type: "snapshot",
party: snapshot.party,
members: snapshot.members,
}),
);
}
},
message: async (ws, message: PartySocketOutgoing) => {
const data = ws.data;
const user = data.user;
if (!user) return;
if (message.type === "ping") {
ws.send(JSON.stringify({ type: "pong" }));
return;
}
if (message.type !== "member_payload") return;
const membership = await getMemberRecord(db, user.id);
if (!membership) return;
if (message.type !== "member_payload") return;
const membership = await getMemberRecord(db, user.id);
if (!membership) return;
if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) {
ws.send(
JSON.stringify({
type: "error",
message: "Payload too large.",
}),
);
return;
}
if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) {
ws.send(
JSON.stringify({
type: "error",
message: "Payload too large.",
}),
);
return;
}
const currentParty = await db.query.party.findFirst({
where: { id: membership.partyId },
});
if (!currentParty) return;
const currentParty = await db.query.party.findFirst({
where: { id: membership.partyId },
});
if (!currentParty) return;
sendPartyEventToUser(membership.partyId, currentParty.hostId, {
type: "member_payload",
fromUserId: user.id,
payload: message.payload,
});
},
close: async (ws) => {
const data = ws.data as unknown as PartyWsData;
const user = data.user;
const { partyId } = data;
if (!user) return;
if (!partyId) {
unregisterUserSocketFromAllParties(user.id, ws);
sendPartyEventToUser(membership.partyId, currentParty.hostId, {
type: "member_payload",
fromUserId: user.id,
payload: message.payload,
});
},
close: async (ws) => {
const data = ws.data;
const user = data.user;
const { partyId } = data;
if (!user) return;
if (!partyId) {
unregisterUserSocketFromAllParties(user.id, ws);
unregisterUserSocket(user.id, ws);
return;
}
unregisterPartySocket(partyId, user.id, ws);
unregisterUserSocket(user.id, ws);
return;
}
unregisterPartySocket(partyId, user.id, ws);
unregisterUserSocket(user.id, ws);
},
body: t.Union([
t.Object({ type: t.Literal("ping") }),
t.Object({ type: t.Literal("member_payload"), payload: t.Any() }),
]),
}),
},
body: t.Union([
t.Object({ type: t.Literal("ping") }),
t.Object({ type: t.Literal("member_payload"), payload: t.Any() }),
]),
}),
);

View file

@ -15,12 +15,11 @@ import {
reassignUserSocketsToParty,
sendDirectEventToUser,
} from "../party-sockets";
const PARTY_STATUS = ["created", "started", "ended"] as const;
type PartyStatus = (typeof PARTY_STATUS)[number];
type PartySnapshot = NonNullable<Awaited<ReturnType<typeof getPartyStatus>>>;
import {
PARTY_STATUS,
type PartySnapshot,
type PartyStatus,
} from "../party-types";
function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) {
if (!snapshot) return;

View file

@ -1,29 +1,39 @@
import { useRouteContext } from "@tanstack/react-router";
import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar";
import {
Item,
ItemContent,
ItemDescription,
ItemMedia,
ItemTitle,
} from "./ui/item";
import { useParty } from "#/hooks/use-party";
import { useUser } from "#/hooks/user";
import { initials } from "#/lib/utils";
import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar";
import {
Item,
ItemContent,
ItemDescription,
ItemMedia,
ItemTitle,
} from "./ui/item";
export function UserInfo() {
const { user } = useUser();
return (
<Item>
<ItemMedia>
<Avatar>
<AvatarImage src={user?.image || undefined} />
<AvatarFallback>{initials(user?.name || "")}</AvatarFallback>
</Avatar>
</ItemMedia>
<ItemContent>
<ItemTitle>{user?.name}</ItemTitle>
<ItemDescription>No party yet</ItemDescription>
</ItemContent>
</Item>
);
const { user } = useUser();
const { party, members, isConnecting, isReconnecting } = useParty();
return (
<Item>
<ItemMedia>
<Avatar>
<AvatarImage src={user?.image || undefined} />
<AvatarFallback>{initials(user?.name || "")}</AvatarFallback>
</Avatar>
</ItemMedia>
<ItemContent>
<ItemTitle>{user?.name}</ItemTitle>
<ItemDescription>
{isConnecting
? "Connecting..."
: isReconnecting
? "Reconnecting..."
: party
? `${members.length} in party`
: "No party yet"}
</ItemDescription>
</ItemContent>
</Item>
);
}

View file

@ -0,0 +1,127 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import type { PartySocketEvent } from "../../../api/src/party-types";
type Handler = (event: PartySocketEvent) => void;
const PING_INTERVAL_MS = 30_000;
const RECONNECT_BASE_MS = 1_000;
const RECONNECT_MAX_MS = 30_000;
export function usePartySocket({
apiUrl,
onMessage,
}: {
apiUrl: string | null;
onMessage: Handler | null;
}) {
const [connectionState, setConnectionState] = useState<
"disconnected" | "connecting" | "connected" | "reconnecting"
>("disconnected");
const wsRef = useRef<WebSocket | null>(null);
const pingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
const reconnectTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const reconnectAttemptRef = useRef(0);
const handlerRef = useRef(onMessage);
useEffect(() => {
handlerRef.current = onMessage;
}, [onMessage]);
const setupWs = useCallback(
(ws: WebSocket) => {
ws.onopen = () => {
reconnectAttemptRef.current = 0;
setConnectionState("connected");
pingTimerRef.current = setInterval(() => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: "ping" }));
}
}, PING_INTERVAL_MS);
};
ws.onmessage = (event) => {
const parsed = JSON.parse(event.data) as PartySocketEvent;
handlerRef.current?.(parsed);
};
ws.onclose = () => {
if (pingTimerRef.current) {
clearInterval(pingTimerRef.current);
pingTimerRef.current = null;
}
wsRef.current = null;
setConnectionState("reconnecting");
const delay = Math.min(
RECONNECT_BASE_MS * 2 ** reconnectAttemptRef.current,
RECONNECT_MAX_MS,
);
reconnectAttemptRef.current++;
reconnectTimerRef.current = setTimeout(() => {
if (!apiUrl) return;
const protocol = apiUrl.startsWith("https") ? "wss" : "ws";
const newWs = new WebSocket(
`${protocol}://${apiUrl.replace(/https?:\/\//, "")}/api/party-socket/ws`,
);
wsRef.current = newWs;
setupWs(newWs);
}, delay);
};
},
[apiUrl],
);
useEffect(() => {
if (!apiUrl) {
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
if (pingTimerRef.current) {
clearInterval(pingTimerRef.current);
pingTimerRef.current = null;
}
if (reconnectTimerRef.current) {
clearTimeout(reconnectTimerRef.current);
reconnectTimerRef.current = null;
}
setConnectionState("disconnected");
reconnectAttemptRef.current = 0;
return;
}
setConnectionState("connecting");
const protocol = apiUrl.startsWith("https") ? "wss" : "ws";
const ws = new WebSocket(
`${protocol}://${apiUrl.replace(/https?:\/\//, "")}/api/party-socket/ws`,
);
wsRef.current = ws;
setupWs(ws);
return () => {
ws.close();
wsRef.current = null;
if (pingTimerRef.current) {
clearInterval(pingTimerRef.current);
pingTimerRef.current = null;
}
if (reconnectTimerRef.current) {
clearTimeout(reconnectTimerRef.current);
reconnectTimerRef.current = null;
}
};
}, [apiUrl, setupWs]);
const state = useMemo(
() => ({
connectionState,
isConnected: connectionState === "connected",
isConnecting: connectionState === "connecting",
isReconnecting: connectionState === "reconnecting",
}),
[connectionState],
);
return state;
}

View file

@ -0,0 +1,74 @@
import { useCallback, useMemo, useState } from "react";
import type {
PartyMember,
PartySocketEvent,
PartyState,
} from "../../../api/src/party-types";
import { usePartySocket } from "./use-party-socket";
import { useUser } from "./user";
function reducePartyState(
state: PartyState,
event: PartySocketEvent,
): PartyState {
switch (event.type) {
case "snapshot":
return { party: event.party, members: event.members };
case "party_status":
return { party: event.party, members: event.members };
case "member_joined":
return state;
case "member_left":
return {
...state,
members: state.members.filter(
(m: PartyMember) => m.userId !== event.userId,
),
};
case "host_changed":
if (!state.party) return state;
return {
...state,
party: { ...state.party, hostId: event.hostId },
};
case "member_payload":
case "pong":
case "error":
return state;
}
}
function getApiUrl(): string | null {
if (typeof window === "undefined") return null;
const envUrl = import.meta.env.VITE_BETTER_AUTH_URL;
if (envUrl) return envUrl;
return `${window.location.protocol}//${window.location.host}`;
}
export function useParty() {
const { session } = useUser();
const [state, setState] = useState<PartyState>({
party: null,
members: [],
});
const handleMessage = useCallback((event: PartySocketEvent) => {
setState((prev: PartyState) => reducePartyState(prev, event));
}, []);
const apiUrl = useMemo(() => {
const url = getApiUrl();
if (!url) return null;
return url;
}, []);
const wsState = usePartySocket({
apiUrl,
onMessage: session ? handleMessage : null,
});
return {
...state,
...wsState,
};
}

View file

@ -20,7 +20,17 @@ const config = defineConfig({
],
server: {
proxy: {
"/api": "http://localhost:4000",
"/api": {
target: "http://localhost:4000",
changeOrigin: true,
rewrite: (path) =>
path.replace(/^\/api/, "/api"),
},
"/api/party-socket/ws": {
target: "ws://localhost:4000",
ws: true,
rewriteWsOrigin: true,
},
},
},
});