try to fix image generator
This commit is contained in:
parent
13101edc6c
commit
4a7149b4bb
@ -7,6 +7,53 @@ import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credi
|
|||||||
import { UniAPIService } from "@/lib/uniapi";
|
import { UniAPIService } from "@/lib/uniapi";
|
||||||
import { FalService } from "@/lib/fal";
|
import { FalService } from "@/lib/fal";
|
||||||
|
|
||||||
|
// 图像生成模型适配器接口
|
||||||
|
interface ModelAdapter {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => Record<string, unknown>;
|
||||||
|
parseResponse: (response: Record<string, unknown>) => { content: string; outputType?: string };
|
||||||
|
}
|
||||||
|
|
||||||
|
// 图像生成模型适配器
|
||||||
|
const IMAGE_MODEL_ADAPTERS: Record<string, ModelAdapter> = {
|
||||||
|
'gpt-image-1': {
|
||||||
|
id: 'gpt-image-1',
|
||||||
|
name: 'GPT Image 1',
|
||||||
|
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
|
||||||
|
model: 'gpt-image-1',
|
||||||
|
prompt: `${promptContent}\n\nUser input: ${userInput}`,
|
||||||
|
size: '1024x1024',
|
||||||
|
quality: 'standard',
|
||||||
|
...params
|
||||||
|
}),
|
||||||
|
parseResponse: (response: Record<string, unknown>) => ({
|
||||||
|
content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully',
|
||||||
|
outputType: 'image'
|
||||||
|
})
|
||||||
|
},
|
||||||
|
'google/gemini-2.5-flash-image-preview': {
|
||||||
|
id: 'google/gemini-2.5-flash-image-preview',
|
||||||
|
name: 'Gemini 2.5 Flash Image Preview',
|
||||||
|
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
|
||||||
|
model: 'google/gemini-2.5-flash-image-preview',
|
||||||
|
contents: [{
|
||||||
|
parts: [{
|
||||||
|
text: `${promptContent}\n\nUser input: ${userInput}`
|
||||||
|
}]
|
||||||
|
}],
|
||||||
|
generationConfig: {
|
||||||
|
temperature: params.temperature || 0.7,
|
||||||
|
maxOutputTokens: params.maxTokens || 1024
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
parseResponse: (response: Record<string, unknown>) => ({
|
||||||
|
content: (response as { candidates?: Array<{ content?: { parts?: Array<{ text?: string }> } }>; generated_image_url?: string }).candidates?.[0]?.content?.parts?.[0]?.text || (response as { generated_image_url?: string }).generated_image_url || 'Image generated successfully',
|
||||||
|
outputType: 'image'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
export async function POST(
|
export async function POST(
|
||||||
request: NextRequest,
|
request: NextRequest,
|
||||||
{ params }: { params: Promise<{ id: string }> }
|
{ params }: { params: Promise<{ id: string }> }
|
||||||
@ -73,34 +120,95 @@ export async function POST(
|
|||||||
|
|
||||||
let apiResponse: Response;
|
let apiResponse: Response;
|
||||||
|
|
||||||
// 根据服务提供商选择不同的API
|
// 根据服务提供商和模型类型选择不同的API
|
||||||
if (run.model.serviceProvider === 'openrouter') {
|
if (run.model.serviceProvider === 'openrouter') {
|
||||||
const requestBody = {
|
// 检查是否是图像生成模型
|
||||||
model: run.model.modelId,
|
if (run.model.outputType === 'image' && IMAGE_MODEL_ADAPTERS[run.model.modelId]) {
|
||||||
messages: [
|
// 使用图像生成模型适配器
|
||||||
|
const adapter = IMAGE_MODEL_ADAPTERS[run.model.modelId];
|
||||||
|
const requestBody = adapter.prepareRequest(
|
||||||
|
run.userInput,
|
||||||
|
promptContent,
|
||||||
{
|
{
|
||||||
role: "user",
|
temperature: run.temperature,
|
||||||
content: finalPrompt,
|
maxTokens: run.maxTokens,
|
||||||
|
topP: run.topP,
|
||||||
|
frequencyPenalty: run.frequencyPenalty,
|
||||||
|
presencePenalty: run.presencePenalty,
|
||||||
}
|
}
|
||||||
],
|
);
|
||||||
temperature: run.temperature || 0.7,
|
|
||||||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
|
||||||
...(run.topP && { top_p: run.topP }),
|
|
||||||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
|
||||||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
|
||||||
stream: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
// 对于图像生成,使用非流式请求
|
||||||
method: "POST",
|
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||||||
headers: {
|
method: "POST",
|
||||||
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||||||
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
"Content-Type": "application/json",
|
||||||
"X-Title": "Prmbr - AI Prompt Studio",
|
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||||||
},
|
"X-Title": "Prmbr - AI Prompt Studio",
|
||||||
body: JSON.stringify(requestBody),
|
},
|
||||||
});
|
body: JSON.stringify(requestBody),
|
||||||
|
});
|
||||||
|
|
||||||
|
// 对于图像生成,我们需要处理非流式响应
|
||||||
|
if (apiResponse.ok) {
|
||||||
|
const responseData = await apiResponse.json();
|
||||||
|
const parsedResult = adapter.parseResponse(responseData);
|
||||||
|
|
||||||
|
// 创建模拟的流式响应
|
||||||
|
const mockStream = new ReadableStream({
|
||||||
|
start(controller) {
|
||||||
|
// 模拟流式数据
|
||||||
|
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||||
|
choices: [{
|
||||||
|
delta: { content: parsedResult.content }
|
||||||
|
}]
|
||||||
|
})}\n\n`));
|
||||||
|
|
||||||
|
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||||
|
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
|
||||||
|
})}\n\n`));
|
||||||
|
|
||||||
|
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||||||
|
controller.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
apiResponse = new Response(mockStream, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'text/event-stream',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 使用标准的文本聊天完成API
|
||||||
|
const requestBody = {
|
||||||
|
model: run.model.modelId,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: finalPrompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature: run.temperature || 0.7,
|
||||||
|
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||||||
|
...(run.topP && { top_p: run.topP }),
|
||||||
|
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||||||
|
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||||||
|
"X-Title": "Prmbr - AI Prompt Studio",
|
||||||
|
},
|
||||||
|
body: JSON.stringify(requestBody),
|
||||||
|
});
|
||||||
|
}
|
||||||
} else if (run.model.serviceProvider === 'uniapi') {
|
} else if (run.model.serviceProvider === 'uniapi') {
|
||||||
const uniAPIService = new UniAPIService();
|
const uniAPIService = new UniAPIService();
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ interface SimulatorRun {
|
|||||||
name: string
|
name: string
|
||||||
provider: string
|
provider: string
|
||||||
modelId: string
|
modelId: string
|
||||||
|
outputType?: string
|
||||||
description?: string
|
description?: string
|
||||||
maxTokens?: number
|
maxTokens?: number
|
||||||
}
|
}
|
||||||
@ -78,6 +79,7 @@ interface Model {
|
|||||||
modelId: string
|
modelId: string
|
||||||
name: string
|
name: string
|
||||||
provider: string
|
provider: string
|
||||||
|
outputType?: string
|
||||||
description?: string
|
description?: string
|
||||||
maxTokens?: number
|
maxTokens?: number
|
||||||
inputCostPer1k?: number
|
inputCostPer1k?: number
|
||||||
@ -158,10 +160,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
|
|||||||
const data = await response.json()
|
const data = await response.json()
|
||||||
setRun(data)
|
setRun(data)
|
||||||
|
|
||||||
// 如果是图片类型且有生成的文件路径,获取临时 URL
|
// 如果是图片类型,从输出中提取图像URL
|
||||||
if (data.outputType === 'image' && data.generatedFilePath) {
|
if (data.outputType === 'image' && data.output) {
|
||||||
fetchImageUrl(runId)
|
// 尝试从输出中提取图像URL
|
||||||
} else {
|
const urlMatch = data.output.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
|
||||||
|
if (urlMatch) {
|
||||||
|
setGeneratedImageUrl(urlMatch[0])
|
||||||
|
}
|
||||||
|
setStreamOutput(data.output || '')
|
||||||
|
} else if (data.outputType !== 'image') {
|
||||||
// 只有非图片类型才设置文本输出
|
// 只有非图片类型才设置文本输出
|
||||||
setStreamOutput(data.output || '')
|
setStreamOutput(data.output || '')
|
||||||
}
|
}
|
||||||
@ -173,7 +180,7 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
|
|||||||
} finally {
|
} finally {
|
||||||
setIsLoading(false)
|
setIsLoading(false)
|
||||||
}
|
}
|
||||||
}, [runId, router, fetchImageUrl])
|
}, [runId, router])
|
||||||
|
|
||||||
const fetchModels = useCallback(async () => {
|
const fetchModels = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
@ -315,6 +322,15 @@ export default function SimulatorRunPage({ params }: { params: Promise<{ id: str
|
|||||||
const content = parsed.choices[0].delta.content
|
const content = parsed.choices[0].delta.content
|
||||||
setStreamOutput(prev => prev + content)
|
setStreamOutput(prev => prev + content)
|
||||||
|
|
||||||
|
// 对于图像生成模型,尝试从内容中提取图像URL
|
||||||
|
if (run?.model?.outputType === 'image') {
|
||||||
|
// 匹配各种图像URL格式
|
||||||
|
const urlMatch = content.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
|
||||||
|
if (urlMatch) {
|
||||||
|
setGeneratedImageUrl(urlMatch[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Auto scroll to bottom
|
// Auto scroll to bottom
|
||||||
if (outputRef.current) {
|
if (outputRef.current) {
|
||||||
outputRef.current.scrollTop = outputRef.current.scrollHeight
|
outputRef.current.scrollTop = outputRef.current.scrollHeight
|
||||||
|
Loading…
Reference in New Issue
Block a user