Prmbr/src/app/api/simulator/[id]/execute/route.ts
2025-08-27 23:37:14 +08:00

398 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { NextRequest, NextResponse } from "next/server";
import { createServerSupabaseClient } from "@/lib/supabase-server";
import { prisma } from "@/lib/prisma";
import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
import { UniAPIService } from "@/lib/uniapi";
import { FalService } from "@/lib/fal";
export async function POST(
request: NextRequest,
{ params }: { params: Promise<{ id: string }> }
) {
const { id } = await params;
try {
const supabase = await createServerSupabaseClient();
const { data: { user }, error: authError } = await supabase.auth.getUser();
if (authError || !user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
const run = await prisma.simulatorRun.findFirst({
where: {
id,
userId: user.id,
},
include: {
prompt: true,
promptVersion: true, // 内部使用,用于获取内容
model: true,
}
});
if (!run) {
return NextResponse.json({ error: "Run not found" }, { status: 404 });
}
if (run.status !== "pending") {
return NextResponse.json({ error: "Run already executed" }, { status: 400 });
}
// Check user's credit balance before execution
const userBalance = await getUserBalance(user.id);
const estimatedCost = calculateCost(0, 100, run.model); // Rough estimate
if (userBalance < estimatedCost) {
return NextResponse.json(
{ error: "Insufficient credit balance", requiredCredit: estimatedCost, currentBalance: userBalance },
{ status: 402 } // Payment Required
);
}
// 更新状态为运行中
await prisma.simulatorRun.update({
where: { id },
data: { status: "running" },
});
// 准备AI API请求
const promptContent = getPromptContent(run);
const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`;
let apiResponse: Response;
// 根据服务提供商选择不同的API
if (run.model.serviceProvider === 'openrouter') {
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
stream: true,
};
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
} else if (run.model.serviceProvider === 'uniapi') {
const uniAPIService = new UniAPIService();
if (run.model.outputType === 'text' || run.model.outputType === 'multimodal') {
// 使用聊天完成API
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
};
// 注意UniAPI 可能不支持流式响应,这里需要调整
const response = await uniAPIService.createChatCompletion(requestBody);
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const content = response.choices?.[0]?.message?.content || '';
const usage = response.usage || { prompt_tokens: 0, completion_tokens: 0 };
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: content }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: usage
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
// 对于非文本模型,返回错误
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported model type: ${run.model.outputType}`,
},
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else if (run.model.serviceProvider === 'fal') {
const falService = new FalService();
if (run.model.outputType === 'image') {
// 使用图像生成API
const response = await falService.generateImage({
model: run.model.modelId,
prompt: finalPrompt,
num_images: 1,
});
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const imageUrl = response.images?.[0]?.url || '';
const result = {
images: response.images,
prompt: finalPrompt,
model: run.model.modelId
};
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: `Generated image: ${imageUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else if (run.model.outputType === 'video') {
// 使用视频生成API
const response = await falService.generateVideo({
model: run.model.modelId,
prompt: finalPrompt,
});
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const videoUrl = response.video?.url || '';
const result = {
video: response.video,
prompt: finalPrompt,
model: run.model.modelId
};
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: `Generated video: ${videoUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 150 } // 估算的token使用量
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
// 对于其他类型,返回错误
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported Fal.ai model type: ${run.model.outputType}`,
},
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else {
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported service provider: ${run.model.serviceProvider}`,
},
});
return NextResponse.json({ error: "Unsupported service provider" }, { status: 400 });
}
if (!apiResponse.ok) {
const errorText = await apiResponse.text();
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `API error: ${apiResponse.status} - ${errorText}`,
},
});
return NextResponse.json({ error: "AI API request failed" }, { status: 500 });
}
// 创建流式响应
const stream = new ReadableStream({
async start(controller) {
const reader = apiResponse.body?.getReader();
if (!reader) {
controller.close();
return;
}
let fullResponse = "";
let inputTokens = 0;
let outputTokens = 0;
const startTime = Date.now();
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = new TextDecoder().decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = line.slice(6);
if (data === '[DONE]') {
// 计算最终数据并更新数据库
const duration = Date.now() - startTime;
const actualCost = calculateCost(inputTokens, outputTokens, run.model);
// Consume credits for this simulation
let creditTransaction;
try {
creditTransaction = await consumeCreditForSimulation(
user.id,
actualCost,
id,
`${run.model.name} simulation: ${inputTokens} input + ${outputTokens} output tokens`
);
} catch (creditError) {
// If credit consumption fails, mark the run as failed
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Credit consumption failed: ${creditError}`,
},
});
controller.enqueue(new TextEncoder().encode(`data: {"error": "Credit consumption failed"}\n\n`));
controller.close();
return;
}
// Update the run with completion data and credit reference
await prisma.simulatorRun.update({
where: { id },
data: {
status: "completed",
output: fullResponse,
inputTokens,
outputTokens,
totalCost: actualCost,
duration,
creditId: creditTransaction.id,
completedAt: new Date(),
},
});
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
return;
}
try {
const parsed = JSON.parse(data);
if (parsed.choices?.[0]?.delta?.content) {
const content = parsed.choices[0].delta.content;
fullResponse += content;
}
// 估算token使用量简化版本
if (parsed.usage) {
inputTokens = parsed.usage.prompt_tokens || 0;
outputTokens = parsed.usage.completion_tokens || 0;
}
} catch {
// 忽略解析错误,继续处理其他数据
}
controller.enqueue(new TextEncoder().encode(`data: ${data}\n\n`));
}
}
}
} catch (error) {
console.error("Stream processing error:", error);
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Stream processing error: ${error}`,
},
});
controller.close();
}
},
});
return new NextResponse(stream, {
headers: {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
},
});
} catch (error) {
console.error("Error executing simulator run:", error);
// 更新运行状态为失败
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Execution error: ${error}`,
},
}).catch(() => {}); // 忽略更新失败
return NextResponse.json(
{ error: "Internal server error" },
{ status: 500 }
);
}
}