From cd2ec0731448ac991e98209180a034d0d630e726 Mon Sep 17 00:00:00 2001 From: Daniel Bulant Date: Thu, 28 May 2026 17:46:57 +0200 Subject: [PATCH] improve random --- .../__tests__/question-generation.test.ts | 105 ++++++++++++++- .../__tests__/question-generator.test.ts | 77 ++++++++++- api/src/party/audio-question-generator.ts | 33 +++-- api/src/party/numeric-question-generator.ts | 23 +++- api/src/party/question-generator.ts | 51 ++++++-- api/src/party/question-utils.ts | 121 +++++++++++++----- 6 files changed, 341 insertions(+), 69 deletions(-) diff --git a/api/src/party/__tests__/question-generation.test.ts b/api/src/party/__tests__/question-generation.test.ts index 6f46b29..8554c4c 100644 --- a/api/src/party/__tests__/question-generation.test.ts +++ b/api/src/party/__tests__/question-generation.test.ts @@ -261,6 +261,53 @@ describe("question generation", () => { } }); + it("randomizes among top-tier candidates instead of only the highest score", () => { + const randomSpy = vi.spyOn(Math, "random").mockReturnValue(0.99); + + try { + const question = pickQuestionCandidate( + [ + { + key: "audio:track:highest", + subjectKey: "track:highest", + fairness: { memberIds: ["a", "b"], memberCount: 2, score: 100 }, + question: makeChoiceQuestion( + "Highest question", + "audio:track:highest", + "track:highest", + ), + }, + { + key: "audio:track:middle", + subjectKey: "track:middle", + fairness: { memberIds: ["a", "b"], memberCount: 2, score: 50 }, + question: makeChoiceQuestion( + "Middle question", + "audio:track:middle", + "track:middle", + ), + }, + { + key: "audio:track:lower", + subjectKey: "track:lower", + fairness: { memberIds: ["a", "b"], memberCount: 2, score: 25 }, + question: makeChoiceQuestion( + "Lower question", + "audio:track:lower", + "track:lower", + ), + }, + ], + [], + 0, + ); + + expect(question?.subjectKey).toBe("track:lower"); + } finally { + randomSpy.mockRestore(); + } + }); + it("orders fair tracks by party coverage before score", () => { const members: PartyQuestionMember[] = [ { userId: "a", name: "A" }, @@ -495,13 +542,40 @@ describe("question generation", () => { expect(question?.type).toBe("choice"); if (question?.type === "choice") { expect(question.options).toHaveLength(2); - expect(question.text).toContain("Shared Track Two"); } } finally { randomSpy.mockRestore(); } }); + it("builds metadata questions for non-top genres", async () => { + const randomSpy = vi.spyOn(Math, "random").mockReturnValue(0.99); + const db = createFakeDb(null); + const analytics = { + storyClusters: [], + groupSummary: { + mostSharedGenres: [{ name: "pop" }, { name: "rock" }, { name: "jazz" }], + }, + } as PartyAnalytics; + + try { + const question = await buildAudioMetadataQuestion( + db, + analytics, + [], + 0, + [], + ); + + expect(question?.questionKey).toBe("audio:genre:jazz"); + expect(question?.text).toBe( + "Which of these genres appears in the party analytics?", + ); + } finally { + randomSpy.mockRestore(); + } + }); + it("selects a fresh party song when the current one was already used", async () => { const db = createSongFallbackDb([ makeSong("track-1", "spotify:track:one", "One"), @@ -606,6 +680,35 @@ describe("question generation", () => { expect(song?.platform_id).toBe("spotify:track:two"); }); + it("prefers the referenced song for non-social subject questions", async () => { + const db = createSongFallbackDb([ + makeSong("track-1", "spotify:track:one", "One"), + makeSong("track-2", "spotify:track:two", "Two"), + ]); + const question = { + type: "choice" as const, + text: 'Who performs "One"?', + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "audio:performer:One", + subjectKey: "track:One", + song: makeSong("track-1", "spotify:track:one", "One"), + }; + + const song = await selectQuestionSong({ + db, + analytics: null, + members: [{ userId: "a", name: "A" }], + history: [], + question, + }); + + expect(song?.platform_id).toBe("spotify:track:one"); + }); + it("keeps album questions on the referenced track", async () => { const db = createSongFallbackDb([ makeSong("track-1", "spotify:track:one", "One"), diff --git a/api/src/party/__tests__/question-generator.test.ts b/api/src/party/__tests__/question-generator.test.ts index f8fb12f..f45097e 100644 --- a/api/src/party/__tests__/question-generator.test.ts +++ b/api/src/party/__tests__/question-generator.test.ts @@ -1,6 +1,7 @@ -import { describe, expect, it, vi } from "vitest"; +import { beforeEach, describe, expect, it, vi } from "vitest"; import type { QuizState } from "../../party-types"; import * as audioQuestionGenerator from "../audio-question-generator"; +import * as socialQuestionGenerator from "../social-question-generator"; vi.mock("../audio-question-generator", () => ({ buildAudioMetadataQuestion: vi.fn(async () => null), @@ -22,11 +23,24 @@ function createFakeDb() { partyMember: { findMany: vi.fn(async () => [{ userId: "a", user: { name: "A" } }]), }, + topTrack: { + findMany: vi.fn(async () => []), + }, }, }; } +function mockResolvedQuestion(fn: unknown, question: unknown) { + ( + fn as { mockResolvedValueOnce: (value: unknown) => void } + ).mockResolvedValueOnce(question); +} + describe("generatePartyQuestion", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + it("returns null when all real question sources are exhausted", async () => { const quizState = { status: "running", @@ -50,9 +64,7 @@ describe("generatePartyQuestion", () => { }); it("attaches a fallback song to generated questions", async () => { - vi.mocked( - audioQuestionGenerator.buildAudioMetadataQuestion, - ).mockResolvedValueOnce({ + mockResolvedQuestion(audioQuestionGenerator.buildAudioMetadataQuestion, { type: "choice", text: "Which genre appears most in the party analytics?", correct: 0, @@ -103,4 +115,61 @@ describe("generatePartyQuestion", () => { expect(question?.song?.platform_id).toBe("spotify:track:one"); }); + + it("prefers metadata questions over social questions when available", async () => { + mockResolvedQuestion(audioQuestionGenerator.buildAudioMetadataQuestion, { + type: "choice", + text: "Which track appears in the party analytics?", + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "audio:track:A", + subjectKey: "track:A", + }); + mockResolvedQuestion(socialQuestionGenerator.buildSocialQuestion, { + type: "choice", + text: "Who is leading the quiz right now?", + correct: 0, + startTimestamp: 1, + endTimestamp: 2, + points: 10, + options: ["A", "B"], + questionKey: "social:leader", + subjectKey: "member:a", + }); + const randomSpy = vi + .spyOn(Math, "random") + .mockReturnValueOnce(0.1) + .mockReturnValueOnce(0.9) + .mockReturnValueOnce(0.2); + + try { + const quizState = { + status: "running", + workflowId: null, + questionIndex: 0, + currentQuestion: null, + answers: {}, + scores: {}, + history: [], + } as QuizState; + + const question = await generatePartyQuestion({ + db: createFakeDb() as never, + partyId: "party-1", + quizState, + analytics: null, + index: 0, + }); + + expect(question?.questionKey).toBe("audio:track:A"); + expect( + socialQuestionGenerator.buildSocialQuestion, + ).not.toHaveBeenCalled(); + } finally { + randomSpy.mockRestore(); + } + }); }); diff --git a/api/src/party/audio-question-generator.ts b/api/src/party/audio-question-generator.ts index cbddef4..e9a322d 100644 --- a/api/src/party/audio-question-generator.ts +++ b/api/src/party/audio-question-generator.ts @@ -17,6 +17,8 @@ import { resolveQuestionSong, } from "./question-utils"; +const METADATA_ENTITY_POOL_SIZE = 8; + type TrackDetails = { id: string; name: string | null; @@ -98,16 +100,18 @@ export async function buildAudioMetadataQuestion( const genreNames = buildOrderedOptions(getMostSharedGenreNames(analytics), 4); if (genreNames) { - const topGenre = genreNames[0]; - if (topGenre) { - const genreOptions = buildOptionsWithCorrect(topGenre, genreNames, 4); + for (const genre of genreNames.slice(0, METADATA_ENTITY_POOL_SIZE)) { + const genreOptions = buildOptionsWithCorrect(genre, genreNames, 4); if (genreOptions) { questions.push({ - key: `audio:genre:${topGenre}`, - subjectKey: `genre:${topGenre}`, + key: `audio:genre:${genre}`, + subjectKey: `genre:${genre}`, question: { type: "choice", - text: "Which genre appears most in the party analytics?", + text: + genre === genreNames[0] + ? "Which genre appears most in the party analytics?" + : "Which of these genres appears in the party analytics?", options: genreOptions.options, correct: genreOptions.correct, points: 10, @@ -120,8 +124,10 @@ export async function buildAudioMetadataQuestion( const topArtistEntities = getFairQuestionArtists(analytics, members, history); const topArtists = topArtistEntities.map((artist) => artist.name); - const topArtist = topArtistEntities[0]; - if (topArtist) { + for (const topArtist of topArtistEntities.slice( + 0, + METADATA_ENTITY_POOL_SIZE, + )) { const artistOptions = buildOptionsWithCorrect( topArtist.name, topArtists, @@ -134,7 +140,10 @@ export async function buildAudioMetadataQuestion( fairness: getArtistFairness(topArtist, members, history), question: { type: "choice", - text: "Which artist shows up most often in the shared audio data?", + text: + topArtist === topArtistEntities[0] + ? "Which artist shows up most often in the shared audio data?" + : "Which artist shows up in the shared audio data?", options: artistOptions.options, correct: artistOptions.correct, points: 10, @@ -144,8 +153,7 @@ export async function buildAudioMetadataQuestion( } } - if (topTracks.length > 0) { - const topTrack = topTracks[0]; + for (const topTrack of topTracks.slice(0, METADATA_ENTITY_POOL_SIZE)) { const topTrackName = topTrack?.name; const trackOptions = topTrackName ? buildOptionsWithCorrect(topTrackName, topTrackNames, 4) @@ -158,9 +166,10 @@ export async function buildAudioMetadataQuestion( question: { type: "choice", text: + topTrack === topTracks[0] && getTrackFairness(topTrack, members, history).memberCount > 1 ? "Which track looks most shared across the party?" - : "Which track stands out in the party analytics?", + : "Which track appears in the party analytics?", options: trackOptions.options, correct: trackOptions.correct, points: 10, diff --git a/api/src/party/numeric-question-generator.ts b/api/src/party/numeric-question-generator.ts index 12ae907..ac655c1 100644 --- a/api/src/party/numeric-question-generator.ts +++ b/api/src/party/numeric-question-generator.ts @@ -16,6 +16,7 @@ import { type PartyAnalytics, type PartyQuestionMember, pickQuestionCandidate, + pickRandomTop, type QuestionCandidate, resolveQuestionSong, } from "./question-utils"; @@ -84,7 +85,9 @@ async function getAlbumReleaseYear({ members, history, }: BuildNumericQuestionInput): Promise { - const topTrack = getFairQuestionTracks(analytics, members, history)[0]; + const topTrack = pickRandomTop( + getFairQuestionTracks(analytics, members, history), + ); const trackName = topTrack?.name; const track = trackName ? await db.query.track.findFirst({ @@ -117,7 +120,9 @@ async function getTrackReleaseYear( input: BuildNumericQuestionInput, ): Promise { const tracks = await getDetailedTopTracks(input); - const track = tracks.find((track) => track.album?.release_date && track.name); + const track = pickRandomTop( + tracks.filter((track) => track.album?.release_date && track.name), + ); if (!track?.name || !track.album?.release_date) return null; const song = await resolveQuestionSong(input.db, input.analytics, { trackName: track.name, @@ -153,8 +158,10 @@ async function getArtistFirstTrackReleaseYear( } } - const artistEntry = Array.from(tracksByArtist.entries()).find( - ([, artistTracks]) => artistTracks.length >= 2, + const artistEntry = pickRandomTop( + Array.from(tracksByArtist.entries()).filter( + ([, artistTracks]) => artistTracks.length >= 2, + ), ); if (!artistEntry) return null; const [artistName, artistTracks] = artistEntry; @@ -189,7 +196,9 @@ async function countTopTrackListeners({ members, history, }: BuildNumericQuestionInput): Promise { - const topTrack = getFairQuestionTracks(analytics, members, history)[0]; + const topTrack = pickRandomTop( + getFairQuestionTracks(analytics, members, history), + ); const trackName = topTrack?.name; if (!trackName || members.length === 0) return null; const dbTrack = await db.query.track.findFirst({ @@ -227,7 +236,9 @@ async function countFavouriteArtistListeners({ members, history, }: BuildNumericQuestionInput): Promise { - const topArtist = getFairQuestionArtists(analytics, members, history)[0]; + const topArtist = pickRandomTop( + getFairQuestionArtists(analytics, members, history), + ); const artistName = topArtist?.name; if (!artistName || members.length === 0) return null; const dbArtist = await db.query.artist.findFirst({ diff --git a/api/src/party/question-generator.ts b/api/src/party/question-generator.ts index e6dce51..81a4a5e 100644 --- a/api/src/party/question-generator.ts +++ b/api/src/party/question-generator.ts @@ -1,5 +1,5 @@ import type { db } from "../db"; -import type { Question, QuizState } from "../party-types"; +import type { Question, QuizRound, QuizState } from "../party-types"; import { buildAudioMetadataQuestion } from "./audio-question-generator"; import { buildNumericQuestion } from "./numeric-question-generator"; import { @@ -27,16 +27,10 @@ export async function generatePartyQuestion({ index, }: GenerateQuestionInput): Promise { const members = await fetchPartyMembers(dbClient, partyId); - const preferredOrder: PartyQuestionType[] = [ - "audio-metadata", - "social", - "numeric", - ]; - const rotation = index % preferredOrder.length; - const typeOrder = [ - ...preferredOrder.slice(rotation), - ...preferredOrder.slice(0, rotation), - ]; + const typeOrder = getRandomQuestionTypeOrder( + ["audio-metadata", "social", "numeric"], + quizState.history, + ); for (const type of typeOrder) { let question: Question | null = null; @@ -82,3 +76,38 @@ export async function generatePartyQuestion({ return null; } + +function getRandomQuestionTypeOrder( + types: PartyQuestionType[], + history: QuizRound[], +): PartyQuestionType[] { + const recentTypes = history + .slice(-3) + .map((round) => getQuestionTypeFromKey(round.question.questionKey)); + + return types + .map((type) => ({ + type, + score: + getQuestionTypeBaseWeight(type) + + Math.random() * 0.35 - + recentTypes.filter((recent) => recent === type).length * 0.45, + })) + .sort((a, b) => b.score - a.score) + .map((entry) => entry.type); +} + +function getQuestionTypeBaseWeight(type: PartyQuestionType): number { + if (type === "audio-metadata") return 1; + if (type === "numeric") return 0.55; + return 0.1; +} + +function getQuestionTypeFromKey( + questionKey: string | undefined, +): PartyQuestionType | null { + if (questionKey?.startsWith("audio:")) return "audio-metadata"; + if (questionKey?.startsWith("social:")) return "social"; + if (questionKey?.startsWith("numeric:")) return "numeric"; + return null; +} diff --git a/api/src/party/question-utils.ts b/api/src/party/question-utils.ts index e8f37c2..dee3167 100644 --- a/api/src/party/question-utils.ts +++ b/api/src/party/question-utils.ts @@ -66,6 +66,9 @@ export type QuestionCandidateFairness = { }; export type QuestionSong = InferSelectModel; +const RANDOM_TOP_TIER_SIZE = 5; +const MAX_RANDOM_WEIGHT = 100; + export const QUESTION_DURATION_MS = 60_000; export const MIN_PARTY_SIZE = 2; export const MAX_PARTY_SIZE = 4; @@ -178,23 +181,12 @@ export function pickQuestionCandidate( }); if (fresh.length === 0) return null; - const bestMemberCount = Math.max( - ...fresh.map((candidate) => candidate.fairness?.memberCount ?? 0), + const candidate = pickWeightedRandom( + fresh.map((candidate) => ({ + item: candidate, + weight: getQuestionCandidateWeight(candidate), + })), ); - const bestScore = Math.max( - ...fresh - .filter( - (candidate) => - (candidate.fairness?.memberCount ?? 0) === bestMemberCount, - ) - .map((candidate) => candidate.fairness?.score ?? 0), - ); - const pool = fresh.filter( - (candidate) => - (candidate.fairness?.memberCount ?? 0) === bestMemberCount && - (candidate.fairness?.score ?? 0) === bestScore, - ); - const candidate = pickRandom(pool); if (!candidate) return null; return { ...candidate.question, @@ -394,15 +386,22 @@ export async function selectQuestionSong({ ) : candidates; + const allFreshCandidates = candidates.filter( + (candidate) => + isUsableText(candidate.song.platform_id) && + !usedPlatformIds.has(candidate.song.platform_id), + ); const freshCandidates = adjacentCandidates.filter( (candidate) => isUsableText(candidate.song.platform_id) && !usedPlatformIds.has(candidate.song.platform_id), ); - const selected = - pickFairSongCandidate(freshCandidates) ?? - pickFairSongCandidate(adjacentCandidates) ?? - pickFairSongCandidate(candidates); + const selected = shouldPreferQuestionSubjectSong(question) + ? (pickRelevantSongCandidate(question, allFreshCandidates) ?? + pickRelevantSongCandidate(question, candidates)) + : (pickFairSongCandidate(freshCandidates) ?? + pickFairSongCandidate(adjacentCandidates) ?? + pickFairSongCandidate(candidates)); return selected?.song ?? question.song ?? null; } @@ -592,23 +591,29 @@ function pickFairSongCandidate( candidates: SongCandidate[], ): SongCandidate | null { if (candidates.length === 0) return null; - const bestMemberCount = Math.max( - ...candidates.map((candidate) => candidate.fairness?.memberCount ?? 0), + return pickWeightedRandom( + candidates.map((candidate) => ({ + item: candidate, + weight: getQuestionCandidateFairnessWeight(candidate.fairness), + })), ); - const bestScore = Math.max( - ...candidates - .filter( +} + +function pickRelevantSongCandidate( + question: Question, + candidates: SongCandidate[], +): SongCandidate | null { + if (!shouldPreferQuestionSubjectSong(question)) { + return pickFairSongCandidate(candidates); + } + + return ( + pickFairSongCandidate( + candidates.filter( (candidate) => - (candidate.fairness?.memberCount ?? 0) === bestMemberCount, - ) - .map((candidate) => candidate.fairness?.score ?? 0), - ); - return pickRandom( - candidates.filter( - (candidate) => - (candidate.fairness?.memberCount ?? 0) === bestMemberCount && - (candidate.fairness?.score ?? 0) === bestScore, - ), + candidate.source === "subject" || candidate.source === "question", + ), + ) ?? pickFairSongCandidate(candidates) ); } @@ -651,6 +656,13 @@ function shouldUseQuestionSubjectSong(question: Question): boolean { ); } +function shouldPreferQuestionSubjectSong(question: Question): boolean { + const key = question.questionKey?.toLowerCase() ?? ""; + const subjectKey = question.subjectKey?.toLowerCase() ?? ""; + if (key.startsWith("social:")) return false; + return subjectKey.startsWith("track:") || subjectKey.startsWith("artist:"); +} + function getMemberTrackScore( track: { memberScores?: { userId: string; score: number }[] }, userIds: string[], @@ -807,6 +819,45 @@ export function pickRandom(items: T[]): T | null { return items[index] ?? null; } +export function pickRandomTop( + items: T[], + limit = RANDOM_TOP_TIER_SIZE, +): T | null { + return pickRandom(items.slice(0, Math.max(1, limit))); +} + +function pickWeightedRandom( + items: Array<{ item: T; weight: number }>, +): T | null { + if (items.length === 0) return null; + const weightedItems = items + .slice() + .sort((a, b) => Math.max(1, b.weight) - Math.max(1, a.weight)); + const totalWeight = weightedItems.reduce( + (total, entry) => total + Math.max(1, entry.weight), + 0, + ); + let target = Math.random() * totalWeight; + for (const entry of weightedItems) { + target -= Math.max(1, entry.weight); + if (target <= 0) return entry.item; + } + return weightedItems.at(-1)?.item ?? null; +} + +function getQuestionCandidateWeight(candidate: QuestionCandidate): number { + return getQuestionCandidateFairnessWeight(candidate.fairness); +} + +function getQuestionCandidateFairnessWeight( + fairness: QuestionCandidateFairness | undefined, +): number { + if (!fairness) return 8; + const memberCoverageWeight = 8 + fairness.memberCount * 20; + const scoreWeight = Math.max(0, Math.min(MAX_RANDOM_WEIGHT, fairness.score)); + return memberCoverageWeight + scoreWeight / 20; +} + function getAvailableOptionCount( availableCount: number, desiredCount: number,