add uniapi

This commit is contained in:
songtianlun 2025-08-27 23:30:27 +08:00
parent 1b81639e92
commit 947b0d08d3
5 changed files with 371 additions and 44 deletions

View File

@ -74,7 +74,7 @@ export default function AdminModelsPage() {
const [selectedPlan, setSelectedPlan] = useState<string>('') const [selectedPlan, setSelectedPlan] = useState<string>('')
const [selectedModels, setSelectedModels] = useState<string[]>([]) const [selectedModels, setSelectedModels] = useState<string[]>([])
const [showAvailableModels, setShowAvailableModels] = useState(false) const [showAvailableModels, setShowAvailableModels] = useState(false)
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate'>('openrouter') const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi'>('openrouter')
useEffect(() => { useEffect(() => {
loadInitialData() loadInitialData()
@ -309,7 +309,7 @@ export default function AdminModelsPage() {
<select <select
value={selectedServiceProvider} value={selectedServiceProvider}
onChange={(e) => { onChange={(e) => {
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate') setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate' | 'uniapi')
setShowAvailableModels(false) setShowAvailableModels(false)
setSelectedModels([]) setSelectedModels([])
}} }}
@ -317,6 +317,7 @@ export default function AdminModelsPage() {
> >
<option value="openrouter">OpenRouter (Text Models)</option> <option value="openrouter">OpenRouter (Text Models)</option>
<option value="replicate">Replicate (Image/Video/Audio Models)</option> <option value="replicate">Replicate (Image/Video/Audio Models)</option>
<option value="uniapi">UniAPI (Multi-modal Models)</option>
</select> </select>
</div> </div>
)} )}
@ -362,7 +363,7 @@ export default function AdminModelsPage() {
</Button> </Button>
</div> </div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3 max-h-96 overflow-y-auto"> <div className="grid grid-cols-1 md:grid-cols-2 xl:grid-cols-3 gap-4 max-h-96 overflow-y-auto">
{availableModels.map(model => { {availableModels.map(model => {
const isSelected = selectedModels.includes(model.modelId) const isSelected = selectedModels.includes(model.modelId)
const isAlreadyAdded = models.some(m => const isAlreadyAdded = models.some(m =>
@ -372,7 +373,7 @@ export default function AdminModelsPage() {
return ( return (
<div <div
key={model.modelId} key={model.modelId}
className={`border rounded-lg p-3 cursor-pointer transition-colors ${ className={`border rounded-lg p-4 cursor-pointer transition-colors ${
isAlreadyAdded isAlreadyAdded
? 'border-yellow-300 bg-yellow-50 dark:bg-yellow-900/20' ? 'border-yellow-300 bg-yellow-50 dark:bg-yellow-900/20'
: isSelected : isSelected
@ -383,17 +384,19 @@ export default function AdminModelsPage() {
> >
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
<div className="flex-1 min-w-0"> <div className="flex-1 min-w-0">
<div className="flex items-center space-x-2 mb-1"> <div className="mb-2">
<h4 className="font-medium text-sm truncate">{model.name}</h4> <h4 className="font-medium text-base truncate mb-2">{model.name}</h4>
<Badge variant="secondary" className="text-xs"> <div className="flex flex-wrap gap-1">
{model.provider} <Badge variant="secondary" className="text-xs">
</Badge> {model.provider}
<Badge variant="outline" className="text-xs"> </Badge>
{model.outputType} <Badge variant="outline" className="text-xs">
</Badge> {model.outputType}
<Badge variant="outline" className="text-xs"> </Badge>
{model.serviceProvider} <Badge variant="outline" className="text-xs">
</Badge> {model.serviceProvider}
</Badge>
</div>
</div> </div>
<div className="flex items-center space-x-2 text-xs text-muted-foreground"> <div className="flex items-center space-x-2 text-xs text-muted-foreground">
{model.maxTokens && ( {model.maxTokens && (

View File

@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from 'next/server'
import { prisma } from '@/lib/prisma' import { prisma } from '@/lib/prisma'
import { OpenRouterService } from '@/lib/openrouter' import { OpenRouterService } from '@/lib/openrouter'
import { ReplicateService } from '@/lib/replicate' import { ReplicateService } from '@/lib/replicate'
import { UniAPIService } from '@/lib/uniapi'
// GET /api/admin/models - 获取所有模型(按套餐分组) // GET /api/admin/models - 获取所有模型(按套餐分组)
export async function GET() { export async function GET() {
@ -95,6 +96,13 @@ export async function POST(request: NextRequest) {
...transformedVideoModels, ...transformedVideoModels,
...transformedAudioModels ...transformedAudioModels
] ]
} else if (provider === 'uniapi') {
// 从 UniAPI 获取可用模型
const uniAPIService = new UniAPIService()
const uniAPIModels = await uniAPIService.getAvailableModels()
availableModels = uniAPIModels
.map(model => uniAPIService.transformModelForDB(model))
.filter(model => model !== null) // 过滤掉可能的 null 值
} }
return NextResponse.json({ return NextResponse.json({
@ -113,6 +121,40 @@ export async function POST(request: NextRequest) {
) )
} }
// 检查模型 ID 的唯一性
const modelIds = selectedModels.map(model => model.modelId)
const existingModels = await prisma.model.findMany({
where: {
modelId: {
in: modelIds
},
subscriptionPlanId: {
not: planId // 排除当前套餐,允许同一套餐内更新
}
},
include: {
subscriptionPlan: {
select: {
displayName: true
}
}
}
})
if (existingModels.length > 0) {
const conflicts = existingModels.map(model => ({
modelId: model.modelId,
existingPlan: model.subscriptionPlan.displayName,
existingServiceProvider: model.serviceProvider
}))
return NextResponse.json({
error: 'Model ID conflicts detected',
conflicts: conflicts,
message: `The following model IDs are already used by other service providers: ${conflicts.map(c => `${c.modelId} (used by ${c.existingServiceProvider} in ${c.existingPlan})`).join(', ')}`
}, { status: 409 })
}
// 批量创建模型记录 // 批量创建模型记录
const results = [] const results = []
for (const modelData of selectedModels) { for (const modelData of selectedModels) {

View File

@ -3,6 +3,7 @@ import { createServerSupabaseClient } from "@/lib/supabase-server";
import { prisma } from "@/lib/prisma"; import { prisma } from "@/lib/prisma";
import { getPromptContent, calculateCost } from "@/lib/simulator-utils"; import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit"; import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
import { UniAPIService } from "@/lib/uniapi";
export async function POST( export async function POST(
request: NextRequest, request: NextRequest,
@ -58,40 +59,115 @@ export async function POST(
const promptContent = getPromptContent(run); const promptContent = getPromptContent(run);
const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`; const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`;
const requestBody = { let apiResponse: Response;
model: run.model.modelId,
messages: [ // 根据服务提供商选择不同的API
{ if (run.model.serviceProvider === 'openrouter') {
role: "user", const requestBody = {
content: finalPrompt, model: run.model.modelId,
} messages: [
], {
temperature: run.temperature || 0.7, role: "user",
...(run.maxTokens && { max_tokens: run.maxTokens }), content: finalPrompt,
...(run.topP && { top_p: run.topP }), }
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }), ],
...(run.presencePenalty && { presence_penalty: run.presencePenalty }), temperature: run.temperature || 0.7,
stream: true, ...(run.maxTokens && { max_tokens: run.maxTokens }),
}; ...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
stream: true,
};
const openRouterResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", { apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
method: "POST", method: "POST",
headers: { headers: {
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`, "Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
"Content-Type": "application/json", "Content-Type": "application/json",
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000", "HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
"X-Title": "Prmbr - AI Prompt Studio", "X-Title": "Prmbr - AI Prompt Studio",
}, },
body: JSON.stringify(requestBody), body: JSON.stringify(requestBody),
}); });
} else if (run.model.serviceProvider === 'uniapi') {
const uniAPIService = new UniAPIService();
if (run.model.outputType === 'text' || run.model.outputType === 'multimodal') {
// 使用聊天完成API
const requestBody = {
model: run.model.modelId,
messages: [
{
role: "user",
content: finalPrompt,
}
],
temperature: run.temperature || 0.7,
...(run.maxTokens && { max_tokens: run.maxTokens }),
...(run.topP && { top_p: run.topP }),
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
};
if (!openRouterResponse.ok) { // 注意UniAPI 可能不支持流式响应,这里需要调整
const errorText = await openRouterResponse.text(); const response = await uniAPIService.createChatCompletion(requestBody);
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const content = response.choices?.[0]?.message?.content || '';
const usage = response.usage || { prompt_tokens: 0, completion_tokens: 0 };
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: content }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: usage
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
// 对于非文本模型,返回错误
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported model type: ${run.model.outputType}`,
},
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else {
await prisma.simulatorRun.update({ await prisma.simulatorRun.update({
where: { id }, where: { id },
data: { data: {
status: "failed", status: "failed",
error: `OpenRouter API error: ${openRouterResponse.status} - ${errorText}`, error: `Unsupported service provider: ${run.model.serviceProvider}`,
},
});
return NextResponse.json({ error: "Unsupported service provider" }, { status: 400 });
}
if (!apiResponse.ok) {
const errorText = await apiResponse.text();
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `API error: ${apiResponse.status} - ${errorText}`,
}, },
}); });
return NextResponse.json({ error: "AI API request failed" }, { status: 500 }); return NextResponse.json({ error: "AI API request failed" }, { status: 500 });
@ -100,7 +176,7 @@ export async function POST(
// 创建流式响应 // 创建流式响应
const stream = new ReadableStream({ const stream = new ReadableStream({
async start(controller) { async start(controller) {
const reader = openRouterResponse.body?.getReader(); const reader = apiResponse.body?.getReader();
if (!reader) { if (!reader) {
controller.close(); controller.close();
return; return;

View File

@ -38,6 +38,8 @@ export async function GET() {
modelId: model.modelId, modelId: model.modelId,
name: model.name, name: model.name,
provider: model.provider, provider: model.provider,
serviceProvider: model.serviceProvider,
outputType: model.outputType,
description: model.description, description: model.description,
maxTokens: model.maxTokens, maxTokens: model.maxTokens,
inputCostPer1k: model.inputCostPer1k, inputCostPer1k: model.inputCostPer1k,

204
src/lib/uniapi.ts Normal file
View File

@ -0,0 +1,204 @@
interface UniAPIModel {
id: string
name: string
description?: string
provider?: string
context_length?: number
pricing?: {
input: number
output: number
}
features?: string[]
type?: 'text' | 'image' | 'audio' | 'video' | 'multimodal'
}
interface UniAPIResponse {
data: UniAPIModel[]
error?: string
}
export class UniAPIService {
private apiKey: string
private baseUrl = 'https://api.uniapi.io/v1'
constructor() {
this.apiKey = process.env.UNIAPI_API_KEY || ''
if (!this.apiKey) {
throw new Error('UNIAPI_API_KEY environment variable is required')
}
}
async getAvailableModels(): Promise<UniAPIModel[]> {
try {
const response = await fetch(`${this.baseUrl}/models`, {
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
})
if (!response.ok) {
const errorText = await response.text()
throw new Error(`UniAPI error: ${response.status} ${response.statusText}`)
}
const data = await response.json()
// 检查不同可能的响应格式
let models: UniAPIModel[] = []
if (data.data && Array.isArray(data.data)) {
models = data.data
} else if (Array.isArray(data)) {
models = data
} else if (data.models && Array.isArray(data.models)) {
models = data.models
} else {
throw new Error('Unexpected response format from UniAPI')
}
if (data.error) {
throw new Error(`UniAPI API error: ${data.error}`)
}
return models
} catch (error) {
console.error('Error fetching UniAPI models:', error)
throw error
}
}
// 将 UniAPI 模型转换为我们数据库的格式
transformModelForDB(model: UniAPIModel) {
// 确保模型有基本的属性
if (!model.id) {
return null
}
return {
modelId: model.id,
name: model.name || model.id || 'Unnamed Model', // 如果没有名称使用模型ID作为后备
provider: model.provider || this.extractProvider(model.id),
serviceProvider: 'uniapi',
outputType: model.type || 'text',
description: model.description || null,
maxTokens: model.context_length || null,
inputCostPer1k: model.pricing?.input ? model.pricing.input * 1000 : null,
outputCostPer1k: model.pricing?.output ? model.pricing.output * 1000 : null,
supportedFeatures: {
type: model.type,
features: model.features || [],
},
metadata: {
original: model,
},
}
}
private extractProvider(modelId: string): string {
// 从模型 ID 中提取提供商名称
const providerMap: Record<string, string> = {
'openai': 'OpenAI',
'anthropic': 'Anthropic',
'claude': 'Anthropic',
'google': 'Google',
'gemini': 'Google',
'meta': 'Meta',
'llama': 'Meta',
'microsoft': 'Microsoft',
'mistral': 'Mistral AI',
'cohere': 'Cohere',
'stability': 'Stability AI',
'midjourney': 'Midjourney',
'dall-e': 'OpenAI',
'gpt': 'OpenAI',
}
const modelLower = modelId.toLowerCase()
// 尝试从 ID 中匹配提供商
for (const [key, value] of Object.entries(providerMap)) {
if (modelLower.includes(key)) {
return value
}
}
// 如果包含斜杠,取斜杠前的部分
if (modelId.includes('/')) {
const provider = modelId.split('/')[0]
return providerMap[provider.toLowerCase()] ||
provider.charAt(0).toUpperCase() + provider.slice(1)
}
// 如果包含连字符,取第一个部分
if (modelId.includes('-')) {
const provider = modelId.split('-')[0]
return providerMap[provider.toLowerCase()] ||
provider.charAt(0).toUpperCase() + provider.slice(1)
}
return 'UniAPI'
}
// 执行聊天完成请求
async createChatCompletion(params: {
model: string
messages: Array<{role: string, content: string}>
temperature?: number
max_tokens?: number
top_p?: number
frequency_penalty?: number
presence_penalty?: number
}) {
try {
const response = await fetch(`${this.baseUrl}/chat/completions`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
if (!response.ok) {
const errorData = await response.json().catch(() => ({}))
throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
}
return await response.json()
} catch (error) {
console.error('Error calling UniAPI chat completion:', error)
throw error
}
}
// 执行图像生成请求
async createImageGeneration(params: {
model: string
prompt: string
size?: string
quality?: string
n?: number
}) {
try {
const response = await fetch(`${this.baseUrl}/images/generations`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
})
if (!response.ok) {
const errorData = await response.json().catch(() => ({}))
throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
}
return await response.json()
} catch (error) {
console.error('Error calling UniAPI image generation:', error)
throw error
}
}
}