diff --git a/src/app/api/simulator/[id]/execute/route.ts b/src/app/api/simulator/[id]/execute/route.ts index 31f307e..f70134b 100644 --- a/src/app/api/simulator/[id]/execute/route.ts +++ b/src/app/api/simulator/[id]/execute/route.ts @@ -7,6 +7,53 @@ import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credi import { UniAPIService } from "@/lib/uniapi"; import { FalService } from "@/lib/fal"; +// 图像生成模型适配器接口 +interface ModelAdapter { + id: string; + name: string; + prepareRequest: (userInput: string, promptContent: string, params: Record) => Record; + parseResponse: (response: Record) => { content: string; outputType?: string }; +} + +// 图像生成模型适配器 +const IMAGE_MODEL_ADAPTERS: Record = { + 'gpt-image-1': { + id: 'gpt-image-1', + name: 'GPT Image 1', + prepareRequest: (userInput: string, promptContent: string, params: Record) => ({ + model: 'gpt-image-1', + prompt: `${promptContent}\n\nUser input: ${userInput}`, + size: '1024x1024', + quality: 'standard', + ...params + }), + parseResponse: (response: Record) => ({ + content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully', + outputType: 'image' + }) + }, + 'google/gemini-2.5-flash-image-preview': { + id: 'google/gemini-2.5-flash-image-preview', + name: 'Gemini 2.5 Flash Image Preview', + prepareRequest: (userInput: string, promptContent: string, params: Record) => ({ + model: 'google/gemini-2.5-flash-image-preview', + contents: [{ + parts: [{ + text: `${promptContent}\n\nUser input: ${userInput}` + }] + }], + generationConfig: { + temperature: params.temperature || 0.7, + maxOutputTokens: params.maxTokens || 1024 + } + }), + parseResponse: (response: Record) => ({ + content: (response as { candidates?: Array<{ content?: { parts?: Array<{ text?: string }> } }>; generated_image_url?: string }).candidates?.[0]?.content?.parts?.[0]?.text || (response as { generated_image_url?: string }).generated_image_url || 'Image generated successfully', + outputType: 'image' + }) + } +}; + export async function POST( request: NextRequest, { params }: { params: Promise<{ id: string }> } @@ -73,34 +120,95 @@ export async function POST( let apiResponse: Response; - // 根据服务提供商选择不同的API + // 根据服务提供商和模型类型选择不同的API if (run.model.serviceProvider === 'openrouter') { - const requestBody = { - model: run.model.modelId, - messages: [ + // 检查是否是图像生成模型 + if (run.model.outputType === 'image' && IMAGE_MODEL_ADAPTERS[run.model.modelId]) { + // 使用图像生成模型适配器 + const adapter = IMAGE_MODEL_ADAPTERS[run.model.modelId]; + const requestBody = adapter.prepareRequest( + run.userInput, + promptContent, { - role: "user", - content: finalPrompt, + temperature: run.temperature, + maxTokens: run.maxTokens, + topP: run.topP, + frequencyPenalty: run.frequencyPenalty, + presencePenalty: run.presencePenalty, } - ], - temperature: run.temperature || 0.7, - ...(run.maxTokens && { max_tokens: run.maxTokens }), - ...(run.topP && { top_p: run.topP }), - ...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }), - ...(run.presencePenalty && { presence_penalty: run.presencePenalty }), - stream: true, - }; + ); - apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", { - method: "POST", - headers: { - "Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`, - "Content-Type": "application/json", - "HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000", - "X-Title": "Prmbr - AI Prompt Studio", - }, - body: JSON.stringify(requestBody), - }); + // 对于图像生成,使用非流式请求 + apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", { + method: "POST", + headers: { + "Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`, + "Content-Type": "application/json", + "HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000", + "X-Title": "Prmbr - AI Prompt Studio", + }, + body: JSON.stringify(requestBody), + }); + + // 对于图像生成,我们需要处理非流式响应 + if (apiResponse.ok) { + const responseData = await apiResponse.json(); + const parsedResult = adapter.parseResponse(responseData); + + // 创建模拟的流式响应 + const mockStream = new ReadableStream({ + start(controller) { + // 模拟流式数据 + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + choices: [{ + delta: { content: parsedResult.content } + }] + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量 + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`)); + controller.close(); + } + }); + + apiResponse = new Response(mockStream, { + headers: { + 'Content-Type': 'text/event-stream', + } + }); + } + } else { + // 使用标准的文本聊天完成API + const requestBody = { + model: run.model.modelId, + messages: [ + { + role: "user", + content: finalPrompt, + } + ], + temperature: run.temperature || 0.7, + ...(run.maxTokens && { max_tokens: run.maxTokens }), + ...(run.topP && { top_p: run.topP }), + ...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }), + ...(run.presencePenalty && { presence_penalty: run.presencePenalty }), + stream: true, + }; + + apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", { + method: "POST", + headers: { + "Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`, + "Content-Type": "application/json", + "HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000", + "X-Title": "Prmbr - AI Prompt Studio", + }, + body: JSON.stringify(requestBody), + }); + } } else if (run.model.serviceProvider === 'uniapi') { const uniAPIService = new UniAPIService(); diff --git a/src/app/simulator/[id]/page.tsx b/src/app/simulator/[id]/page.tsx index ccb6e5e..c744416 100644 --- a/src/app/simulator/[id]/page.tsx +++ b/src/app/simulator/[id]/page.tsx @@ -68,6 +68,7 @@ interface SimulatorRun { name: string provider: string modelId: string + outputType?: string description?: string maxTokens?: number } @@ -78,6 +79,7 @@ interface Model { modelId: string name: string provider: string + outputType?: string description?: string maxTokens?: number inputCostPer1k?: number @@ -158,10 +160,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str const data = await response.json() setRun(data) - // 如果是图片类型且有生成的文件路径,获取临时 URL - if (data.outputType === 'image' && data.generatedFilePath) { - fetchImageUrl(runId) - } else { + // 如果是图片类型,从输出中提取图像URL + if (data.outputType === 'image' && data.output) { + // 尝试从输出中提取图像URL + const urlMatch = data.output.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i) + if (urlMatch) { + setGeneratedImageUrl(urlMatch[0]) + } + setStreamOutput(data.output || '') + } else if (data.outputType !== 'image') { // 只有非图片类型才设置文本输出 setStreamOutput(data.output || '') } @@ -173,7 +180,7 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str } finally { setIsLoading(false) } - }, [runId, router, fetchImageUrl]) + }, [runId, router]) const fetchModels = useCallback(async () => { try { @@ -315,6 +322,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str const content = parsed.choices[0].delta.content setStreamOutput(prev => prev + content) + // 对于图像生成模型,尝试从内容中提取图像URL + if (run?.model?.outputType === 'image') { + // 匹配各种图像URL格式 + const urlMatch = content.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i) + if (urlMatch) { + setGeneratedImageUrl(urlMatch[0]) + } + } + // Auto scroll to bottom if (outputRef.current) { outputRef.current.scrollTop = outputRef.current.scrollHeight