diff --git a/api/src/party/__tests__/question-generation.test.ts b/api/src/party/__tests__/question-generation.test.ts index 3048bce..24713f7 100644 --- a/api/src/party/__tests__/question-generation.test.ts +++ b/api/src/party/__tests__/question-generation.test.ts @@ -6,10 +6,12 @@ import { type PartyAnalytics, type PartyQuestionMember, pickQuestionCandidate, + selectQuestionSong, } from "../question-utils"; import { buildSocialQuestion } from "../social-question-generator"; type Db = typeof import("../../db").db; +type Song = NonNullable; function makeChoiceQuestion( text: string, @@ -61,6 +63,39 @@ function createFakeDb(trackReleaseDate: Date | null) { } as unknown as Db; } +function makeSong(id: string, platformId: string, name: string): Song { + return { + id, + albumId: "album-1", + platform: "spotify", + platform_id: platformId, + name, + popularity: 1, + duration: 1, + explicit: false, + disc_number: 1, + track_number: 1, + }; +} + +function createSongFallbackDb(rows: Song[]) { + return { + query: { + topTrack: { + findMany: vi.fn(async () => + rows.map((row, index) => ({ + position: index + 1, + track: row, + })), + ), + }, + track: { + findMany: vi.fn(async () => []), + }, + }, + } as unknown as Db; +} + describe("question generation", () => { it("skips repeated question keys, subjects, and text", () => { const history: QuizRound[] = [ @@ -230,4 +265,79 @@ describe("question generation", () => { expect(question).toBeNull(); }); + + it("selects a fresh party song when the current one was already used", async () => { + const db = createSongFallbackDb([ + makeSong("track-1", "spotify:track:one", "One"), + makeSong("track-2", "spotify:track:two", "Two"), + ]); + const question = { + type: "choice" as const, + text: "Which genre appears most in the party analytics?", + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "audio:genre:pop", + subjectKey: "genre:pop", + }; + + const song = await selectQuestionSong({ + db, + analytics: null, + members: [{ userId: "a", name: "A" }], + history: [ + { + questionIndex: 0, + question: { + ...question, + song: makeSong("track-1", "spotify:track:one", "One"), + }, + responses: [], + }, + ], + question, + }); + + expect(song?.platform_id).toBe("spotify:track:two"); + }); + + it("keeps a song-target question on the same track", async () => { + const db = createSongFallbackDb([ + makeSong("track-1", "spotify:track:one", "One"), + makeSong("track-2", "spotify:track:two", "Two"), + ]); + const question = { + type: "choice" as const, + text: "What song is currently playing?", + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "audio:current-song:One", + subjectKey: "track:One", + hideSongTitle: true, + song: { + ...makeSong("track-1", "spotify:track:one", "One"), + }, + }; + + const song = await selectQuestionSong({ + db, + analytics: null, + members: [{ userId: "a", name: "A" }], + history: [ + { + questionIndex: 0, + question, + responses: [], + }, + ], + question, + }); + + expect(song?.platform_id).toBe("spotify:track:one"); + }); }); diff --git a/api/src/party/__tests__/question-generator.test.ts b/api/src/party/__tests__/question-generator.test.ts index 5aef0f7..f8fb12f 100644 --- a/api/src/party/__tests__/question-generator.test.ts +++ b/api/src/party/__tests__/question-generator.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it, vi } from "vitest"; import type { QuizState } from "../../party-types"; +import * as audioQuestionGenerator from "../audio-question-generator"; vi.mock("../audio-question-generator", () => ({ buildAudioMetadataQuestion: vi.fn(async () => null), @@ -47,4 +48,59 @@ describe("generatePartyQuestion", () => { expect(question).toBeNull(); }); + + it("attaches a fallback song to generated questions", async () => { + vi.mocked( + audioQuestionGenerator.buildAudioMetadataQuestion, + ).mockResolvedValueOnce({ + type: "choice", + text: "Which genre appears most in the party analytics?", + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "audio:genre:pop", + subjectKey: "genre:pop", + }); + + const quizState = { + status: "running", + workflowId: null, + questionIndex: 0, + currentQuestion: null, + answers: {}, + scores: {}, + history: [], + } as QuizState; + + const question = await generatePartyQuestion({ + db: { + query: { + partyMember: { + findMany: vi.fn(async () => [{ userId: "a", user: { name: "A" } }]), + }, + topTrack: { + findMany: vi.fn(async () => [ + { + position: 1, + track: { + id: "track-1", + platform: "spotify", + platform_id: "spotify:track:one", + name: "One", + }, + }, + ]), + }, + }, + } as never, + partyId: "party-1", + quizState, + analytics: null, + index: 0, + }); + + expect(question?.song?.platform_id).toBe("spotify:track:one"); + }); }); diff --git a/api/src/party/question-generator.ts b/api/src/party/question-generator.ts index 1a44d42..f6eec91 100644 --- a/api/src/party/question-generator.ts +++ b/api/src/party/question-generator.ts @@ -2,8 +2,11 @@ import type { db } from "../db"; import type { Question, QuizState } from "../party-types"; import { buildAudioMetadataQuestion } from "./audio-question-generator"; import { buildNumericQuestion } from "./numeric-question-generator"; -import type { PartyAnalytics } from "./question-utils"; -import { fetchPartyMembers } from "./question-utils"; +import { + fetchPartyMembers, + type PartyAnalytics, + selectQuestionSong, +} from "./question-utils"; import { buildSocialQuestion } from "./social-question-generator"; export type PartyQuestionType = "audio-metadata" | "social" | "numeric"; @@ -36,37 +39,44 @@ export async function generatePartyQuestion({ ]; for (const type of typeOrder) { + let question: Question | null = null; if (type === "audio-metadata") { - const q = await buildAudioMetadataQuestion( + question = await buildAudioMetadataQuestion( dbClient, analytics, index, quizState.history, ); - if (q) return q; - continue; - } - - if (type === "social") { - const q = await buildSocialQuestion( + } else if (type === "social") { + question = await buildSocialQuestion( dbClient, quizState, analytics, members, index, ); - if (q) return q; + } else { + question = await buildNumericQuestion({ + db: dbClient, + analytics, + index, + members, + history: quizState.history, + }); + } + + if (!question) { continue; } - const q = await buildNumericQuestion({ + const song = await selectQuestionSong({ db: dbClient, analytics, - index, members, history: quizState.history, + question, }); - if (q) return q; + return song ? { ...question, song } : question; } return null; diff --git a/api/src/party/question-utils.ts b/api/src/party/question-utils.ts index af51e04..f0cbdb3 100644 --- a/api/src/party/question-utils.ts +++ b/api/src/party/question-utils.ts @@ -1,7 +1,7 @@ import type { InferSelectModel } from "drizzle-orm"; import type { db as Db } from "../db"; import type { track as trackTable } from "../db/schema"; -import type { QuizRound } from "../party-types"; +import type { Question, QuizRound } from "../party-types"; export type PartyQuestionMember = { userId: string; @@ -262,6 +262,228 @@ export async function resolveQuestionSong( return song; } +type SongSelectionInput = { + db: typeof Db; + analytics: PartyAnalytics; + members: PartyQuestionMember[]; + history: QuizRound[]; + question: Question; +}; + +export async function selectQuestionSong({ + db, + analytics, + members, + history, + question, +}: SongSelectionInput): Promise { + const keepSpecificSong = isSongTargetQuestion(question); + const usedPlatformIds = new Set( + history + .map((round) => round.question.song?.platform_id) + .filter((value): value is string => isUsableText(value)), + ); + + const candidates = await collectSongCandidates({ + db, + analytics, + members, + question, + }); + + if (candidates.length === 0) return question.song ?? null; + if (keepSpecificSong) return candidates[0] ?? question.song ?? null; + + const freshCandidate = candidates.find( + (candidate) => + isUsableText(candidate.platform_id) && + !usedPlatformIds.has(candidate.platform_id), + ); + return freshCandidate ?? candidates[0] ?? question.song ?? null; +} + +async function collectSongCandidates({ + db, + analytics, + members, + question, +}: { + db: typeof Db; + analytics: PartyAnalytics; + members: PartyQuestionMember[]; + question: Question; +}): Promise { + const candidates: QuestionSong[] = []; + const seen = new Set(); + const push = (song: QuestionSong | null | undefined) => { + if (!song || !isUsableText(song.platform_id)) return; + if (seen.has(song.platform_id)) return; + seen.add(song.platform_id); + candidates.push(song); + }; + + push(question.song); + + const subjectSong = await resolveSongFromQuestionSubject( + db, + analytics, + question, + ); + push(subjectSong); + + const peopleSong = await resolveSongFromMentionedPeople( + db, + analytics, + question, + ); + push(peopleSong); + + const topClusterTracks = [...(analytics?.storyClusters?.[0]?.tracks ?? [])] + .filter((track) => isUsableText(track.name)) + .sort((a, b) => getTrackScore(b) - getTrackScore(a)); + + for (const track of topClusterTracks) { + const song = await resolveQuestionSong(db, analytics, { + trackName: track.name, + artistNames: track.artists?.map((artist) => artist.name), + albumName: track.albumName, + }); + push(song); + } + + if (members.length > 0) { + const topPartySongs = await fetchPartyTopSongs(db, members); + for (const song of topPartySongs) { + push(song); + } + } + + return candidates; +} + +async function resolveSongFromQuestionSubject( + db: typeof Db, + analytics: PartyAnalytics, + question: Question, +): Promise { + const subjectKey = question.subjectKey ?? ""; + if (subjectKey.startsWith("track:")) { + const trackName = subjectKey.slice("track:".length).trim(); + if (!trackName) return null; + return resolveQuestionSong(db, analytics, { trackName }); + } + + if (subjectKey.startsWith("artist:")) { + const artistName = subjectKey.slice("artist:".length).trim(); + if (!artistName) return null; + return resolveQuestionSong(db, analytics, { artistNames: [artistName] }); + } + + return null; +} + +async function resolveSongFromMentionedPeople( + db: typeof Db, + analytics: PartyAnalytics, + question: Question, +): Promise { + const subjectKey = question.subjectKey ?? ""; + const userIds = subjectKey.startsWith("member:") + ? [subjectKey.slice("member:".length).trim()].filter(Boolean) + : subjectKey.startsWith("pair:") + ? subjectKey + .slice("pair:".length) + .split("|") + .map((value) => value.trim()) + .filter(Boolean) + : []; + + if (userIds.length === 0) return null; + + const tracks = [...(analytics?.storyClusters?.[0]?.tracks ?? [])] + .filter((track) => isUsableText(track.name)) + .sort( + (a, b) => + getMemberTrackScore(b, userIds) - getMemberTrackScore(a, userIds), + ); + + for (const track of tracks) { + const song = await resolveQuestionSong(db, analytics, { + trackName: track.name, + artistNames: track.artists?.map((artist) => artist.name), + albumName: track.albumName, + }); + if (song) return song; + } + + return null; +} + +async function fetchPartyTopSongs( + db: typeof Db, + members: PartyQuestionMember[], +): Promise { + const songs: QuestionSong[] = []; + const seen = new Set(); + + for (const member of members) { + const rows = await db.query.topTrack.findMany({ + where: { + userId: member.userId, + }, + with: { + track: { + with: { + album: true, + artists: true, + }, + }, + }, + orderBy: { + position: "asc", + }, + limit: 5, + }); + + for (const row of rows) { + const song = row.track; + if (!song || !isUsableText(song.platform_id)) continue; + if (seen.has(song.platform_id)) continue; + seen.add(song.platform_id); + songs.push(song); + } + } + + return songs; +} + +function isSongTargetQuestion(question: Question): boolean { + const key = question.questionKey?.toLowerCase() ?? ""; + const text = question.text.toLowerCase(); + return ( + question.hideSongTitle === true || + key.startsWith("audio:current-song:") || + text.includes("what song") || + text.includes("which song") + ); +} + +function getTrackScore(track: { memberScores?: { score: number }[] }): number { + return (track.memberScores ?? []).reduce( + (total, entry) => total + entry.score, + 0, + ); +} + +function getMemberTrackScore( + track: { memberScores?: { userId: string; score: number }[] }, + userIds: string[], +): number { + return (track.memberScores ?? []).reduce((total, entry) => { + return userIds.includes(entry.userId) ? total + entry.score : total; + }, 0); +} + export function isUsableText( value: string | null | undefined, ): value is string {