diff --git a/api/src/party/__tests__/question-generation.test.ts b/api/src/party/__tests__/question-generation.test.ts index b6dc36d..6f46b29 100644 --- a/api/src/party/__tests__/question-generation.test.ts +++ b/api/src/party/__tests__/question-generation.test.ts @@ -576,4 +576,62 @@ describe("question generation", () => { expect(song?.platform_id).toBe("spotify:track:one"); }); + + it("uses an adjacent song for generic metadata questions", async () => { + const db = createSongFallbackDb([ + makeSong("track-1", "spotify:track:one", "One"), + makeSong("track-2", "spotify:track:two", "Two"), + ]); + const question = { + type: "choice" as const, + text: "Which genre appears most in the party analytics?", + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "audio:genre:pop", + subjectKey: "genre:pop", + song: makeSong("track-1", "spotify:track:one", "One"), + }; + + const song = await selectQuestionSong({ + db, + analytics: null, + members: [{ userId: "a", name: "A" }], + history: [], + question, + }); + + expect(song?.platform_id).toBe("spotify:track:two"); + }); + + it("keeps album questions on the referenced track", async () => { + const db = createSongFallbackDb([ + makeSong("track-1", "spotify:track:one", "One"), + makeSong("track-2", "spotify:track:two", "Two"), + ]); + const question = { + type: "choice" as const, + text: '"One" appears on which album?', + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "audio:album:One Album", + subjectKey: "track:One", + song: makeSong("track-1", "spotify:track:one", "One"), + }; + + const song = await selectQuestionSong({ + db, + analytics: null, + members: [{ userId: "a", name: "A" }], + history: [], + question, + }); + + expect(song?.platform_id).toBe("spotify:track:one"); + }); }); diff --git a/api/src/party/question-utils.ts b/api/src/party/question-utils.ts index 78f8fee..e8f37c2 100644 --- a/api/src/party/question-utils.ts +++ b/api/src/party/question-utils.ts @@ -341,6 +341,7 @@ type SongSelectionInput = { type SongCandidate = { song: QuestionSong; fairness?: QuestionCandidateFairness; + source?: "question" | "subject"; }; export async function selectQuestionSong({ @@ -350,7 +351,7 @@ export async function selectQuestionSong({ history, question, }: SongSelectionInput): Promise { - const keepSpecificSong = isSongTargetQuestion(question); + const keepSpecificSong = shouldUseQuestionSubjectSong(question); const usedPlatformIds = new Set( history .map((round) => round.question.song?.platform_id) @@ -366,15 +367,42 @@ export async function selectQuestionSong({ }); if (candidates.length === 0) return question.song ?? null; - if (keepSpecificSong) return candidates[0]?.song ?? question.song ?? null; + if (keepSpecificSong) { + return ( + candidates.find((candidate) => candidate.source === "subject")?.song ?? + candidates.find((candidate) => candidate.source === "question")?.song ?? + candidates[0]?.song ?? + question.song ?? + null + ); + } - const freshCandidates = candidates.filter( + const exactPlatformIds = new Set( + candidates + .filter( + (candidate) => + candidate.source === "question" || candidate.source === "subject", + ) + .map((candidate) => candidate.song.platform_id) + .filter((value): value is string => isUsableText(value)), + ); + const adjacentCandidates = + exactPlatformIds.size > 0 + ? candidates.filter( + (candidate) => + !exactPlatformIds.has(candidate.song.platform_id ?? ""), + ) + : candidates; + + const freshCandidates = adjacentCandidates.filter( (candidate) => isUsableText(candidate.song.platform_id) && !usedPlatformIds.has(candidate.song.platform_id), ); const selected = - pickFairSongCandidate(freshCandidates) ?? pickFairSongCandidate(candidates); + pickFairSongCandidate(freshCandidates) ?? + pickFairSongCandidate(adjacentCandidates) ?? + pickFairSongCandidate(candidates); return selected?.song ?? question.song ?? null; } @@ -396,14 +424,15 @@ async function collectSongCandidates({ const push = ( song: QuestionSong | null | undefined, fairness?: QuestionCandidateFairness, + source?: SongCandidate["source"], ) => { if (!song || !isUsableText(song.platform_id)) return; if (seen.has(song.platform_id)) return; seen.add(song.platform_id); - candidates.push({ song, fairness }); + candidates.push({ song, fairness, source }); }; - push(question.song); + push(question.song, undefined, "question"); const subjectSong = await resolveSongFromQuestionSubject( db, @@ -413,6 +442,7 @@ async function collectSongCandidates({ push( subjectSong, getQuestionSubjectFairness(analytics, members, history, question), + "subject", ); const peopleSong = await resolveSongFromMentionedPeople( @@ -608,14 +638,16 @@ function getQuestionSubjectFairness( return undefined; } -function isSongTargetQuestion(question: Question): boolean { +function shouldUseQuestionSubjectSong(question: Question): boolean { const key = question.questionKey?.toLowerCase() ?? ""; - const text = question.text.toLowerCase(); return ( question.hideSongTitle === true || key.startsWith("audio:current-song:") || - text.includes("what song") || - text.includes("which song") + key.startsWith("audio:title:") || + key.startsWith("audio:album:") || + key.startsWith("audio:performer:") || + key.startsWith("numeric:album-year:") || + key.startsWith("numeric:track-year:") ); } diff --git a/api/src/routes/party-analysis.ts b/api/src/routes/party-analysis.ts index 7e775ff..9311f4b 100644 --- a/api/src/routes/party-analysis.ts +++ b/api/src/routes/party-analysis.ts @@ -26,10 +26,13 @@ export const partyAnalysisApp = new Elysia() return { error: "Only the host can trigger analysis." }; } - const result = await partyAnalysisWorkflow.analyzeParty( - membership.partyId, - ); - return result; + await partyAnalysisWorkflow.analyzeParty(membership.partyId); + const updatedParty = await db.query.party.findFirst({ + where: { + id: membership.partyId, + }, + }); + return updatedParty?.analysisData ?? null; }, { auth: true, diff --git a/api/src/workflows/__tests__/party-analysis.test.ts b/api/src/workflows/__tests__/party-analysis.test.ts index ad4bbab..295aabc 100644 --- a/api/src/workflows/__tests__/party-analysis.test.ts +++ b/api/src/workflows/__tests__/party-analysis.test.ts @@ -1,6 +1,7 @@ /** biome-ignore-all lint/style/noNonNullAssertion: test setup uses controlled arrays */ import { DBOS } from "@dbos-inc/dbos-sdk"; import { describe, expect, it } from "vitest"; +import { db } from "../../db"; import { addFollowedArtist, addPlaybackHistory, @@ -21,6 +22,46 @@ import "../../dbos"; await DBOS.launch(); +async function analyzeParty(partyId: string) { + await partyAnalysisWorkflow.analyzeParty(partyId); + const savedParty = await db.query.party.findFirst({ + where: { id: partyId }, + }); + return savedParty?.analysisData as { + storyClusters: Array<{ + memberIds: string[]; + memberCount: number; + tracks: Array<{ + id: string; + name: string; + memberScores: Array<{ userId: string; score: number }>; + memberCount: number; + }>; + artists: Array<{ id: string; name: string; memberCount: number }>; + genres: unknown[]; + }>; + pairwise: Array<{ + userIdA: string; + userIdB: string; + sharedTracks: number; + sharedArtists: number; + sharedGenres: number; + similarity: number; + }>; + groupSummary: { + totalMembers: number; + mostSharedGenres: unknown[]; + mostDiverseMember: { genreEntropy: number } | null; + mostAlignedPair: { userIdA: string; userIdB: string } | null; + }; + memberProfiles: Array<{ + userId: string; + trackCount: number; + artistCount: number; + }>; + }; +} + describe("PartyAnalysisWorkflow", () => { describe("analyzeParty - basic behavior", () => { it("returns empty results for party with fewer than 2 members", async () => { @@ -28,7 +69,7 @@ describe("PartyAnalysisWorkflow", () => { const party = await createParty(user.id); await joinParty(party.partyId, user.id); - const result = await partyAnalysisWorkflow.analyzeParty(party.partyId); + const result = await analyzeParty(party.partyId); expect(result.storyClusters).toHaveLength(0); expect(result.pairwise).toHaveLength(0); @@ -43,7 +84,7 @@ describe("PartyAnalysisWorkflow", () => { const party = await createParty(user.id); // Don't add any members - const result = await partyAnalysisWorkflow.analyzeParty(party.partyId); + const result = await analyzeParty(party.partyId); expect(result.storyClusters).toHaveLength(0); expect(result.groupSummary.totalMembers).toBe(0); @@ -54,7 +95,7 @@ describe("PartyAnalysisWorkflow", () => { const party = await createParty(user.id); await joinParty(party.partyId, user.id); - const result = await partyAnalysisWorkflow.analyzeParty(party.partyId); + const result = await analyzeParty(party.partyId); expect(result.storyClusters).toHaveLength(0); expect(result.groupSummary.totalMembers).toBe(1); @@ -67,7 +108,7 @@ describe("PartyAnalysisWorkflow", () => { const { partyId, userIdA, userIdB, sharedTrackId, sharedArtistId } = await seedPartyWithTwoSimilarUsers(); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); expect(result.storyClusters).toHaveLength(1); const cluster = result.storyClusters[0]!; @@ -94,8 +135,9 @@ describe("PartyAnalysisWorkflow", () => { // Should have exactly 1 pairwise comparison expect(result.pairwise).toHaveLength(1); const comparison = result.pairwise[0]!; - expect(comparison.userIdA).toBe(userIdA); - expect(comparison.userIdB).toBe(userIdB); + expect([comparison.userIdA, comparison.userIdB].sort()).toEqual( + [userIdA, userIdB].sort(), + ); expect(comparison.sharedTracks).toBeGreaterThan(0); expect(comparison.sharedArtists).toBeGreaterThan(0); expect(comparison.similarity).toBeGreaterThan(0); @@ -116,14 +158,20 @@ describe("PartyAnalysisWorkflow", () => { }); it("correctly identifies group summary", async () => { - const { partyId, userIdA } = await seedPartyWithTwoSimilarUsers(); + const { partyId, userIdA, userIdB } = + await seedPartyWithTwoSimilarUsers(); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); expect(result.groupSummary.totalMembers).toBe(2); expect(result.groupSummary.mostAlignedPair).toBeDefined(); if (result.groupSummary.mostAlignedPair) { - expect(result.groupSummary.mostAlignedPair.userIdA).toBe(userIdA); + expect( + [ + result.groupSummary.mostAlignedPair.userIdA, + result.groupSummary.mostAlignedPair.userIdB, + ].sort(), + ).toEqual([userIdA, userIdB].sort()); } expect(result.groupSummary.mostSharedGenres).toHaveLength(1); }); @@ -133,7 +181,7 @@ describe("PartyAnalysisWorkflow", () => { it("does not find shared tracks across all members", async () => { const { partyId } = await seedPartyWithThreeDiverseUsers(); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); expect(result.storyClusters).toHaveLength(3); expect(result.pairwise).toHaveLength(3); // C(3,2) = 3 pairs @@ -148,7 +196,7 @@ describe("PartyAnalysisWorkflow", () => { it("identifies pairwise comparisons for all member pairs", async () => { const { partyId } = await seedPartyWithThreeDiverseUsers(); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); expect(result.pairwise).toHaveLength(3); result.pairwise.forEach((comparison) => { @@ -160,7 +208,7 @@ describe("PartyAnalysisWorkflow", () => { it("correctly identifies genre diversity for each member", async () => { const { partyId } = await seedPartyWithThreeDiverseUsers(); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); expect(result.memberProfiles).toHaveLength(3); expect(result.groupSummary.mostDiverseMember).toBeDefined(); @@ -190,7 +238,7 @@ describe("PartyAnalysisWorkflow", () => { await addPlaybackHistory(userIdA, trackC, oldDate); await addPlaybackHistory(userIdB, trackC, oldDate); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); const comparison = result.pairwise[0]!; expect(comparison.sharedTracks).toBeGreaterThan(1); // sharedTrack + trackC @@ -212,7 +260,7 @@ describe("PartyAnalysisWorkflow", () => { ]); await addSavedTrack(userIdA, extraTrack.id); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); const profileA = result.memberProfiles.find((p) => p.userId === userIdA); expect(profileA).toBeDefined(); @@ -228,7 +276,7 @@ describe("PartyAnalysisWorkflow", () => { await addTopArtist(userIdA, followedArtist.id, 5); await addFollowedArtist(userIdA, followedArtist.id); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); const profileA = result.memberProfiles.find((p) => p.userId === userIdA); expect(profileA).toBeDefined(); @@ -251,7 +299,7 @@ describe("PartyAnalysisWorkflow", () => { await addTopTrack(userIdB, uniqueTrack.id, 2); await addTopArtist(userIdB, uniqueArtist.id, 2); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); const allTracks = result.storyClusters.flatMap( (cluster) => cluster.tracks, ); @@ -276,7 +324,7 @@ describe("PartyAnalysisWorkflow", () => { it("sorts clusters with all-member cluster first", async () => { const { partyId, sharedTrackId } = await seedPartyWithTwoSimilarUsers(); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); // The cluster with both members should be first expect(result.storyClusters[0]?.memberCount).toBe(2); @@ -302,7 +350,7 @@ describe("PartyAnalysisWorkflow", () => { await addTopTrack(userIdA, extraTrack.id, 50); await addTopTrack(userIdB, extraTrack.id, 50); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); const cluster = result.storyClusters[0]!; expect(cluster.tracks.length).toBeGreaterThan(1); @@ -321,7 +369,7 @@ describe("PartyAnalysisWorkflow", () => { it("calculates Jaccard-like similarity using min/max scoring", async () => { const { partyId } = await seedPartyWithTwoSimilarUsers(); - const result = await partyAnalysisWorkflow.analyzeParty(partyId); + const result = await analyzeParty(partyId); const comparison = result.pairwise[0]!; expect(comparison.sharedTracks).toBeGreaterThanOrEqual(1); @@ -338,8 +386,6 @@ describe("PartyAnalysisWorkflow", () => { await partyAnalysisWorkflow.analyzeParty(partyId); - const { db } = await import("../../db"); - const savedParty = await db.query.party.findFirst({ where: { id: partyId }, }); diff --git a/api/src/workflows/party-analysis.ts b/api/src/workflows/party-analysis.ts index 620d272..0153f5a 100644 --- a/api/src/workflows/party-analysis.ts +++ b/api/src/workflows/party-analysis.ts @@ -91,12 +91,30 @@ type PartyAnalysisResult = { memberProfiles: MemberProfile[]; }; +type PartyAnalysisWorkflowResult = { + partyId: string; + totalMembers: number; + analyzed: boolean; +}; + +const MAX_STORY_CLUSTERS = 8; +const MAX_CLUSTER_ENTITIES = 20; +const MAX_PAIRWISE_COMPARISONS = 20; +const MAX_PROFILE_GENRES = 20; + export class PartyAnalysisWorkflow extends ConfiguredInstance { @DBOS.workflow() - async analyzeParty(partyId: string): Promise { + async analyzeParty(partyId: string): Promise { + return this.analyzeAndSaveParty(partyId); + } + + @DBOS.step() + private async analyzeAndSaveParty( + partyId: string, + ): Promise { const members = await this.fetchPartyMembers(partyId); if (members.length < 2) { - return { + await this.saveAnalysis(partyId, { storyClusters: [], pairwise: [], groupSummary: { @@ -106,7 +124,8 @@ export class PartyAnalysisWorkflow extends ConfiguredInstance { mostAlignedPair: null, }, memberProfiles: [], - }; + }); + return { partyId, totalMembers: members.length, analyzed: false }; } const memberInfos = members.map((m) => ({ @@ -160,22 +179,17 @@ export class PartyAnalysisWorkflow extends ConfiguredInstance { mostAlignedPair, }; - await this.saveAnalysis(partyId, { + const analysis = this.compactAnalysis({ storyClusters, pairwise, groupSummary, memberProfiles, }); + await this.saveAnalysis(partyId, analysis); - return { - storyClusters, - pairwise, - groupSummary, - memberProfiles, - }; + return { partyId, totalMembers: members.length, analyzed: true }; } - @DBOS.step() private async fetchPartyMembers(partyId: string): Promise { const result = await db .select() @@ -185,7 +199,6 @@ export class PartyAnalysisWorkflow extends ConfiguredInstance { return result as PartyMemberRow[]; } - @DBOS.step() private async fetchAllMemberData( members: { userId: string }[], ): Promise> { @@ -199,7 +212,6 @@ export class PartyAnalysisWorkflow extends ConfiguredInstance { return result; } - @DBOS.step() private async fetchMemberScores(userId: string): Promise { const scores: MemberScores = { tracks: new Map(), @@ -879,7 +891,6 @@ export class PartyAnalysisWorkflow extends ConfiguredInstance { return genres.filter((g) => g.memberCount >= 2).slice(0, 10); } - @DBOS.step() private async saveAnalysis( partyId: string, analysis: PartyAnalysisResult, @@ -892,6 +903,29 @@ export class PartyAnalysisWorkflow extends ConfiguredInstance { }) .where(sql`${party.id} = ${partyId}`); } + + private compactAnalysis(analysis: PartyAnalysisResult): PartyAnalysisResult { + return { + storyClusters: analysis.storyClusters + .slice(0, MAX_STORY_CLUSTERS) + .map((cluster) => ({ + ...cluster, + tracks: cluster.tracks.slice(0, MAX_CLUSTER_ENTITIES), + artists: cluster.artists.slice(0, MAX_CLUSTER_ENTITIES), + genres: cluster.genres.slice(0, MAX_CLUSTER_ENTITIES), + })), + pairwise: analysis.pairwise.slice(0, MAX_PAIRWISE_COMPARISONS), + groupSummary: analysis.groupSummary, + memberProfiles: analysis.memberProfiles.map((profile) => ({ + ...profile, + genreScores: Object.fromEntries( + Object.entries(profile.genreScores) + .sort(([, left], [, right]) => right - left) + .slice(0, MAX_PROFILE_GENRES), + ), + })), + }; + } } interface MemberScores { diff --git a/api/src/workflows/quiz.ts b/api/src/workflows/quiz.ts index 479a3ab..23a81fc 100644 --- a/api/src/workflows/quiz.ts +++ b/api/src/workflows/quiz.ts @@ -1,7 +1,7 @@ import { ConfiguredInstance, DBOS, WorkflowQueue } from "@dbos-inc/dbos-sdk"; import { eq } from "drizzle-orm"; import { db } from "../db"; -import { partyMember } from "../db/schema"; +import { party, partyMember } from "../db/schema"; import { generatePartyQuestion } from "../party/question-generator"; import type { PartyAnalytics } from "../party/question-utils"; import { updatePartyData } from "../party/state"; @@ -174,11 +174,12 @@ export class QuizWorkflow extends ConfiguredInstance { quizState: QuizState, index: number, ): Promise { - const partyRecord = await db.query.party.findFirst({ - where: { - id: partyId, - }, - }); + const partyRecord = await db + .select({ analysisData: party.analysisData }) + .from(party) + .where(eq(party.id, partyId)) + .limit(1) + .then((rows) => rows[0]); const analytics = (partyRecord?.analysisData ?? null) as PartyAnalytics; const question = await generatePartyQuestion({ db,