260 lines
7.6 KiB
TypeScript
260 lines
7.6 KiB
TypeScript
import { NextRequest, NextResponse } from "next/server";
|
|
import { auth } from "@/lib/auth";
|
|
import { headers } from "next/headers";
|
|
import { prisma } from "@/lib/prisma";
|
|
|
|
// 模型适配器类型
|
|
type ModelAdapter = {
|
|
id: string;
|
|
name: string;
|
|
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => Record<string, unknown>;
|
|
parseResponse: (response: Record<string, unknown>) => { content: string; outputType?: string };
|
|
};
|
|
|
|
// 图像生成模型适配器
|
|
const IMAGE_MODEL_ADAPTERS: Record<string, ModelAdapter> = {
|
|
'gpt-image-1': {
|
|
id: 'gpt-image-1',
|
|
name: 'GPT Image 1',
|
|
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
|
|
model: 'gpt-image-1',
|
|
prompt: `${promptContent}\n\nUser input: ${userInput}`,
|
|
size: '1024x1024',
|
|
quality: 'standard',
|
|
...params
|
|
}),
|
|
parseResponse: (response: Record<string, unknown>) => ({
|
|
content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully',
|
|
outputType: 'image'
|
|
})
|
|
},
|
|
'google/gemini-2.5-flash-image-preview': {
|
|
id: 'google/gemini-2.5-flash-image-preview',
|
|
name: 'Gemini 2.5 Flash Image Preview',
|
|
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
|
|
model: 'google/gemini-2.5-flash-image-preview',
|
|
messages: [{
|
|
role: 'user',
|
|
content: `${promptContent}\n\nUser input: ${userInput}`
|
|
}],
|
|
temperature: params.temperature || 0.7,
|
|
...(params.maxTokens ? { max_tokens: params.maxTokens } : {}),
|
|
...(params.topP ? { top_p: params.topP } : {}),
|
|
...(params.frequencyPenalty ? { frequency_penalty: params.frequencyPenalty } : {}),
|
|
...(params.presencePenalty ? { presence_penalty: params.presencePenalty } : {})
|
|
}),
|
|
parseResponse: (response: Record<string, unknown>) => {
|
|
// 尝试从不同的响应格式中提取内容
|
|
const choices = (response as { choices?: Array<{ message?: { content?: string } }> }).choices
|
|
const content = choices?.[0]?.message?.content || ''
|
|
|
|
// 从内容中提取图像URL
|
|
const urlMatch = content.match(/https?:\/\/[^\s<>"']*\.(?:png|jpg|jpeg|gif|webp|bmp|svg)(?:\?[^\s<>"']*)?/i)
|
|
const imageUrl = urlMatch?.[0] || ''
|
|
|
|
return {
|
|
content: imageUrl || content || 'Image generation completed',
|
|
outputType: 'image'
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
export async function GET(request: NextRequest) {
|
|
try {
|
|
const session = await auth.api.getSession({
|
|
headers: await headers()
|
|
});
|
|
|
|
if (!session?.user) {
|
|
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
|
|
}
|
|
|
|
const user = session.user;
|
|
|
|
const { searchParams } = new URL(request.url);
|
|
const page = parseInt(searchParams.get("page") || "1");
|
|
const limit = parseInt(searchParams.get("limit") || "20");
|
|
const status = searchParams.get("status");
|
|
|
|
const skip = (page - 1) * limit;
|
|
|
|
const where = {
|
|
userId: user.id,
|
|
...(status && { status }),
|
|
};
|
|
|
|
const [runs, total] = await Promise.all([
|
|
prisma.simulatorRun.findMany({
|
|
where,
|
|
select: {
|
|
id: true,
|
|
name: true,
|
|
status: true,
|
|
userInput: true,
|
|
output: true,
|
|
error: true,
|
|
createdAt: true,
|
|
completedAt: true,
|
|
inputTokens: true,
|
|
outputTokens: true,
|
|
totalCost: true,
|
|
duration: true,
|
|
prompt: {
|
|
select: { id: true, name: true }
|
|
},
|
|
model: {
|
|
select: { id: true, name: true, provider: true }
|
|
}
|
|
},
|
|
orderBy: { createdAt: "desc" },
|
|
skip,
|
|
take: limit,
|
|
}),
|
|
prisma.simulatorRun.count({ where }),
|
|
]);
|
|
|
|
return NextResponse.json({
|
|
runs,
|
|
pagination: {
|
|
page,
|
|
limit,
|
|
total,
|
|
totalPages: Math.ceil(total / limit),
|
|
},
|
|
});
|
|
} catch (error) {
|
|
console.error("Error fetching simulator runs:", error);
|
|
return NextResponse.json(
|
|
{ error: "Internal server error" },
|
|
{ status: 500 }
|
|
);
|
|
}
|
|
}
|
|
|
|
export async function POST(request: NextRequest) {
|
|
try {
|
|
const session = await auth.api.getSession({
|
|
headers: await headers()
|
|
});
|
|
|
|
if (!session?.user) {
|
|
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
|
|
}
|
|
|
|
const user = session.user;
|
|
|
|
const body = await request.json();
|
|
const {
|
|
name,
|
|
promptId,
|
|
promptVersionId,
|
|
modelId,
|
|
userInput,
|
|
promptContent,
|
|
temperature = 0.7,
|
|
maxTokens,
|
|
topP,
|
|
frequencyPenalty,
|
|
presencePenalty,
|
|
generationMode = 'text', // 新增生成模式字段
|
|
// 用于创建新prompt的字段
|
|
createNewPrompt,
|
|
newPromptName,
|
|
newPromptContent,
|
|
} = body;
|
|
|
|
let finalPromptId = promptId;
|
|
|
|
// 如果是创建新prompt模式
|
|
if (createNewPrompt && newPromptContent) {
|
|
// 创建新的prompt
|
|
const newPrompt = await prisma.prompt.create({
|
|
data: {
|
|
userId: user.id,
|
|
name: newPromptName || name || "New Prompt",
|
|
content: newPromptContent,
|
|
visibility: "private",
|
|
},
|
|
});
|
|
finalPromptId = newPrompt.id;
|
|
} else if (promptId) {
|
|
// 验证用户是否拥有该prompt
|
|
const prompt = await prisma.prompt.findFirst({
|
|
where: {
|
|
id: promptId,
|
|
userId: user.id,
|
|
},
|
|
});
|
|
|
|
if (!prompt) {
|
|
return NextResponse.json({ error: "Prompt not found" }, { status: 404 });
|
|
}
|
|
} else {
|
|
return NextResponse.json({ error: "Either promptId or newPromptContent is required" }, { status: 400 });
|
|
}
|
|
|
|
// 验证模型是否可用
|
|
const model = await prisma.model.findUnique({
|
|
where: { id: modelId },
|
|
include: { subscriptionPlan: true }
|
|
});
|
|
|
|
if (!model || !model.isActive) {
|
|
return NextResponse.json({ error: "Model not available" }, { status: 400 });
|
|
}
|
|
|
|
// 验证生成模式与模型的兼容性
|
|
if (generationMode === 'text' && model.outputType !== 'text') {
|
|
return NextResponse.json({ error: "Selected model is not compatible with text generation mode" }, { status: 400 });
|
|
}
|
|
|
|
if (generationMode === 'image') {
|
|
if (model.outputType !== 'image') {
|
|
return NextResponse.json({ error: "Selected model is not compatible with image generation mode" }, { status: 400 });
|
|
}
|
|
|
|
// 检查是否有对应的适配器
|
|
if (!IMAGE_MODEL_ADAPTERS[model.modelId]) {
|
|
return NextResponse.json({
|
|
error: `Image model ${model.modelId} is not supported yet. Supported models: ${Object.keys(IMAGE_MODEL_ADAPTERS).join(', ')}`
|
|
}, { status: 400 });
|
|
}
|
|
}
|
|
|
|
// 创建运行记录
|
|
const run = await prisma.simulatorRun.create({
|
|
data: {
|
|
userId: user.id,
|
|
name: name || "Simulation Run",
|
|
promptId: finalPromptId,
|
|
promptVersionId,
|
|
modelId,
|
|
userInput,
|
|
promptContent,
|
|
temperature,
|
|
maxTokens,
|
|
topP,
|
|
frequencyPenalty,
|
|
presencePenalty,
|
|
status: "pending",
|
|
},
|
|
include: {
|
|
prompt: {
|
|
select: { id: true, name: true }
|
|
},
|
|
model: {
|
|
select: { id: true, name: true, provider: true, modelId: true }
|
|
}
|
|
}
|
|
});
|
|
|
|
return NextResponse.json(run, { status: 201 });
|
|
} catch (error) {
|
|
console.error("Error creating simulator run:", error);
|
|
return NextResponse.json(
|
|
{ error: "Internal server error" },
|
|
{ status: 500 }
|
|
);
|
|
}
|
|
} |