add uniapi

This commit is contained in:
songtianlun 2025-08-27 23:30:27 +08:00
parent 1b81639e92
commit 947b0d08d3
5 changed files with 371 additions and 44 deletions

View File

@ -74,7 +74,7 @@ export default function AdminModelsPage() {
const [selectedPlan, setSelectedPlan] = useState<string>('')
const [selectedModels, setSelectedModels] = useState<string[]>([])
const [showAvailableModels, setShowAvailableModels] = useState(false)
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate'>('openrouter')
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi'>('openrouter')
useEffect(() => {
loadInitialData()
@ -309,7 +309,7 @@ export default function AdminModelsPage() {
<select
value={selectedServiceProvider}
onChange={(e) => {
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate')
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate' | 'uniapi')
setShowAvailableModels(false)
setSelectedModels([])
}}
@ -317,6 +317,7 @@ export default function AdminModelsPage() {
>
<option value="openrouter">OpenRouter (Text Models)</option>
<option value="replicate">Replicate (Image/Video/Audio Models)</option>
<option value="uniapi">UniAPI (Multi-modal Models)</option>
</select>
</div>
)}
@ -362,7 +363,7 @@ export default function AdminModelsPage() {
</Button>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3 max-h-96 overflow-y-auto">
<div className="grid grid-cols-1 md:grid-cols-2 xl:grid-cols-3 gap-4 max-h-96 overflow-y-auto">
{availableModels.map(model => {
const isSelected = selectedModels.includes(model.modelId)
const isAlreadyAdded = models.some(m =>
@ -372,7 +373,7 @@ export default function AdminModelsPage() {
return (
<div
key={model.modelId}
className={`border rounded-lg p-3 cursor-pointer transition-colors ${
className={`border rounded-lg p-4 cursor-pointer transition-colors ${
isAlreadyAdded
? 'border-yellow-300 bg-yellow-50 dark:bg-yellow-900/20'
: isSelected
@ -383,17 +384,19 @@ export default function AdminModelsPage() {
>
<div className="flex items-center justify-between">
<div className="flex-1 min-w-0">
<div className="flex items-center space-x-2 mb-1">
<h4 className="font-medium text-sm truncate">{model.name}</h4>
<Badge variant="secondary" className="text-xs">
{model.provider}
</Badge>
<Badge variant="outline" className="text-xs">
{model.outputType}
</Badge>
<Badge variant="outline" className="text-xs">
{model.serviceProvider}
</Badge>
<div className="mb-2">
<h4 className="font-medium text-base truncate mb-2">{model.name}</h4>
<div className="flex flex-wrap gap-1">
<Badge variant="secondary" className="text-xs">
{model.provider}
</Badge>
<Badge variant="outline" className="text-xs">
{model.outputType}
</Badge>
<Badge variant="outline" className="text-xs">
{model.serviceProvider}
</Badge>
</div>
</div>
<div className="flex items-center space-x-2 text-xs text-muted-foreground">
{model.maxTokens && (

View File

@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from 'next/server'
import { prisma } from '@/lib/prisma'
import { OpenRouterService } from '@/lib/openrouter'
import { ReplicateService } from '@/lib/replicate'
import { UniAPIService } from '@/lib/uniapi'
// GET /api/admin/models - 获取所有模型(按套餐分组)
export async function GET() {
@ -95,6 +96,13 @@ export async function POST(request: NextRequest) {
...transformedVideoModels,
...transformedAudioModels
]
} else if (provider === 'uniapi') {
// 从 UniAPI 获取可用模型
const uniAPIService = new UniAPIService()
const uniAPIModels = await uniAPIService.getAvailableModels()
availableModels = uniAPIModels
.map(model => uniAPIService.transformModelForDB(model))
.filter(model => model !== null) // 过滤掉可能的 null 值
}
return NextResponse.json({
@ -113,6 +121,40 @@ export async function POST(request: NextRequest) {
)
}
// 检查模型 ID 的唯一性
const modelIds = selectedModels.map(model => model.modelId)
const existingModels = await prisma.model.findMany({
where: {
modelId: {
in: modelIds
},
subscriptionPlanId: {
not: planId // 排除当前套餐,允许同一套餐内更新
}
},
include: {
subscriptionPlan: {
select: {
displayName: true
}
}
}
})
if (existingModels.length > 0) {
const conflicts = existingModels.map(model => ({
modelId: model.modelId,
existingPlan: model.subscriptionPlan.displayName,
existingServiceProvider: model.serviceProvider
}))
return NextResponse.json({
error: 'Model ID conflicts detected',
conflicts: conflicts,
message: `The following model IDs are already used by other service providers: ${conflicts.map(c => `${c.modelId} (used by ${c.existingServiceProvider} in ${c.existingPlan})`).join(', ')}`
}, { status: 409 })
}
// 批量创建模型记录
const results = []
for (const modelData of selectedModels) {

View File

@ -3,6 +3,7 @@ import { createServerSupabaseClient } from "@/lib/supabase-server";
import { prisma } from "@/lib/prisma";
import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
import { UniAPIService } from "@/lib/uniapi";
export async function POST(
request: NextRequest,
@ -58,40 +59,115 @@ export async function POST(
const promptContent = getPromptContent(run);
const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`;
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
stream: true,
};
let apiResponse: Response;
// 根据服务提供商选择不同的API
if (run.model.serviceProvider === 'openrouter') {
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
stream: true,
};
const openRouterResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST",
headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio",
},
body: JSON.stringify(requestBody),
});
} else if (run.model.serviceProvider === 'uniapi') {
const uniAPIService = new UniAPIService();
if (run.model.outputType === 'text' || run.model.outputType === 'multimodal') {
// 使用聊天完成API
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
};
if (!openRouterResponse.ok) {
const errorText = await openRouterResponse.text();
// 注意UniAPI 可能不支持流式响应,这里需要调整
const response = await uniAPIService.createChatCompletion(requestBody);
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const content = response.choices?.[0]?.message?.content || '';
const usage = response.usage || { prompt_tokens: 0, completion_tokens: 0 };
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: content }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: usage
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
// 对于非文本模型,返回错误
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported model type: ${run.model.outputType}`,
},
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else {
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `OpenRouter API error: ${openRouterResponse.status} - ${errorText}`,
error: `Unsupported service provider: ${run.model.serviceProvider}`,
},
});
return NextResponse.json({ error: "Unsupported service provider" }, { status: 400 });
}
if (!apiResponse.ok) {
const errorText = await apiResponse.text();
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `API error: ${apiResponse.status} - ${errorText}`,
},
});
return NextResponse.json({ error: "AI API request failed" }, { status: 500 });
@ -100,7 +176,7 @@ export async function POST(
// 创建流式响应
const stream = new ReadableStream({
async start(controller) {
const reader = openRouterResponse.body?.getReader();
const reader = apiResponse.body?.getReader();
if (!reader) {
controller.close();
return;

View File

@ -38,6 +38,8 @@ export async function GET() {
modelId: model.modelId,
name: model.name,
provider: model.provider,
serviceProvider: model.serviceProvider,
outputType: model.outputType,
description: model.description,
maxTokens: model.maxTokens,
inputCostPer1k: model.inputCostPer1k,

204
src/lib/uniapi.ts Normal file
View File

@ -0,0 +1,204 @@
interface UniAPIModel {
id: string
name: string
description?: string
provider?: string
context_length?: number
pricing?: {
input: number
output: number
}
features?: string[]
type?: 'text' | 'image' | 'audio' | 'video' | 'multimodal'
}
interface UniAPIResponse {
data: UniAPIModel[]
error?: string
}
export class UniAPIService {
private apiKey: string
private baseUrl = 'https://api.uniapi.io/v1'
constructor() {
this.apiKey = process.env.UNIAPI_API_KEY || ''
if (!this.apiKey) {
throw new Error('UNIAPI_API_KEY environment variable is required')
}
}
async getAvailableModels(): Promise<UniAPIModel[]> {
try {
const response = await fetch(`${this.baseUrl}/models`, {
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
})
if (!response.ok) {
const errorText = await response.text()
throw new Error(`UniAPI error: ${response.status} ${response.statusText}`)
}
const data = await response.json()
// 检查不同可能的响应格式
let models: UniAPIModel[] = []
if (data.data && Array.isArray(data.data)) {
models = data.data
} else if (Array.isArray(data)) {
models = data
} else if (data.models && Array.isArray(data.models)) {
models = data.models
} else {
throw new Error('Unexpected response format from UniAPI')
}
if (data.error) {
throw new Error(`UniAPI API error: ${data.error}`)
}
return models
} catch (error) {
console.error('Error fetching UniAPI models:', error)
throw error
}
}
// 将 UniAPI 模型转换为我们数据库的格式
transformModelForDB(model: UniAPIModel) {
// 确保模型有基本的属性
if (!model.id) {
return null
}
return {
modelId: model.id,
name: model.name || model.id || 'Unnamed Model', // 如果没有名称使用模型ID作为后备
provider: model.provider || this.extractProvider(model.id),
serviceProvider: 'uniapi',
outputType: model.type || 'text',
description: model.description || null,
maxTokens: model.context_length || null,
inputCostPer1k: model.pricing?.input ? model.pricing.input * 1000 : null,
outputCostPer1k: model.pricing?.output ? model.pricing.output * 1000 : null,
supportedFeatures: {
type: model.type,
features: model.features || [],
},
metadata: {
original: model,
},
}
}
private extractProvider(modelId: string): string {
// 从模型 ID 中提取提供商名称
const providerMap: Record<string, string> = {
'openai': 'OpenAI',
'anthropic': 'Anthropic',
'claude': 'Anthropic',
'google': 'Google',
'gemini': 'Google',
'meta': 'Meta',
'llama': 'Meta',
'microsoft': 'Microsoft',
'mistral': 'Mistral AI',
'cohere': 'Cohere',
'stability': 'Stability AI',
'midjourney': 'Midjourney',
'dall-e': 'OpenAI',
'gpt': 'OpenAI',
}
const modelLower = modelId.toLowerCase()
// 尝试从 ID 中匹配提供商
for (const [key, value] of Object.entries(providerMap)) {
if (modelLower.includes(key)) {
return value
}
}
// 如果包含斜杠,取斜杠前的部分
if (modelId.includes('/')) {
const provider = modelId.split('/')[0]
return providerMap[provider.toLowerCase()] ||
provider.charAt(0).toUpperCase() + provider.slice(1)
}
// 如果包含连字符,取第一个部分
if (modelId.includes('-')) {
const provider = modelId.split('-')[0]
return providerMap[provider.toLowerCase()] ||
provider.charAt(0).toUpperCase() + provider.slice(1)
}
return 'UniAPI'
}
// 执行聊天完成请求
async createChatCompletion(params: {
model: string
messages: Array<{role: string, content: string}>
temperature?: number
max_tokens?: number
top_p?: number
frequency_penalty?: number
presence_penalty?: number
}) {
try {
const response = await fetch(`${this.baseUrl}/chat/completions`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
if (!response.ok) {
const errorData = await response.json().catch(() => ({}))
throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
}
return await response.json()
} catch (error) {
console.error('Error calling UniAPI chat completion:', error)
throw error
}
}
// 执行图像生成请求
async createImageGeneration(params: {
model: string
prompt: string
size?: string
quality?: string
n?: number
}) {
try {
const response = await fetch(`${this.baseUrl}/images/generations`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
if (!response.ok) {
const errorData = await response.json().catch(() => ({}))
throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
}
return await response.json()
} catch (error) {
console.error('Error calling UniAPI image generation:', error)
throw error
}
}
}