Prmbr/src/app/api/simulator/[id]/execute/route.ts
2025-08-31 23:12:55 +08:00

663 lines
24 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { NextRequest, NextResponse } from "next/server";
import { auth } from "@/lib/auth";
import { headers } from "next/headers";
import { prisma } from "@/lib/prisma";
import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
import { UniAPIService } from "@/lib/uniapi";
import { FalService } from "@/lib/fal";
import { uploadBase64Image } from "@/lib/storage";
import type { Prisma } from "@prisma/client";
// 图像生成模型适配器接口
interface ModelAdapter {
id: string;
name: string;
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => Record<string, unknown>;
parseResponse: (response: Record<string, unknown>) => { content: string; outputType?: string; imageData?: string };
}
// 图像生成模型适配器
const IMAGE_MODEL_ADAPTERS: Record<string, ModelAdapter> = {
'gpt-image-1': {
id: 'gpt-image-1',
name: 'GPT Image 1',
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
model: 'gpt-image-1',
prompt: `${promptContent}\n\nUser input: ${userInput}`,
size: '1024x1024',
quality: 'standard',
...params
}),
parseResponse: (response: Record<string, unknown>) => ({
content: (response as { data?: { url: string }[]; url?: string }).data?.[0]?.url || (response as { url?: string }).url || 'Image generated successfully',
outputType: 'image'
})
},
'google/gemini-2.5-flash-image-preview': {
id: 'google/gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image Preview',
prepareRequest: (userInput: string, promptContent: string, params: Record<string, unknown>) => ({
model: 'google/gemini-2.5-flash-image-preview',
messages: [{
role: 'user',
content: `${promptContent}\n\nUser input: ${userInput}`
}],
temperature: params.temperature || 0.7,
...(params.maxTokens ? { max_tokens: params.maxTokens } : {}),
...(params.topP ? { top_p: params.topP } : {}),
...(params.frequencyPenalty ? { frequency_penalty: params.frequencyPenalty } : {}),
...(params.presencePenalty ? { presence_penalty: params.presencePenalty } : {})
}),
parseResponse: (response: Record<string, unknown>) => {
// 尝试从不同的响应格式中提取内容
const choices = (response as { choices?: Array<{ message?: { content?: string; images?: Array<{ image_url?: { url?: string } }> } }> }).choices
const choice = choices?.[0]
const content = choice?.message?.content || ''
// 提取 base64 图片数据
const images = choice?.message?.images
let imageData = ''
if (images && images.length > 0) {
const imageUrl = images[0]?.image_url?.url
if (imageUrl && imageUrl.startsWith('data:image/')) {
imageData = imageUrl
}
}
// 如果有图片数据,返回图片数据,否则返回文本内容
return {
content: imageData || content || 'No image data found in response',
outputType: 'image',
imageData: imageData
}
}
}
};
// Define the type for the simulator run with included relations
type SimulatorRunWithRelations = Prisma.SimulatorRunGetPayload<{
include: {
prompt: true;
promptVersion: true;
model: true;
user: {
include: {
subscriptionPlan: true;
};
};
};
}>;
export async function POST(
request: NextRequest,
{ params }: { params: Promise<{ id: string }> }
) {
const { id } = await params;
try {
const session = await auth.api.getSession({ headers: await headers() });
if (!session?.user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
const user = session.user;
const run: SimulatorRunWithRelations | null = await prisma.simulatorRun.findFirst({
where: {
id,
userId: user.id,
},
include: {
prompt: true,
promptVersion: true, // 内部使用,用于获取内容
model: true,
user: {
include: {
subscriptionPlan: true
}
}
}
});
if (!run) {
return NextResponse.json({ error: "Run not found" }, { status: 404 });
}
if (run.status !== "pending") {
return NextResponse.json({ error: "Run already executed" }, { status: 400 });
}
// Check user's credit balance before execution
const userBalance = await getUserBalance(user.id);
const costMultiplier = (run.user.subscriptionPlan as { costMultiplier?: number })?.costMultiplier || 1.0;
const estimatedCost = calculateCost(0, 100, run.model, costMultiplier); // Rough estimate
if (userBalance < estimatedCost) {
return NextResponse.json(
{ error: "Insufficient credit balance", requiredCredit: estimatedCost, currentBalance: userBalance },
{ status: 402 } // Payment Required
);
}
// 更新状态为运行中
await prisma.simulatorRun.update({
where: { id },
data: { status: "running" },
});
// 准备AI API请求
const promptContent = getPromptContent(run);
const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`;
let apiResponse: Response;
let debugRequest: Record<string, unknown> | null = null;
let debugResponse: Record<string, unknown> | null = null;
const isDevelopment = process.env.NODE_ENV === 'development';
// 根据服务提供商和模型类型选择不同的API
if (run.model.serviceProvider === 'openrouter') {
// 检查是否是图像生成模型
if (run.model.outputType === 'image' && IMAGE_MODEL_ADAPTERS[run.model.modelId]) {
// 使用图像生成模型适配器
const adapter = IMAGE_MODEL_ADAPTERS[run.model.modelId];
const requestBody = adapter.prepareRequest(
run.userInput,
promptContent,
{
temperature: run.temperature,
maxTokens: run.maxTokens,
topP: run.topP,
frequencyPenalty: run.frequencyPenalty,
presencePenalty: run.presencePenalty,
}
);
// 存储调试信息
if (isDevelopment) {
debugRequest = {
url: "https://openrouter.ai/api/v1/chat/completions",
method: "POST",
headers: {
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: requestBody
};
}
// 对于图像生成,使用非流式请求
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
// 对于图像生成,我们需要处理非流式响应
if (apiResponse.ok) {
const responseData = await apiResponse.json();
// 存储调试响应信息
if (isDevelopment) {
debugResponse = {
status: apiResponse.status,
headers: Object.fromEntries(apiResponse.headers.entries()),
body: responseData
};
}
const parsedResult = adapter.parseResponse(responseData);
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: parsedResult.content }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
})}\n\n`));
// 异步处理图像生成数据存储
(async () => {
try {
// 对于图像生成,存储图片数据并完成运行
const startTime = Date.now();
const duration = Date.now() - startTime;
const actualCost = calculateCost(50, 100, run.model, costMultiplier);
// Consume credits for this simulation
let creditTransaction;
try {
creditTransaction = await consumeCreditForSimulation(
user.id,
actualCost,
id,
`${run.model.name} image generation`
);
} catch (creditError) {
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Credit consumption failed: ${creditError}`,
},
});
return;
}
// 更新运行状态并存储图片数据
// 如果有图片数据上传到S3
let generatedFilePath = null;
if (parsedResult.imageData && parsedResult.imageData.startsWith('data:image/')) {
try {
const fileName = `run-${id}-${Date.now()}.png`;
generatedFilePath = await uploadBase64Image(parsedResult.imageData, fileName);
console.log('Image uploaded to S3:', generatedFilePath);
} catch (error) {
console.error('Error uploading image to S3:', error);
}
}
await prisma.simulatorRun.update({
where: { id },
data: {
status: "completed",
output: parsedResult.content,
inputTokens: 50,
outputTokens: 100,
totalCost: actualCost,
duration,
creditId: creditTransaction.id,
completedAt: new Date(),
generatedFilePath,
...(isDevelopment && debugRequest ? { debugRequest: debugRequest as Prisma.InputJsonValue } : {}),
...(isDevelopment && debugResponse ? { debugResponse: debugResponse as Prisma.InputJsonValue } : {}),
},
});
} catch (error) {
console.error("Error storing image data:", error);
}
})();
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
const errorData = await apiResponse.text();
// 存储错误响应信息
if (isDevelopment) {
debugResponse = {
status: apiResponse.status,
headers: Object.fromEntries(apiResponse.headers.entries()),
body: { error: errorData }
};
}
}
} else {
// 使用标准的文本聊天完成API
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
stream: true,
};
// 存储调试信息
if (isDevelopment) {
debugRequest = {
url: "https://openrouter.ai/api/v1/chat/completions",
method: "POST",
headers: {
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: requestBody
};
}
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
}
} else if (run.model.serviceProvider === 'uniapi') {
const uniAPIService = new UniAPIService();
if (run.model.outputType === 'text' || run.model.outputType === 'multimodal') {
// 使用聊天完成API
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
};
// 注意UniAPI 可能不支持流式响应,这里需要调整
const response = await uniAPIService.createChatCompletion(requestBody);
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const content = response.choices?.[0]?.message?.content || '';
const usage = response.usage || { prompt_tokens: 0, completion_tokens: 0 };
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: content }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: usage
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
// 对于非文本模型,返回错误
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported model type: ${run.model.outputType}`,
},
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else if (run.model.serviceProvider === 'fal') {
const falService = new FalService();
if (run.model.outputType === 'image') {
// 使用图像生成API
const response = await falService.generateImage({
model: run.model.modelId,
prompt: finalPrompt,
num_images: 1,
});
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const imageUrl = response.images?.[0]?.url || '';
const result = {
images: response.images,
prompt: finalPrompt,
model: run.model.modelId
};
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: `Generated image: ${imageUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else if (run.model.outputType === 'video') {
// 使用视频生成API
const response = await falService.generateVideo({
model: run.model.modelId,
prompt: finalPrompt,
});
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const videoUrl = response.video?.url || '';
const result = {
video: response.video,
prompt: finalPrompt,
model: run.model.modelId
};
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: `Generated video: ${videoUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 150 } // 估算的token使用量
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
// 对于其他类型,返回错误
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported Fal.ai model type: ${run.model.outputType}`,
},
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else {
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported service provider: ${run.model.serviceProvider}`,
},
});
return NextResponse.json({ error: "Unsupported service provider" }, { status: 400 });
}
if (!apiResponse.ok) {
const errorText = await apiResponse.text();
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `API error: ${apiResponse.status} - ${errorText}`,
},
});
return NextResponse.json({ error: "AI API request failed" }, { status: 500 });
}
// 创建流式响应
const stream = new ReadableStream({
async start(controller) {
const reader = apiResponse.body?.getReader();
if (!reader) {
controller.close();
return;
}
let fullResponse = "";
let inputTokens = 0;
let outputTokens = 0;
const startTime = Date.now();
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = new TextDecoder().decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = line.slice(6);
if (data === '[DONE]') {
// 计算最终数据并更新数据库
const duration = Date.now() - startTime;
const actualCost = calculateCost(inputTokens, outputTokens, run.model, costMultiplier);
// Consume credits for this simulation
let creditTransaction;
try {
creditTransaction = await consumeCreditForSimulation(
user.id,
actualCost,
id,
`${run.model.name} simulation: ${inputTokens} input + ${outputTokens} output tokens`
);
} catch (creditError) {
// If credit consumption fails, mark the run as failed
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Credit consumption failed: ${creditError}`,
},
});
controller.enqueue(new TextEncoder().encode(`data: {"error": "Credit consumption failed"}\n\n`));
controller.close();
return;
}
// Update the run with completion data and credit reference
await prisma.simulatorRun.update({
where: { id },
data: {
status: "completed",
output: fullResponse,
inputTokens,
outputTokens,
totalCost: actualCost,
duration,
creditId: creditTransaction.id,
completedAt: new Date(),
...(isDevelopment && debugRequest ? { debugRequest: debugRequest as Prisma.InputJsonValue } : {}),
...(isDevelopment && debugResponse ? { debugResponse: debugResponse as Prisma.InputJsonValue } : {}),
},
});
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
return;
}
try {
const parsed = JSON.parse(data);
if (parsed.choices?.[0]?.delta?.content) {
const content = parsed.choices[0].delta.content;
fullResponse += content;
}
// 估算token使用量简化版本
if (parsed.usage) {
inputTokens = parsed.usage.prompt_tokens || 0;
outputTokens = parsed.usage.completion_tokens || 0;
}
} catch {
// 忽略解析错误,继续处理其他数据
}
controller.enqueue(new TextEncoder().encode(`data: ${data}\n\n`));
}
}
}
} catch (error) {
console.error("Stream processing error:", error);
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Stream processing error: ${error}`,
},
});
controller.close();
}
},
});
return new NextResponse(stream, {
headers: {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
},
});
} catch (error) {
console.error("Error executing simulator run:", error);
// 更新运行状态为失败
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Execution error: ${error}`,
},
}).catch(() => {}); // 忽略更新失败
return NextResponse.json(
{ error: "Internal server error" },
{ status: 500 }
);
}
}