From 13101edc6c34eb5b0ae6cf6315e69226b3ee6258 Mon Sep 17 00:00:00 2001 From: songtianlun Date: Sun, 31 Aug 2025 01:31:15 +0800 Subject: [PATCH] fix some api --- src/app/api/simulator/[id]/duplicate/route.ts | 12 +- src/app/api/simulator/[id]/route.ts | 30 +++-- src/app/api/simulator/prompts/route.ts | 12 +- src/app/api/simulator/route.ts | 87 +++++++++++- src/app/simulator/[id]/page.tsx | 4 +- src/app/simulator/new/page.tsx | 126 ++++++++++++++++-- src/hooks/useBetterAuth.ts | 13 +- 7 files changed, 243 insertions(+), 41 deletions(-) diff --git a/src/app/api/simulator/[id]/duplicate/route.ts b/src/app/api/simulator/[id]/duplicate/route.ts index 41d558e..049418c 100644 --- a/src/app/api/simulator/[id]/duplicate/route.ts +++ b/src/app/api/simulator/[id]/duplicate/route.ts @@ -1,5 +1,6 @@ import { NextRequest, NextResponse } from "next/server"; -import { createServerSupabaseClient } from "@/lib/supabase-server"; +import { auth } from '@/lib/auth'; +import { headers } from 'next/headers'; import { prisma } from "@/lib/prisma"; export async function POST( @@ -8,12 +9,15 @@ export async function POST( ) { try { const { id } = await params; - const supabase = await createServerSupabaseClient(); - const { data: { user }, error: authError } = await supabase.auth.getUser(); + const session = await auth.api.getSession({ + headers: await headers() + }); - if (authError || !user) { + if (!session?.user) { return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); } + + const user = session.user; // 获取原始运行记录 const originalRun = await prisma.simulatorRun.findFirst({ diff --git a/src/app/api/simulator/[id]/route.ts b/src/app/api/simulator/[id]/route.ts index 8d30fea..f73529e 100644 --- a/src/app/api/simulator/[id]/route.ts +++ b/src/app/api/simulator/[id]/route.ts @@ -1,5 +1,6 @@ import { NextRequest, NextResponse } from "next/server"; -import { createServerSupabaseClient } from "@/lib/supabase-server"; +import { auth } from '@/lib/auth'; +import { headers } from 'next/headers'; import { prisma } from "@/lib/prisma"; export async function GET( @@ -8,12 +9,15 @@ export async function GET( ) { try { const { id } = await params; - const supabase = await createServerSupabaseClient(); - const { data: { user }, error: authError } = await supabase.auth.getUser(); + const session = await auth.api.getSession({ + headers: await headers() + }); - if (authError || !user) { + if (!session?.user) { return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); } + + const user = session.user; const run = await prisma.simulatorRun.findFirst({ where: { @@ -75,12 +79,15 @@ export async function PATCH( ) { try { const { id } = await params; - const supabase = await createServerSupabaseClient(); - const { data: { user }, error: authError } = await supabase.auth.getUser(); + const session = await auth.api.getSession({ + headers: await headers() + }); - if (authError || !user) { + if (!session?.user) { return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); } + + const user = session.user; const body = await request.json(); const { status, output, error, inputTokens, outputTokens, totalCost, duration } = body; @@ -144,12 +151,15 @@ export async function PUT( ) { try { const { id } = await params; - const supabase = await createServerSupabaseClient(); - const { data: { user }, error: authError } = await supabase.auth.getUser(); + const session = await auth.api.getSession({ + headers: await headers() + }); - if (authError || !user) { + if (!session?.user) { return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); } + + const user = session.user; const body = await request.json(); const { diff --git a/src/app/api/simulator/prompts/route.ts b/src/app/api/simulator/prompts/route.ts index 79a5022..055d88e 100644 --- a/src/app/api/simulator/prompts/route.ts +++ b/src/app/api/simulator/prompts/route.ts @@ -1,17 +1,21 @@ import { NextRequest, NextResponse } from 'next/server' import { prisma } from '@/lib/prisma' -import { createServerSupabaseClient } from '@/lib/supabase-server' +import { auth } from '@/lib/auth' +import { headers } from 'next/headers' // GET /api/simulator/prompts - 获取用户的提示词列表,包含所有版本信息用于模拟器 export async function GET(request: NextRequest) { try { - const supabase = await createServerSupabaseClient() - const { data: { user } } = await supabase.auth.getUser() + const session = await auth.api.getSession({ + headers: await headers() + }) - if (!user) { + if (!session?.user) { return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) } + const user = session.user + const { searchParams } = new URL(request.url) const limit = parseInt(searchParams.get('limit') || '100') diff --git a/src/app/api/simulator/route.ts b/src/app/api/simulator/route.ts index ab632b9..9614565 100644 --- a/src/app/api/simulator/route.ts +++ b/src/app/api/simulator/route.ts @@ -1,15 +1,66 @@ import { NextRequest, NextResponse } from "next/server"; -import { createServerSupabaseClient } from "@/lib/supabase-server"; +import { auth } from "@/lib/auth"; +import { headers } from "next/headers"; import { prisma } from "@/lib/prisma"; +// 模型适配器类型 +type ModelAdapter = { + id: string; + name: string; + prepareRequest: (userInput: string, promptContent: string, params: Record) => Record; + parseResponse: (response: Record) => { content: string; outputType?: string }; +}; + +// 图像生成模型适配器 +const IMAGE_MODEL_ADAPTERS: Record = { + 'gpt-image-1': { + id: 'gpt-image-1', + name: 'GPT Image 1', + prepareRequest: (userInput: string, promptContent: string, params: Record) => ({ + model: 'gpt-image-1', + prompt: `${promptContent}\n\nUser input: ${userInput}`, + size: '1024x1024', + quality: 'standard', + ...params + }), + parseResponse: (response: Record) => ({ + content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully', + outputType: 'image' + }) + }, + 'google/gemini-2.5-flash-image-preview': { + id: 'google/gemini-2.5-flash-image-preview', + name: 'Gemini 2.5 Flash Image Preview', + prepareRequest: (userInput: string, promptContent: string, params: Record) => ({ + model: 'google/gemini-2.5-flash-image-preview', + contents: [{ + parts: [{ + text: `${promptContent}\n\nUser input: ${userInput}` + }] + }], + generationConfig: { + temperature: params.temperature || 0.7, + maxOutputTokens: params.maxTokens || 1024 + } + }), + parseResponse: (response: Record) => ({ + content: (response as { candidates?: Array<{ content?: { parts?: Array<{ text?: string }> } }>; generated_image_url?: string }).candidates?.[0]?.content?.parts?.[0]?.text || (response as { generated_image_url?: string }).generated_image_url || 'Image generated successfully', + outputType: 'image' + }) + } +}; + export async function GET(request: NextRequest) { try { - const supabase = await createServerSupabaseClient(); - const { data: { user }, error: authError } = await supabase.auth.getUser(); + const session = await auth.api.getSession({ + headers: await headers() + }); - if (authError || !user) { + if (!session?.user) { return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); } + + const user = session.user; const { searchParams } = new URL(request.url); const page = parseInt(searchParams.get("page") || "1"); @@ -73,12 +124,15 @@ export async function GET(request: NextRequest) { export async function POST(request: NextRequest) { try { - const supabase = await createServerSupabaseClient(); - const { data: { user }, error: authError } = await supabase.auth.getUser(); + const session = await auth.api.getSession({ + headers: await headers() + }); - if (authError || !user) { + if (!session?.user) { return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); } + + const user = session.user; const body = await request.json(); const { @@ -93,6 +147,7 @@ export async function POST(request: NextRequest) { topP, frequencyPenalty, presencePenalty, + generationMode = 'text', // 新增生成模式字段 // 用于创建新prompt的字段 createNewPrompt, newPromptName, @@ -139,6 +194,24 @@ export async function POST(request: NextRequest) { return NextResponse.json({ error: "Model not available" }, { status: 400 }); } + // 验证生成模式与模型的兼容性 + if (generationMode === 'text' && model.outputType !== 'text') { + return NextResponse.json({ error: "Selected model is not compatible with text generation mode" }, { status: 400 }); + } + + if (generationMode === 'image') { + if (model.outputType !== 'image') { + return NextResponse.json({ error: "Selected model is not compatible with image generation mode" }, { status: 400 }); + } + + // 检查是否有对应的适配器 + if (!IMAGE_MODEL_ADAPTERS[model.modelId]) { + return NextResponse.json({ + error: `Image model ${model.modelId} is not supported yet. Supported models: ${Object.keys(IMAGE_MODEL_ADAPTERS).join(', ')}` + }, { status: 400 }); + } + } + // 创建运行记录 const run = await prisma.simulatorRun.create({ data: { diff --git a/src/app/simulator/[id]/page.tsx b/src/app/simulator/[id]/page.tsx index e319ed6..ccb6e5e 100644 --- a/src/app/simulator/[id]/page.tsx +++ b/src/app/simulator/[id]/page.tsx @@ -2,7 +2,7 @@ import { useState, useEffect, useRef, useCallback } from 'react' import { useTranslations } from 'next-intl' -import { useAuthUser } from '@/hooks/useAuthUser' +import { useBetterAuth } from '@/hooks/useBetterAuth' import { useRouter } from 'next/navigation' import { Header } from '@/components/layout/Header' import { Footer } from '@/components/layout/Footer' @@ -85,7 +85,7 @@ interface Model { } export default function SimulatorRunPage({ params }: { params: Promise<{ id: string }> }) { - const { user, loading: authLoading } = useAuthUser() + const { user, loading: authLoading } = useBetterAuth() const router = useRouter() const t = useTranslations('simulator') const locale = useLocale() diff --git a/src/app/simulator/new/page.tsx b/src/app/simulator/new/page.tsx index 9c61a49..c4cf510 100644 --- a/src/app/simulator/new/page.tsx +++ b/src/app/simulator/new/page.tsx @@ -2,7 +2,7 @@ import { useState, useEffect, useCallback } from 'react' import { useTranslations } from 'next-intl' -import { useAuthUser } from '@/hooks/useAuthUser' +import { useBetterAuth } from '@/hooks/useBetterAuth' import { useRouter } from 'next/navigation' import { Header } from '@/components/layout/Header' import { Footer } from '@/components/layout/Footer' @@ -36,11 +36,30 @@ interface Prompt { }> } +// 生成模式枚举 +type GenerationMode = 'text' | 'image' + +// 支持的图像生成模型ID映射 +const SUPPORTED_IMAGE_MODELS = { + 'gpt-image-1': { + id: 'gpt-image-1', + name: 'GPT Image 1', + adapter: 'gpt-image-1' + }, + 'google/gemini-2.5-flash-image-preview': { + id: 'google/gemini-2.5-flash-image-preview', + name: 'Gemini 2.5 Flash Image Preview', + adapter: 'gemini-image' + } +} as const + interface Model { id: string modelId: string name: string provider: string + serviceProvider: string + outputType: string description?: string maxTokens?: number inputCostPer1k?: number @@ -49,7 +68,7 @@ interface Model { } export default function NewSimulatorRunPage() { - const { user, loading: authLoading } = useAuthUser() + const { user, loading: authLoading } = useBetterAuth() const router = useRouter() const t = useTranslations('simulator') @@ -69,6 +88,10 @@ export default function NewSimulatorRunPage() { const [simulatorName, setSimulatorName] = useState('') const [promptInputMode, setPromptInputMode] = useState<'select' | 'create'>('select') // 选择模式:选择现有提示词 或 创建新提示词 + // 生成模式相关状态 + const [generationMode, setGenerationMode] = useState('text') + const [filteredModels, setFilteredModels] = useState([]) + // Advanced settings const [temperature, setTemperature] = useState('0.7') const [maxTokens, setMaxTokens] = useState('') @@ -76,6 +99,34 @@ export default function NewSimulatorRunPage() { const [frequencyPenalty, setFrequencyPenalty] = useState('0') const [presencePenalty, setPresencePenalty] = useState('0') + // 模型过滤逻辑 + const filterModelsByMode = useCallback((allModels: Model[], mode: GenerationMode) => { + if (mode === 'text') { + // 文本模式:显示所有outputType为text的模型 + return allModels.filter(model => model.outputType === 'text') + } else if (mode === 'image') { + // 图像模式:只显示已适配的图像生成模型 + return allModels.filter(model => + model.outputType === 'image' && + Object.keys(SUPPORTED_IMAGE_MODELS).includes(model.modelId) + ) + } + return [] + }, []) + + // 当生成模式改变时更新模型列表 + useEffect(() => { + const filtered = filterModelsByMode(models, generationMode) + setFilteredModels(filtered) + + // 重置已选择的模型,选择第一个可用模型 + if (filtered.length > 0) { + setSelectedModelId(filtered[0].id) + } else { + setSelectedModelId('') + } + }, [models, generationMode, filterModelsByMode]) + const fetchData = useCallback(async () => { if (!user) return @@ -95,10 +146,7 @@ export default function NewSimulatorRunPage() { if (modelsResponse.ok) { const modelsData = await modelsResponse.json() setModels(modelsData.models || []) - // Auto-select first model - if (modelsData.models?.length > 0) { - setSelectedModelId(modelsData.models[0].id) - } + // 模型选择逻辑现在由useEffect处理 } } catch (error) { console.error('Error fetching data:', error) @@ -222,6 +270,7 @@ export default function NewSimulatorRunPage() { topP: number; frequencyPenalty: number; presencePenalty: number; + generationMode: GenerationMode; createNewPrompt?: boolean; newPromptName?: string; newPromptContent?: string; @@ -237,6 +286,7 @@ export default function NewSimulatorRunPage() { topP: parseFloat(topP), frequencyPenalty: parseFloat(frequencyPenalty), presencePenalty: parseFloat(presencePenalty), + generationMode, } if (promptInputMode === 'create') { @@ -315,6 +365,52 @@ export default function NewSimulatorRunPage() { ) : (
+ {/* Generation Mode Selection */} + +
+

+ + 生成模式 +

+

+ 选择生成模式:文本生成或图像生成 +

+ +
+ + +
+ + {generationMode === 'image' && filteredModels.length === 0 && ( +
+

+ 当前没有可用的图像生成模型。支持的模型:{Object.values(SUPPORTED_IMAGE_MODELS).map(m => m.name).join(', ')} +

+
+ )} +
+
+ {/* Prompt Configuration */}
@@ -511,13 +607,16 @@ export default function NewSimulatorRunPage() {

{t('selectModel')}

- Choose the AI model to run your prompt + {generationMode === 'text' + ? '选择用于文本生成的AI模型' + : '选择用于图像生成的AI模型(仅显示已适配的模型)' + }

- {models.map(model => ( + {filteredModels.map(model => (

{model.name}

- - {model.provider} - +
+ + {model.outputType === 'text' ? '📝' : '🎨'} {model.outputType} + + + {model.provider} + +
{model.description && (

diff --git a/src/hooks/useBetterAuth.ts b/src/hooks/useBetterAuth.ts index d175ed3..a0fcee5 100644 --- a/src/hooks/useBetterAuth.ts +++ b/src/hooks/useBetterAuth.ts @@ -132,6 +132,8 @@ export function useBetterAuth() { ...globalAuthState, user, loading: false, + // 保留现有的 isAdmin 状态,除非用户登出 + isAdmin: user ? globalAuthState.isAdmin : false, } notifyStateChange(newState) @@ -232,9 +234,14 @@ export function useBetterAuth() { // 立即更新用户状态(不延迟) updateUserState(user) - // 异步进行数据同步(防抖延迟) + // 异步进行数据同步 if (user) { - debouncedUserDataSync(user, trigger) + // 初始加载时立即同步,其他情况使用防抖 + if (trigger === SyncTrigger.INITIAL_LOAD) { + globalHandleUserDataSync(user, trigger) + } else { + debouncedUserDataSync(user, trigger) + } } if (!isGlobalInitialized) { @@ -249,7 +256,7 @@ export function useBetterAuth() { return () => { mounted = false } - }, [session, isPending, updateUserState, debouncedUserDataSync]) + }, [session, isPending, updateUserState, debouncedUserDataSync, globalHandleUserDataSync]) // 设置loading状态 useEffect(() => {