398 lines
13 KiB
TypeScript
398 lines
13 KiB
TypeScript
import { NextRequest, NextResponse } from "next/server";
|
||
import { createServerSupabaseClient } from "@/lib/supabase-server";
|
||
import { prisma } from "@/lib/prisma";
|
||
import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
|
||
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
|
||
import { UniAPIService } from "@/lib/uniapi";
|
||
import { FalService } from "@/lib/fal";
|
||
|
||
export async function POST(
|
||
request: NextRequest,
|
||
{ params }: { params: Promise<{ id: string }> }
|
||
) {
|
||
const { id } = await params;
|
||
try {
|
||
const supabase = await createServerSupabaseClient();
|
||
const { data: { user }, error: authError } = await supabase.auth.getUser();
|
||
|
||
if (authError || !user) {
|
||
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
|
||
}
|
||
|
||
const run = await prisma.simulatorRun.findFirst({
|
||
where: {
|
||
id,
|
||
userId: user.id,
|
||
},
|
||
include: {
|
||
prompt: true,
|
||
promptVersion: true, // 内部使用,用于获取内容
|
||
model: true,
|
||
}
|
||
});
|
||
|
||
if (!run) {
|
||
return NextResponse.json({ error: "Run not found" }, { status: 404 });
|
||
}
|
||
|
||
if (run.status !== "pending") {
|
||
return NextResponse.json({ error: "Run already executed" }, { status: 400 });
|
||
}
|
||
|
||
// Check user's credit balance before execution
|
||
const userBalance = await getUserBalance(user.id);
|
||
const estimatedCost = calculateCost(0, 100, run.model); // Rough estimate
|
||
|
||
if (userBalance < estimatedCost) {
|
||
return NextResponse.json(
|
||
{ error: "Insufficient credit balance", requiredCredit: estimatedCost, currentBalance: userBalance },
|
||
{ status: 402 } // Payment Required
|
||
);
|
||
}
|
||
|
||
// 更新状态为运行中
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: { status: "running" },
|
||
});
|
||
|
||
// 准备AI API请求
|
||
const promptContent = getPromptContent(run);
|
||
const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`;
|
||
|
||
let apiResponse: Response;
|
||
|
||
// 根据服务提供商选择不同的API
|
||
if (run.model.serviceProvider === 'openrouter') {
|
||
const requestBody = {
|
||
model: run.model.modelId,
|
||
messages: [
|
||
{
|
||
role: "user",
|
||
content: finalPrompt,
|
||
}
|
||
],
|
||
temperature: run.temperature || 0.7,
|
||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||
...(run.topP && { top_p: run.topP }),
|
||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||
stream: true,
|
||
};
|
||
|
||
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||
method: "POST",
|
||
headers: {
|
||
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||
"Content-Type": "application/json",
|
||
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||
"X-Title": "Prmbr - AI Prompt Studio",
|
||
},
|
||
body: JSON.stringify(requestBody),
|
||
});
|
||
} else if (run.model.serviceProvider === 'uniapi') {
|
||
const uniAPIService = new UniAPIService();
|
||
|
||
if (run.model.outputType === 'text' || run.model.outputType === 'multimodal') {
|
||
// 使用聊天完成API
|
||
const requestBody = {
|
||
model: run.model.modelId,
|
||
messages: [
|
||
{
|
||
role: "user",
|
||
content: finalPrompt,
|
||
}
|
||
],
|
||
temperature: run.temperature || 0.7,
|
||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||
...(run.topP && { top_p: run.topP }),
|
||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||
};
|
||
|
||
// 注意:UniAPI 可能不支持流式响应,这里需要调整
|
||
const response = await uniAPIService.createChatCompletion(requestBody);
|
||
|
||
// 创建模拟的流式响应
|
||
const mockStream = new ReadableStream({
|
||
start(controller) {
|
||
const content = response.choices?.[0]?.message?.content || '';
|
||
const usage = response.usage || { prompt_tokens: 0, completion_tokens: 0 };
|
||
|
||
// 模拟流式数据
|
||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||
choices: [{
|
||
delta: { content: content }
|
||
}]
|
||
})}\n\n`));
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||
usage: usage
|
||
})}\n\n`));
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||
controller.close();
|
||
}
|
||
});
|
||
|
||
apiResponse = new Response(mockStream, {
|
||
headers: {
|
||
'Content-Type': 'text/event-stream',
|
||
}
|
||
});
|
||
} else {
|
||
// 对于非文本模型,返回错误
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "failed",
|
||
error: `Unsupported model type: ${run.model.outputType}`,
|
||
},
|
||
});
|
||
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
|
||
}
|
||
} else if (run.model.serviceProvider === 'fal') {
|
||
const falService = new FalService();
|
||
|
||
if (run.model.outputType === 'image') {
|
||
// 使用图像生成API
|
||
const response = await falService.generateImage({
|
||
model: run.model.modelId,
|
||
prompt: finalPrompt,
|
||
num_images: 1,
|
||
});
|
||
|
||
// 创建模拟的流式响应
|
||
const mockStream = new ReadableStream({
|
||
start(controller) {
|
||
const imageUrl = response.images?.[0]?.url || '';
|
||
const result = {
|
||
images: response.images,
|
||
prompt: finalPrompt,
|
||
model: run.model.modelId
|
||
};
|
||
|
||
// 模拟流式数据
|
||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||
choices: [{
|
||
delta: { content: `Generated image: ${imageUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
|
||
}]
|
||
})}\n\n`));
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
|
||
})}\n\n`));
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||
controller.close();
|
||
}
|
||
});
|
||
|
||
apiResponse = new Response(mockStream, {
|
||
headers: {
|
||
'Content-Type': 'text/event-stream',
|
||
}
|
||
});
|
||
} else if (run.model.outputType === 'video') {
|
||
// 使用视频生成API
|
||
const response = await falService.generateVideo({
|
||
model: run.model.modelId,
|
||
prompt: finalPrompt,
|
||
});
|
||
|
||
// 创建模拟的流式响应
|
||
const mockStream = new ReadableStream({
|
||
start(controller) {
|
||
const videoUrl = response.video?.url || '';
|
||
const result = {
|
||
video: response.video,
|
||
prompt: finalPrompt,
|
||
model: run.model.modelId
|
||
};
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||
choices: [{
|
||
delta: { content: `Generated video: ${videoUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
|
||
}]
|
||
})}\n\n`));
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||
usage: { prompt_tokens: 50, completion_tokens: 150 } // 估算的token使用量
|
||
})}\n\n`));
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||
controller.close();
|
||
}
|
||
});
|
||
|
||
apiResponse = new Response(mockStream, {
|
||
headers: {
|
||
'Content-Type': 'text/event-stream',
|
||
}
|
||
});
|
||
} else {
|
||
// 对于其他类型,返回错误
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "failed",
|
||
error: `Unsupported Fal.ai model type: ${run.model.outputType}`,
|
||
},
|
||
});
|
||
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
|
||
}
|
||
} else {
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "failed",
|
||
error: `Unsupported service provider: ${run.model.serviceProvider}`,
|
||
},
|
||
});
|
||
return NextResponse.json({ error: "Unsupported service provider" }, { status: 400 });
|
||
}
|
||
|
||
if (!apiResponse.ok) {
|
||
const errorText = await apiResponse.text();
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "failed",
|
||
error: `API error: ${apiResponse.status} - ${errorText}`,
|
||
},
|
||
});
|
||
return NextResponse.json({ error: "AI API request failed" }, { status: 500 });
|
||
}
|
||
|
||
// 创建流式响应
|
||
const stream = new ReadableStream({
|
||
async start(controller) {
|
||
const reader = apiResponse.body?.getReader();
|
||
if (!reader) {
|
||
controller.close();
|
||
return;
|
||
}
|
||
|
||
let fullResponse = "";
|
||
let inputTokens = 0;
|
||
let outputTokens = 0;
|
||
const startTime = Date.now();
|
||
|
||
try {
|
||
while (true) {
|
||
const { done, value } = await reader.read();
|
||
if (done) break;
|
||
|
||
const chunk = new TextDecoder().decode(value);
|
||
const lines = chunk.split('\n');
|
||
|
||
for (const line of lines) {
|
||
if (line.startsWith('data: ')) {
|
||
const data = line.slice(6);
|
||
if (data === '[DONE]') {
|
||
// 计算最终数据并更新数据库
|
||
const duration = Date.now() - startTime;
|
||
const actualCost = calculateCost(inputTokens, outputTokens, run.model);
|
||
|
||
// Consume credits for this simulation
|
||
let creditTransaction;
|
||
try {
|
||
creditTransaction = await consumeCreditForSimulation(
|
||
user.id,
|
||
actualCost,
|
||
id,
|
||
`${run.model.name} simulation: ${inputTokens} input + ${outputTokens} output tokens`
|
||
);
|
||
} catch (creditError) {
|
||
// If credit consumption fails, mark the run as failed
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "failed",
|
||
error: `Credit consumption failed: ${creditError}`,
|
||
},
|
||
});
|
||
controller.enqueue(new TextEncoder().encode(`data: {"error": "Credit consumption failed"}\n\n`));
|
||
controller.close();
|
||
return;
|
||
}
|
||
|
||
// Update the run with completion data and credit reference
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "completed",
|
||
output: fullResponse,
|
||
inputTokens,
|
||
outputTokens,
|
||
totalCost: actualCost,
|
||
duration,
|
||
creditId: creditTransaction.id,
|
||
completedAt: new Date(),
|
||
},
|
||
});
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||
controller.close();
|
||
return;
|
||
}
|
||
|
||
try {
|
||
const parsed = JSON.parse(data);
|
||
if (parsed.choices?.[0]?.delta?.content) {
|
||
const content = parsed.choices[0].delta.content;
|
||
fullResponse += content;
|
||
}
|
||
|
||
// 估算token使用量(简化版本)
|
||
if (parsed.usage) {
|
||
inputTokens = parsed.usage.prompt_tokens || 0;
|
||
outputTokens = parsed.usage.completion_tokens || 0;
|
||
}
|
||
} catch {
|
||
// 忽略解析错误,继续处理其他数据
|
||
}
|
||
|
||
controller.enqueue(new TextEncoder().encode(`data: ${data}\n\n`));
|
||
}
|
||
}
|
||
}
|
||
} catch (error) {
|
||
console.error("Stream processing error:", error);
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "failed",
|
||
error: `Stream processing error: ${error}`,
|
||
},
|
||
});
|
||
controller.close();
|
||
}
|
||
},
|
||
});
|
||
|
||
return new NextResponse(stream, {
|
||
headers: {
|
||
'Content-Type': 'text/event-stream',
|
||
'Cache-Control': 'no-cache',
|
||
'Connection': 'keep-alive',
|
||
},
|
||
});
|
||
} catch (error) {
|
||
console.error("Error executing simulator run:", error);
|
||
|
||
// 更新运行状态为失败
|
||
await prisma.simulatorRun.update({
|
||
where: { id },
|
||
data: {
|
||
status: "failed",
|
||
error: `Execution error: ${error}`,
|
||
},
|
||
}).catch(() => {}); // 忽略更新失败
|
||
|
||
return NextResponse.json(
|
||
{ error: "Internal server error" },
|
||
{ status: 500 }
|
||
);
|
||
}
|
||
} |