try to fix image generator
This commit is contained in:
parent
13101edc6c
commit
4a7149b4bb
@ -7,6 +7,53 @@ import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credi
|
||||
import { UniAPIService } from "@/lib/uniapi";
|
||||
import { FalService } from "@/lib/fal";
|
||||
|
||||
// 图像生成模型适配器接口
|
||||
interface ModelAdapter {
|
||||
id: string;
|
||||
name: string;
|
||||
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => Record<string, unknown>;
|
||||
parseResponse: (response: Record<string, unknown>) => { content: string; outputType?: string };
|
||||
}
|
||||
|
||||
// 图像生成模型适配器
|
||||
const IMAGE_MODEL_ADAPTERS: Record<string, ModelAdapter> = {
|
||||
'gpt-image-1': {
|
||||
id: 'gpt-image-1',
|
||||
name: 'GPT Image 1',
|
||||
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
|
||||
model: 'gpt-image-1',
|
||||
prompt: `${promptContent}\n\nUser input: ${userInput}`,
|
||||
size: '1024x1024',
|
||||
quality: 'standard',
|
||||
...params
|
||||
}),
|
||||
parseResponse: (response: Record<string, unknown>) => ({
|
||||
content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully',
|
||||
outputType: 'image'
|
||||
})
|
||||
},
|
||||
'google/gemini-2.5-flash-image-preview': {
|
||||
id: 'google/gemini-2.5-flash-image-preview',
|
||||
name: 'Gemini 2.5 Flash Image Preview',
|
||||
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
|
||||
model: 'google/gemini-2.5-flash-image-preview',
|
||||
contents: [{
|
||||
parts: [{
|
||||
text: `${promptContent}\n\nUser input: ${userInput}`
|
||||
}]
|
||||
}],
|
||||
generationConfig: {
|
||||
temperature: params.temperature || 0.7,
|
||||
maxOutputTokens: params.maxTokens || 1024
|
||||
}
|
||||
}),
|
||||
parseResponse: (response: Record<string, unknown>) => ({
|
||||
content: (response as { candidates?: Array<{ content?: { parts?: Array<{ text?: string }> } }>; generated_image_url?: string }).candidates?.[0]?.content?.parts?.[0]?.text || (response as { generated_image_url?: string }).generated_image_url || 'Image generated successfully',
|
||||
outputType: 'image'
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
export async function POST(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ id: string }> }
|
||||
@ -73,34 +120,95 @@ export async function POST(
|
||||
|
||||
let apiResponse: Response;
|
||||
|
||||
// 根据服务提供商选择不同的API
|
||||
// 根据服务提供商和模型类型选择不同的API
|
||||
if (run.model.serviceProvider === 'openrouter') {
|
||||
const requestBody = {
|
||||
model: run.model.modelId,
|
||||
messages: [
|
||||
// 检查是否是图像生成模型
|
||||
if (run.model.outputType === 'image' && IMAGE_MODEL_ADAPTERS[run.model.modelId]) {
|
||||
// 使用图像生成模型适配器
|
||||
const adapter = IMAGE_MODEL_ADAPTERS[run.model.modelId];
|
||||
const requestBody = adapter.prepareRequest(
|
||||
run.userInput,
|
||||
promptContent,
|
||||
{
|
||||
role: "user",
|
||||
content: finalPrompt,
|
||||
temperature: run.temperature,
|
||||
maxTokens: run.maxTokens,
|
||||
topP: run.topP,
|
||||
frequencyPenalty: run.frequencyPenalty,
|
||||
presencePenalty: run.presencePenalty,
|
||||
}
|
||||
],
|
||||
temperature: run.temperature || 0.7,
|
||||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||||
...(run.topP && { top_p: run.topP }),
|
||||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||||
stream: true,
|
||||
};
|
||||
);
|
||||
|
||||
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||||
"X-Title": "Prmbr - AI Prompt Studio",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
// 对于图像生成,使用非流式请求
|
||||
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||||
"X-Title": "Prmbr - AI Prompt Studio",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
// 对于图像生成,我们需要处理非流式响应
|
||||
if (apiResponse.ok) {
|
||||
const responseData = await apiResponse.json();
|
||||
const parsedResult = adapter.parseResponse(responseData);
|
||||
|
||||
// 创建模拟的流式响应
|
||||
const mockStream = new ReadableStream({
|
||||
start(controller) {
|
||||
// 模拟流式数据
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
choices: [{
|
||||
delta: { content: parsedResult.content }
|
||||
}]
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||||
controller.close();
|
||||
}
|
||||
});
|
||||
|
||||
apiResponse = new Response(mockStream, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// 使用标准的文本聊天完成API
|
||||
const requestBody = {
|
||||
model: run.model.modelId,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: finalPrompt,
|
||||
}
|
||||
],
|
||||
temperature: run.temperature || 0.7,
|
||||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||||
...(run.topP && { top_p: run.topP }),
|
||||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||||
"X-Title": "Prmbr - AI Prompt Studio",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
}
|
||||
} else if (run.model.serviceProvider === 'uniapi') {
|
||||
const uniAPIService = new UniAPIService();
|
||||
|
||||
|
@ -68,6 +68,7 @@ interface SimulatorRun {
|
||||
name: string
|
||||
provider: string
|
||||
modelId: string
|
||||
outputType?: string
|
||||
description?: string
|
||||
maxTokens?: number
|
||||
}
|
||||
@ -78,6 +79,7 @@ interface Model {
|
||||
modelId: string
|
||||
name: string
|
||||
provider: string
|
||||
outputType?: string
|
||||
description?: string
|
||||
maxTokens?: number
|
||||
inputCostPer1k?: number
|
||||
@ -158,10 +160,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
|
||||
const data = await response.json()
|
||||
setRun(data)
|
||||
|
||||
// 如果是图片类型且有生成的文件路径,获取临时 URL
|
||||
if (data.outputType === 'image' && data.generatedFilePath) {
|
||||
fetchImageUrl(runId)
|
||||
} else {
|
||||
// 如果是图片类型,从输出中提取图像URL
|
||||
if (data.outputType === 'image' && data.output) {
|
||||
// 尝试从输出中提取图像URL
|
||||
const urlMatch = data.output.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
|
||||
if (urlMatch) {
|
||||
setGeneratedImageUrl(urlMatch[0])
|
||||
}
|
||||
setStreamOutput(data.output || '')
|
||||
} else if (data.outputType !== 'image') {
|
||||
// 只有非图片类型才设置文本输出
|
||||
setStreamOutput(data.output || '')
|
||||
}
|
||||
@ -173,7 +180,7 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [runId, router, fetchImageUrl])
|
||||
}, [runId, router])
|
||||
|
||||
const fetchModels = useCallback(async () => {
|
||||
try {
|
||||
@ -315,6 +322,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
|
||||
const content = parsed.choices[0].delta.content
|
||||
setStreamOutput(prev => prev + content)
|
||||
|
||||
// 对于图像生成模型,尝试从内容中提取图像URL
|
||||
if (run?.model?.outputType === 'image') {
|
||||
// 匹配各种图像URL格式
|
||||
const urlMatch = content.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
|
||||
if (urlMatch) {
|
||||
setGeneratedImageUrl(urlMatch[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Auto scroll to bottom
|
||||
if (outputRef.current) {
|
||||
outputRef.current.scrollTop = outputRef.current.scrollHeight
|
||||
|
Loading…
Reference in New Issue
Block a user