Prmbr/src/app/api/simulator/route.ts
2025-08-31 01:40:52 +08:00

260 lines
7.6 KiB
TypeScript

import { NextRequest, NextResponse } from "next/server";
import { auth } from "@/lib/auth";
import { headers } from "next/headers";
import { prisma } from "@/lib/prisma";
// 模型适配器类型
type ModelAdapter = {
id: string;
name: string;
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => Record<string, unknown>;
parseResponse: (response: Record<string, unknown>) => { content: string; outputType?: string };
};
// 图像生成模型适配器
const IMAGE_MODEL_ADAPTERS: Record<string, ModelAdapter> = {
'gpt-image-1': {
id: 'gpt-image-1',
name: 'GPT Image 1',
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
model: 'gpt-image-1',
prompt: `${promptContent}\n\nUser input: ${userInput}`,
size: '1024x1024',
quality: 'standard',
...params
}),
parseResponse: (response: Record<string, unknown>) => ({
content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully',
outputType: 'image'
})
},
'google/gemini-2.5-flash-image-preview': {
id: 'google/gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image Preview',
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
model: 'google/gemini-2.5-flash-image-preview',
messages: [{
role: 'user',
content: `${promptContent}\n\nUser input: ${userInput}`
}],
temperature: params.temperature || 0.7,
...(params.maxTokens ? { max_tokens: params.maxTokens } : {}),
...(params.topP ? { top_p: params.topP } : {}),
...(params.frequencyPenalty ? { frequency_penalty: params.frequencyPenalty } : {}),
...(params.presencePenalty ? { presence_penalty: params.presencePenalty } : {})
}),
parseResponse: (response: Record<string, unknown>) => {
// 尝试从不同的响应格式中提取内容
const choices = (response as { choices?: Array<{ message?: { content?: string } }> }).choices
const content = choices?.[0]?.message?.content || ''
// 从内容中提取图像URL
const urlMatch = content.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
const imageUrl = urlMatch?.[0] || ''
return {
content: imageUrl || content || 'Image generation completed',
outputType: 'image'
}
}
}
};
export async function GET(request: NextRequest) {
try {
const session = await auth.api.getSession({
headers: await headers()
});
if (!session?.user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
const user = session.user;
const { searchParams } = new URL(request.url);
const page = parseInt(searchParams.get("page") || "1");
const limit = parseInt(searchParams.get("limit") || "20");
const status = searchParams.get("status");
const skip = (page - 1) * limit;
const where = {
userId: user.id,
...(status && { status }),
};
const [runs, total] = await Promise.all([
prisma.simulatorRun.findMany({
where,
select: {
id: true,
name: true,
status: true,
userInput: true,
output: true,
error: true,
createdAt: true,
completedAt: true,
inputTokens: true,
outputTokens: true,
totalCost: true,
duration: true,
prompt: {
select: { id: true, name: true }
},
model: {
select: { id: true, name: true, provider: true }
}
},
orderBy: { createdAt: "desc" },
skip,
take: limit,
}),
prisma.simulatorRun.count({ where }),
]);
return NextResponse.json({
runs,
pagination: {
page,
limit,
total,
totalPages: Math.ceil(total / limit),
},
});
} catch (error) {
console.error("Error fetching simulator runs:", error);
return NextResponse.json(
{ error: "Internal server error" },
{ status: 500 }
);
}
}
export async function POST(request: NextRequest) {
try {
const session = await auth.api.getSession({
headers: await headers()
});
if (!session?.user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
const user = session.user;
const body = await request.json();
const {
name,
promptId,
promptVersionId,
modelId,
userInput,
promptContent,
temperature = 0.7,
maxTokens,
topP,
frequencyPenalty,
presencePenalty,
generationMode = 'text', // 新增生成模式字段
// 用于创建新prompt的字段
createNewPrompt,
newPromptName,
newPromptContent,
} = body;
let finalPromptId = promptId;
// 如果是创建新prompt模式
if (createNewPrompt && newPromptContent) {
// 创建新的prompt
const newPrompt = await prisma.prompt.create({
data: {
userId: user.id,
name: newPromptName || name || "New Prompt",
content: newPromptContent,
visibility: "private",
},
});
finalPromptId = newPrompt.id;
} else if (promptId) {
// 验证用户是否拥有该prompt
const prompt = await prisma.prompt.findFirst({
where: {
id: promptId,
userId: user.id,
},
});
if (!prompt) {
return NextResponse.json({ error: "Prompt not found" }, { status: 404 });
}
} else {
return NextResponse.json({ error: "Either promptId or newPromptContent is required" }, { status: 400 });
}
// 验证模型是否可用
const model = await prisma.model.findUnique({
where: { id: modelId },
include: { subscriptionPlan: true }
});
if (!model || !model.isActive) {
return NextResponse.json({ error: "Model not available" }, { status: 400 });
}
// 验证生成模式与模型的兼容性
if (generationMode === 'text' && model.outputType !== 'text') {
return NextResponse.json({ error: "Selected model is not compatible with text generation mode" }, { status: 400 });
}
if (generationMode === 'image') {
if (model.outputType !== 'image') {
return NextResponse.json({ error: "Selected model is not compatible with image generation mode" }, { status: 400 });
}
// 检查是否有对应的适配器
if (!IMAGE_MODEL_ADAPTERS[model.modelId]) {
return NextResponse.json({
error: `Image model ${model.modelId} is not supported yet. Supported models: ${Object.keys(IMAGE_MODEL_ADAPTERS).join(', ')}`
}, { status: 400 });
}
}
// 创建运行记录
const run = await prisma.simulatorRun.create({
data: {
userId: user.id,
name: name || "Simulation Run",
promptId: finalPromptId,
promptVersionId,
modelId,
userInput,
promptContent,
temperature,
maxTokens,
topP,
frequencyPenalty,
presencePenalty,
status: "pending",
},
include: {
prompt: {
select: { id: true, name: true }
},
model: {
select: { id: true, name: true, provider: true, modelId: true }
}
}
});
return NextResponse.json(run, { status: 201 });
} catch (error) {
console.error("Error creating simulator run:", error);
return NextResponse.json(
{ error: "Internal server error" },
{ status: 500 }
);
}
}