try to fix image generator

This commit is contained in:
songtianlun 2025-08-31 01:38:59 +08:00
parent 13101edc6c
commit 4a7149b4bb
2 changed files with 153 additions and 29 deletions

View File

@ -7,6 +7,53 @@ import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credi
import { UniAPIService } from "@/lib/uniapi";
import { FalService } from "@/lib/fal";
// 图像生成模型适配器接口
interface ModelAdapter {
id: string;
name: string;
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => Record<string, unknown>;
parseResponse: (response: Record<string, unknown>) => { content: string; outputType?: string };
}
// 图像生成模型适配器
const IMAGE_MODEL_ADAPTERS: Record<string, ModelAdapter> = {
'gpt-image-1': {
id: 'gpt-image-1',
name: 'GPT Image 1',
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
model: 'gpt-image-1',
prompt: `${promptContent}\n\nUser input: ${userInput}`,
size: '1024x1024',
quality: 'standard',
...params
}),
parseResponse: (response: Record<string, unknown>) => ({
content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully',
outputType: 'image'
})
},
'google/gemini-2.5-flash-image-preview': {
id: 'google/gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image Preview',
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
model: 'google/gemini-2.5-flash-image-preview',
contents: [{
parts: [{
text: `${promptContent}\n\nUser input: ${userInput}`
}]
}],
generationConfig: {
temperature: params.temperature || 0.7,
maxOutputTokens: params.maxTokens || 1024
}
}),
parseResponse: (response: Record<string, unknown>) => ({
content: (response as { candidates?: Array<{ content?: { parts?: Array<{ text?: string }> } }>; generated_image_url?: string }).candidates?.[0]?.content?.parts?.[0]?.text || (response as { generated_image_url?: string }).generated_image_url || 'Image generated successfully',
outputType: 'image'
})
}
};
export async function POST(
request: NextRequest,
{ params }: { params: Promise<{ id: string }> }
@ -73,34 +120,95 @@ export async function POST(
let apiResponse: Response;
// 根据服务提供商选择不同的API
// 根据服务提供商和模型类型选择不同的API
if (run.model.serviceProvider === 'openrouter') {
const requestBody = {
model: run.model.modelId,
messages: [
// 检查是否是图像生成模型
if (run.model.outputType === 'image' && IMAGE_MODEL_ADAPTERS[run.model.modelId]) {
// 使用图像生成模型适配器
const adapter = IMAGE_MODEL_ADAPTERS[run.model.modelId];
const requestBody = adapter.prepareRequest(
run.userInput,
promptContent,
{
role: "user",
content: finalPrompt,
temperature: run.temperature,
maxTokens: run.maxTokens,
topP: run.topP,
frequencyPenalty: run.frequencyPenalty,
presencePenalty: run.presencePenalty,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
stream: true,
};
);
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
// 对于图像生成,使用非流式请求
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
// 对于图像生成,我们需要处理非流式响应
if (apiResponse.ok) {
const responseData = await apiResponse.json();
const parsedResult = adapter.parseResponse(responseData);
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: parsedResult.content }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
}
} else {
// 使用标准的文本聊天完成API
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
stream: true,
};
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
}
} else if (run.model.serviceProvider === 'uniapi') {
const uniAPIService = new UniAPIService();

View File

@ -68,6 +68,7 @@ interface SimulatorRun {
name: string
provider: string
modelId: string
outputType?: string
description?: string
maxTokens?: number
}
@ -78,6 +79,7 @@ interface Model {
modelId: string
name: string
provider: string
outputType?: string
description?: string
maxTokens?: number
inputCostPer1k?: number
@ -158,10 +160,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
const data = await response.json()
setRun(data)
// 如果是图片类型且有生成的文件路径,获取临时 URL
if (data.outputType === 'image' && data.generatedFilePath) {
fetchImageUrl(runId)
} else {
// 如果是图片类型从输出中提取图像URL
if (data.outputType === 'image' && data.output) {
// 尝试从输出中提取图像URL
const urlMatch = data.output.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
if (urlMatch) {
setGeneratedImageUrl(urlMatch[0])
}
setStreamOutput(data.output || '')
} else if (data.outputType !== 'image') {
// 只有非图片类型才设置文本输出
setStreamOutput(data.output || '')
}
@ -173,7 +180,7 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
} finally {
setIsLoading(false)
}
}, [runId, router, fetchImageUrl])
}, [runId, router])
const fetchModels = useCallback(async () => {
try {
@ -315,6 +322,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
const content = parsed.choices[0].delta.content
setStreamOutput(prev => prev + content)
// 对于图像生成模型尝试从内容中提取图像URL
if (run?.model?.outputType === 'image') {
// 匹配各种图像URL格式
const urlMatch = content.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
if (urlMatch) {
setGeneratedImageUrl(urlMatch[0])
}
}
// Auto scroll to bottom
if (outputRef.current) {
outputRef.current.scrollTop = outputRef.current.scrollHeight