From 9072ce76babe3fb2cf6263cf181caa266403e691 Mon Sep 17 00:00:00 2001 From: Daniel Bulant Date: Wed, 29 Apr 2026 19:40:31 +0200 Subject: [PATCH] initial socket impl --- api/src/party-data.ts | 5 +- api/src/party-sockets.ts | 5 +- api/src/party-types.ts | 35 ++++++ api/src/routes/party-socket.ts | 180 +++++++++++++----------------- api/src/routes/party.ts | 11 +- web/src/components/user-info.tsx | 56 ++++++---- web/src/hooks/use-party-socket.ts | 127 +++++++++++++++++++++ web/src/hooks/use-party.ts | 74 ++++++++++++ web/vite.config.ts | 12 +- 9 files changed, 370 insertions(+), 135 deletions(-) create mode 100644 api/src/party-types.ts create mode 100644 web/src/hooks/use-party-socket.ts create mode 100644 web/src/hooks/use-party.ts diff --git a/api/src/party-data.ts b/api/src/party-data.ts index 1ea0dd7..eef5210 100644 --- a/api/src/party-data.ts +++ b/api/src/party-data.ts @@ -1,6 +1,7 @@ import { eq } from "drizzle-orm"; import { db } from "./db"; import { party, partyMember } from "./db/schema"; +import type { PartySnapshot } from "./party-types"; type DbClient = typeof db; type DbTransaction = Parameters[0] extends ( @@ -33,7 +34,9 @@ export async function getMemberRecord(dbClient: DbLike, userId: string) { ); } -export async function getPartyStatus(partyId: string) { +export async function getPartyStatus( + partyId: string, +): Promise { const partyRecord = await db.query.party.findFirst({ where: { id: partyId, diff --git a/api/src/party-sockets.ts b/api/src/party-sockets.ts index fcd4252..796b7b2 100644 --- a/api/src/party-sockets.ts +++ b/api/src/party-sockets.ts @@ -1,7 +1,4 @@ -type PartySocketEvent = { - type: string; - [key: string]: unknown; -}; +import type { PartySocketEvent } from "./party-types"; type WebSocketLike = { send: (data: string) => void; diff --git a/api/src/party-types.ts b/api/src/party-types.ts new file mode 100644 index 0000000..9ff318d --- /dev/null +++ b/api/src/party-types.ts @@ -0,0 +1,35 @@ +import type { InferSelectModel } from "drizzle-orm"; +import type { party, partyMember, user } from "./db/schema"; + +export type Party = InferSelectModel; +export type PartyMember = InferSelectModel; +export type User = InferSelectModel; + +export type PartyMemberWithUser = PartyMember & { user: User | null }; + +export const PARTY_STATUS = ["created", "started", "ended"] as const; +export type PartyStatus = (typeof PARTY_STATUS)[number]; + +export type PartySnapshot = { + party: Party; + members: PartyMemberWithUser[]; +}; + +export type PartyState = { + party: Party | null; + members: PartyMemberWithUser[]; +}; + +export type PartySocketOutgoing = + | { type: "ping" } + | { type: "member_payload"; payload: unknown }; + +export type PartySocketEvent = + | { type: "snapshot"; party: Party | null; members: PartyMemberWithUser[] } + | { type: "party_status"; party: Party; members: PartyMemberWithUser[] } + | { type: "member_joined"; userId: string } + | { type: "member_left"; userId: string; kickedBy?: string } + | { type: "host_changed"; hostId: string } + | { type: "member_payload"; fromUserId: string; payload: unknown } + | { type: "error"; message: string } + | { type: "pong" }; diff --git a/api/src/routes/party-socket.ts b/api/src/routes/party-socket.ts index a0c36e9..f1040d9 100644 --- a/api/src/routes/party-socket.ts +++ b/api/src/routes/party-socket.ts @@ -10,20 +10,11 @@ import { unregisterUserSocket, unregisterUserSocketFromAllParties, } from "../party-sockets"; - -type PartySocketMessage = - | { - type: "member_payload"; - payload: unknown; - } - | { - type: "ping"; - }; +import type { PartySocketOutgoing } from "../party-types"; const MAX_MEMBER_PAYLOAD_SIZE = 8_000; type PartyWsData = { - user?: { id: string }; partyId?: string | null; }; @@ -38,100 +29,89 @@ function getPayloadSize(payload: unknown) { export const partySocketApp = new Elysia() .use(betterAuthElysia) .group("/party-socket", (app) => - app.ws("/ws", { - beforeHandle: async ({ request, set }) => { - const session = await auth.api.getSession({ - headers: request.headers, - }); - if (!session) { - set.status = 401; - return; - } - return { - user: session.user, - session: session.session, - }; - }, - open: async (ws) => { - const data = ws.data as unknown as PartyWsData; - const user = data.user; - if (!user) return; - registerUserSocket(user.id, ws); - const membership = await getMemberRecord(db, user.id); - if (!membership) { - ws.send( - JSON.stringify({ - type: "snapshot", - party: null, - members: [], - }), - ); - return; - } + app + .get("/test", () => ({ ok: 1 })) + .ws("/ws", { + auth: true, + open: async (ws) => { + const user = ws.data.user; + if (!user) return; + registerUserSocket(user.id, ws); + const membership = await getMemberRecord(db, user.id); + if (!membership) { + ws.send( + JSON.stringify({ + type: "snapshot", + party: null, + members: [], + }), + ); + return; + } - const snapshot = await getPartyStatus(membership.partyId); - data.partyId = membership.partyId; - registerPartySocket(membership.partyId, user.id, ws); - if (snapshot) { - ws.send( - JSON.stringify({ - type: "snapshot", - party: snapshot.party, - members: snapshot.members, - }), - ); - } - }, - message: async (ws, message: PartySocketMessage) => { - const data = ws.data as unknown as PartyWsData; - const user = data.user; - if (!user) return; - if (message.type === "ping") { - ws.send(JSON.stringify({ type: "pong" })); - return; - } + const snapshot = await getPartyStatus(membership.partyId); + ws.data.partyId = membership.partyId; + registerPartySocket(membership.partyId, user.id, ws); + if (snapshot) { + ws.send( + JSON.stringify({ + type: "snapshot", + party: snapshot.party, + members: snapshot.members, + }), + ); + } + }, + message: async (ws, message: PartySocketOutgoing) => { + const data = ws.data; + const user = data.user; + if (!user) return; + if (message.type === "ping") { + ws.send(JSON.stringify({ type: "pong" })); + return; + } - if (message.type !== "member_payload") return; - const membership = await getMemberRecord(db, user.id); - if (!membership) return; + if (message.type !== "member_payload") return; + const membership = await getMemberRecord(db, user.id); + if (!membership) return; - if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) { - ws.send( - JSON.stringify({ - type: "error", - message: "Payload too large.", - }), - ); - return; - } + if (getPayloadSize(message.payload) > MAX_MEMBER_PAYLOAD_SIZE) { + ws.send( + JSON.stringify({ + type: "error", + message: "Payload too large.", + }), + ); + return; + } - const currentParty = await db.query.party.findFirst({ - where: { id: membership.partyId }, - }); - if (!currentParty) return; + const currentParty = await db.query.party.findFirst({ + where: { id: membership.partyId }, + }); + if (!currentParty) return; - sendPartyEventToUser(membership.partyId, currentParty.hostId, { - type: "member_payload", - fromUserId: user.id, - payload: message.payload, - }); - }, - close: async (ws) => { - const data = ws.data as unknown as PartyWsData; - const user = data.user; - const { partyId } = data; - if (!user) return; - if (!partyId) { - unregisterUserSocketFromAllParties(user.id, ws); + sendPartyEventToUser(membership.partyId, currentParty.hostId, { + type: "member_payload", + fromUserId: user.id, + payload: message.payload, + }); + }, + close: async (ws) => { + const data = ws.data; + const user = data.user; + const { partyId } = data; + if (!user) return; + if (!partyId) { + unregisterUserSocketFromAllParties(user.id, ws); + unregisterUserSocket(user.id, ws); + return; + } + unregisterPartySocket(partyId, user.id, ws); unregisterUserSocket(user.id, ws); - return; - } - unregisterPartySocket(partyId, user.id, ws); - unregisterUserSocket(user.id, ws); - }, - body: t.Union([ - t.Object({ type: t.Literal("ping") }), - t.Object({ type: t.Literal("member_payload"), payload: t.Any() }), - ]), - }), + }, + body: t.Union([ + t.Object({ type: t.Literal("ping") }), + t.Object({ type: t.Literal("member_payload"), payload: t.Any() }), + ]), + }), ); diff --git a/api/src/routes/party.ts b/api/src/routes/party.ts index 2477a53..350d49f 100644 --- a/api/src/routes/party.ts +++ b/api/src/routes/party.ts @@ -15,12 +15,11 @@ import { reassignUserSocketsToParty, sendDirectEventToUser, } from "../party-sockets"; - -const PARTY_STATUS = ["created", "started", "ended"] as const; - -type PartyStatus = (typeof PARTY_STATUS)[number]; - -type PartySnapshot = NonNullable>>; +import { + PARTY_STATUS, + type PartySnapshot, + type PartyStatus, +} from "../party-types"; function broadcastSnapshot(partyId: string, snapshot: PartySnapshot | null) { if (!snapshot) return; diff --git a/web/src/components/user-info.tsx b/web/src/components/user-info.tsx index c9f7591..9c4c2dd 100644 --- a/web/src/components/user-info.tsx +++ b/web/src/components/user-info.tsx @@ -1,29 +1,39 @@ import { useRouteContext } from "@tanstack/react-router"; -import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar"; -import { - Item, - ItemContent, - ItemDescription, - ItemMedia, - ItemTitle, -} from "./ui/item"; +import { useParty } from "#/hooks/use-party"; import { useUser } from "#/hooks/user"; import { initials } from "#/lib/utils"; +import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar"; +import { + Item, + ItemContent, + ItemDescription, + ItemMedia, + ItemTitle, +} from "./ui/item"; export function UserInfo() { - const { user } = useUser(); - return ( - - - - - {initials(user?.name || "")} - - - - {user?.name} - No party yet - - - ); + const { user } = useUser(); + const { party, members, isConnecting, isReconnecting } = useParty(); + return ( + + + + + {initials(user?.name || "")} + + + + {user?.name} + + {isConnecting + ? "Connecting..." + : isReconnecting + ? "Reconnecting..." + : party + ? `${members.length} in party` + : "No party yet"} + + + + ); } diff --git a/web/src/hooks/use-party-socket.ts b/web/src/hooks/use-party-socket.ts new file mode 100644 index 0000000..2b901b9 --- /dev/null +++ b/web/src/hooks/use-party-socket.ts @@ -0,0 +1,127 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import type { PartySocketEvent } from "../../../api/src/party-types"; + +type Handler = (event: PartySocketEvent) => void; + +const PING_INTERVAL_MS = 30_000; +const RECONNECT_BASE_MS = 1_000; +const RECONNECT_MAX_MS = 30_000; + +export function usePartySocket({ + apiUrl, + onMessage, +}: { + apiUrl: string | null; + onMessage: Handler | null; +}) { + const [connectionState, setConnectionState] = useState< + "disconnected" | "connecting" | "connected" | "reconnecting" + >("disconnected"); + + const wsRef = useRef(null); + const pingTimerRef = useRef | null>(null); + const reconnectTimerRef = useRef | null>(null); + const reconnectAttemptRef = useRef(0); + const handlerRef = useRef(onMessage); + + useEffect(() => { + handlerRef.current = onMessage; + }, [onMessage]); + + const setupWs = useCallback( + (ws: WebSocket) => { + ws.onopen = () => { + reconnectAttemptRef.current = 0; + setConnectionState("connected"); + pingTimerRef.current = setInterval(() => { + if (ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ type: "ping" })); + } + }, PING_INTERVAL_MS); + }; + + ws.onmessage = (event) => { + const parsed = JSON.parse(event.data) as PartySocketEvent; + handlerRef.current?.(parsed); + }; + + ws.onclose = () => { + if (pingTimerRef.current) { + clearInterval(pingTimerRef.current); + pingTimerRef.current = null; + } + wsRef.current = null; + setConnectionState("reconnecting"); + + const delay = Math.min( + RECONNECT_BASE_MS * 2 ** reconnectAttemptRef.current, + RECONNECT_MAX_MS, + ); + reconnectAttemptRef.current++; + reconnectTimerRef.current = setTimeout(() => { + if (!apiUrl) return; + const protocol = apiUrl.startsWith("https") ? "wss" : "ws"; + const newWs = new WebSocket( + `${protocol}://${apiUrl.replace(/https?:\/\//, "")}/api/party-socket/ws`, + ); + wsRef.current = newWs; + setupWs(newWs); + }, delay); + }; + }, + [apiUrl], + ); + + useEffect(() => { + if (!apiUrl) { + if (wsRef.current) { + wsRef.current.close(); + wsRef.current = null; + } + if (pingTimerRef.current) { + clearInterval(pingTimerRef.current); + pingTimerRef.current = null; + } + if (reconnectTimerRef.current) { + clearTimeout(reconnectTimerRef.current); + reconnectTimerRef.current = null; + } + setConnectionState("disconnected"); + reconnectAttemptRef.current = 0; + return; + } + + setConnectionState("connecting"); + const protocol = apiUrl.startsWith("https") ? "wss" : "ws"; + const ws = new WebSocket( + `${protocol}://${apiUrl.replace(/https?:\/\//, "")}/api/party-socket/ws`, + ); + wsRef.current = ws; + setupWs(ws); + + return () => { + ws.close(); + wsRef.current = null; + if (pingTimerRef.current) { + clearInterval(pingTimerRef.current); + pingTimerRef.current = null; + } + if (reconnectTimerRef.current) { + clearTimeout(reconnectTimerRef.current); + reconnectTimerRef.current = null; + } + }; + }, [apiUrl, setupWs]); + + const state = useMemo( + () => ({ + connectionState, + isConnected: connectionState === "connected", + isConnecting: connectionState === "connecting", + isReconnecting: connectionState === "reconnecting", + }), + [connectionState], + ); + + return state; +} diff --git a/web/src/hooks/use-party.ts b/web/src/hooks/use-party.ts new file mode 100644 index 0000000..6f3e232 --- /dev/null +++ b/web/src/hooks/use-party.ts @@ -0,0 +1,74 @@ +import { useCallback, useMemo, useState } from "react"; +import type { + PartyMember, + PartySocketEvent, + PartyState, +} from "../../../api/src/party-types"; +import { usePartySocket } from "./use-party-socket"; +import { useUser } from "./user"; + +function reducePartyState( + state: PartyState, + event: PartySocketEvent, +): PartyState { + switch (event.type) { + case "snapshot": + return { party: event.party, members: event.members }; + case "party_status": + return { party: event.party, members: event.members }; + case "member_joined": + return state; + case "member_left": + return { + ...state, + members: state.members.filter( + (m: PartyMember) => m.userId !== event.userId, + ), + }; + case "host_changed": + if (!state.party) return state; + return { + ...state, + party: { ...state.party, hostId: event.hostId }, + }; + case "member_payload": + case "pong": + case "error": + return state; + } +} + +function getApiUrl(): string | null { + if (typeof window === "undefined") return null; + const envUrl = import.meta.env.VITE_BETTER_AUTH_URL; + if (envUrl) return envUrl; + return `${window.location.protocol}//${window.location.host}`; +} + +export function useParty() { + const { session } = useUser(); + const [state, setState] = useState({ + party: null, + members: [], + }); + + const handleMessage = useCallback((event: PartySocketEvent) => { + setState((prev: PartyState) => reducePartyState(prev, event)); + }, []); + + const apiUrl = useMemo(() => { + const url = getApiUrl(); + if (!url) return null; + return url; + }, []); + + const wsState = usePartySocket({ + apiUrl, + onMessage: session ? handleMessage : null, + }); + + return { + ...state, + ...wsState, + }; +} diff --git a/web/vite.config.ts b/web/vite.config.ts index d6578f7..186b21c 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -20,7 +20,17 @@ const config = defineConfig({ ], server: { proxy: { - "/api": "http://localhost:4000", + "/api": { + target: "http://localhost:4000", + changeOrigin: true, + rewrite: (path) => + path.replace(/^\/api/, "/api"), + }, + "/api/party-socket/ws": { + target: "ws://localhost:4000", + ws: true, + rewriteWsOrigin: true, + }, }, }, });