add uniapi
This commit is contained in:
parent
1b81639e92
commit
947b0d08d3
@ -74,7 +74,7 @@ export default function AdminModelsPage() {
|
||||
const [selectedPlan, setSelectedPlan] = useState<string>('')
|
||||
const [selectedModels, setSelectedModels] = useState<string[]>([])
|
||||
const [showAvailableModels, setShowAvailableModels] = useState(false)
|
||||
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate'>('openrouter')
|
||||
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi'>('openrouter')
|
||||
|
||||
useEffect(() => {
|
||||
loadInitialData()
|
||||
@ -309,7 +309,7 @@ export default function AdminModelsPage() {
|
||||
<select
|
||||
value={selectedServiceProvider}
|
||||
onChange={(e) => {
|
||||
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate')
|
||||
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate' | 'uniapi')
|
||||
setShowAvailableModels(false)
|
||||
setSelectedModels([])
|
||||
}}
|
||||
@ -317,6 +317,7 @@ export default function AdminModelsPage() {
|
||||
>
|
||||
<option value="openrouter">OpenRouter (Text Models)</option>
|
||||
<option value="replicate">Replicate (Image/Video/Audio Models)</option>
|
||||
<option value="uniapi">UniAPI (Multi-modal Models)</option>
|
||||
</select>
|
||||
</div>
|
||||
)}
|
||||
@ -362,7 +363,7 @@ export default function AdminModelsPage() {
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3 max-h-96 overflow-y-auto">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 xl:grid-cols-3 gap-4 max-h-96 overflow-y-auto">
|
||||
{availableModels.map(model => {
|
||||
const isSelected = selectedModels.includes(model.modelId)
|
||||
const isAlreadyAdded = models.some(m =>
|
||||
@ -372,7 +373,7 @@ export default function AdminModelsPage() {
|
||||
return (
|
||||
<div
|
||||
key={model.modelId}
|
||||
className={`border rounded-lg p-3 cursor-pointer transition-colors ${
|
||||
className={`border rounded-lg p-4 cursor-pointer transition-colors ${
|
||||
isAlreadyAdded
|
||||
? 'border-yellow-300 bg-yellow-50 dark:bg-yellow-900/20'
|
||||
: isSelected
|
||||
@ -383,17 +384,19 @@ export default function AdminModelsPage() {
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center space-x-2 mb-1">
|
||||
<h4 className="font-medium text-sm truncate">{model.name}</h4>
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
{model.provider}
|
||||
</Badge>
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{model.outputType}
|
||||
</Badge>
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{model.serviceProvider}
|
||||
</Badge>
|
||||
<div className="mb-2">
|
||||
<h4 className="font-medium text-base truncate mb-2">{model.name}</h4>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
{model.provider}
|
||||
</Badge>
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{model.outputType}
|
||||
</Badge>
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{model.serviceProvider}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2 text-xs text-muted-foreground">
|
||||
{model.maxTokens && (
|
||||
|
@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from 'next/server'
|
||||
import { prisma } from '@/lib/prisma'
|
||||
import { OpenRouterService } from '@/lib/openrouter'
|
||||
import { ReplicateService } from '@/lib/replicate'
|
||||
import { UniAPIService } from '@/lib/uniapi'
|
||||
|
||||
// GET /api/admin/models - 获取所有模型(按套餐分组)
|
||||
export async function GET() {
|
||||
@ -95,6 +96,13 @@ export async function POST(request: NextRequest) {
|
||||
...transformedVideoModels,
|
||||
...transformedAudioModels
|
||||
]
|
||||
} else if (provider === 'uniapi') {
|
||||
// 从 UniAPI 获取可用模型
|
||||
const uniAPIService = new UniAPIService()
|
||||
const uniAPIModels = await uniAPIService.getAvailableModels()
|
||||
availableModels = uniAPIModels
|
||||
.map(model => uniAPIService.transformModelForDB(model))
|
||||
.filter(model => model !== null) // 过滤掉可能的 null 值
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
@ -113,6 +121,40 @@ export async function POST(request: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
// 检查模型 ID 的唯一性
|
||||
const modelIds = selectedModels.map(model => model.modelId)
|
||||
const existingModels = await prisma.model.findMany({
|
||||
where: {
|
||||
modelId: {
|
||||
in: modelIds
|
||||
},
|
||||
subscriptionPlanId: {
|
||||
not: planId // 排除当前套餐,允许同一套餐内更新
|
||||
}
|
||||
},
|
||||
include: {
|
||||
subscriptionPlan: {
|
||||
select: {
|
||||
displayName: true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if (existingModels.length > 0) {
|
||||
const conflicts = existingModels.map(model => ({
|
||||
modelId: model.modelId,
|
||||
existingPlan: model.subscriptionPlan.displayName,
|
||||
existingServiceProvider: model.serviceProvider
|
||||
}))
|
||||
|
||||
return NextResponse.json({
|
||||
error: 'Model ID conflicts detected',
|
||||
conflicts: conflicts,
|
||||
message: `The following model IDs are already used by other service providers: ${conflicts.map(c => `${c.modelId} (used by ${c.existingServiceProvider} in ${c.existingPlan})`).join(', ')}`
|
||||
}, { status: 409 })
|
||||
}
|
||||
|
||||
// 批量创建模型记录
|
||||
const results = []
|
||||
for (const modelData of selectedModels) {
|
||||
|
@ -3,6 +3,7 @@ import { createServerSupabaseClient } from "@/lib/supabase-server";
|
||||
import { prisma } from "@/lib/prisma";
|
||||
import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
|
||||
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
|
||||
import { UniAPIService } from "@/lib/uniapi";
|
||||
|
||||
export async function POST(
|
||||
request: NextRequest,
|
||||
@ -58,40 +59,115 @@ export async function POST(
|
||||
const promptContent = getPromptContent(run);
|
||||
const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`;
|
||||
|
||||
const requestBody = {
|
||||
model: run.model.modelId,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: finalPrompt,
|
||||
}
|
||||
],
|
||||
temperature: run.temperature || 0.7,
|
||||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||||
...(run.topP && { top_p: run.topP }),
|
||||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||||
stream: true,
|
||||
};
|
||||
let apiResponse: Response;
|
||||
|
||||
const openRouterResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||||
"X-Title": "Prmbr - AI Prompt Studio",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
// 根据服务提供商选择不同的API
|
||||
if (run.model.serviceProvider === 'openrouter') {
|
||||
const requestBody = {
|
||||
model: run.model.modelId,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: finalPrompt,
|
||||
}
|
||||
],
|
||||
temperature: run.temperature || 0.7,
|
||||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||||
...(run.topP && { top_p: run.topP }),
|
||||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (!openRouterResponse.ok) {
|
||||
const errorText = await openRouterResponse.text();
|
||||
apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`,
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000",
|
||||
"X-Title": "Prmbr - AI Prompt Studio",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
} else if (run.model.serviceProvider === 'uniapi') {
|
||||
const uniAPIService = new UniAPIService();
|
||||
|
||||
if (run.model.outputType === 'text' || run.model.outputType === 'multimodal') {
|
||||
// 使用聊天完成API
|
||||
const requestBody = {
|
||||
model: run.model.modelId,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: finalPrompt,
|
||||
}
|
||||
],
|
||||
temperature: run.temperature || 0.7,
|
||||
...(run.maxTokens && { max_tokens: run.maxTokens }),
|
||||
...(run.topP && { top_p: run.topP }),
|
||||
...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }),
|
||||
...(run.presencePenalty && { presence_penalty: run.presencePenalty }),
|
||||
};
|
||||
|
||||
// 注意:UniAPI 可能不支持流式响应,这里需要调整
|
||||
const response = await uniAPIService.createChatCompletion(requestBody);
|
||||
|
||||
// 创建模拟的流式响应
|
||||
const mockStream = new ReadableStream({
|
||||
start(controller) {
|
||||
const content = response.choices?.[0]?.message?.content || '';
|
||||
const usage = response.usage || { prompt_tokens: 0, completion_tokens: 0 };
|
||||
|
||||
// 模拟流式数据
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
choices: [{
|
||||
delta: { content: content }
|
||||
}]
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
usage: usage
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||||
controller.close();
|
||||
}
|
||||
});
|
||||
|
||||
apiResponse = new Response(mockStream, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 对于非文本模型,返回错误
|
||||
await prisma.simulatorRun.update({
|
||||
where: { id },
|
||||
data: {
|
||||
status: "failed",
|
||||
error: `Unsupported model type: ${run.model.outputType}`,
|
||||
},
|
||||
});
|
||||
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
|
||||
}
|
||||
} else {
|
||||
await prisma.simulatorRun.update({
|
||||
where: { id },
|
||||
data: {
|
||||
status: "failed",
|
||||
error: `OpenRouter API error: ${openRouterResponse.status} - ${errorText}`,
|
||||
error: `Unsupported service provider: ${run.model.serviceProvider}`,
|
||||
},
|
||||
});
|
||||
return NextResponse.json({ error: "Unsupported service provider" }, { status: 400 });
|
||||
}
|
||||
|
||||
if (!apiResponse.ok) {
|
||||
const errorText = await apiResponse.text();
|
||||
await prisma.simulatorRun.update({
|
||||
where: { id },
|
||||
data: {
|
||||
status: "failed",
|
||||
error: `API error: ${apiResponse.status} - ${errorText}`,
|
||||
},
|
||||
});
|
||||
return NextResponse.json({ error: "AI API request failed" }, { status: 500 });
|
||||
@ -100,7 +176,7 @@ export async function POST(
|
||||
// 创建流式响应
|
||||
const stream = new ReadableStream({
|
||||
async start(controller) {
|
||||
const reader = openRouterResponse.body?.getReader();
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
controller.close();
|
||||
return;
|
||||
|
@ -38,6 +38,8 @@ export async function GET() {
|
||||
modelId: model.modelId,
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
serviceProvider: model.serviceProvider,
|
||||
outputType: model.outputType,
|
||||
description: model.description,
|
||||
maxTokens: model.maxTokens,
|
||||
inputCostPer1k: model.inputCostPer1k,
|
||||
|
204
src/lib/uniapi.ts
Normal file
204
src/lib/uniapi.ts
Normal file
@ -0,0 +1,204 @@
|
||||
interface UniAPIModel {
|
||||
id: string
|
||||
name: string
|
||||
description?: string
|
||||
provider?: string
|
||||
context_length?: number
|
||||
pricing?: {
|
||||
input: number
|
||||
output: number
|
||||
}
|
||||
features?: string[]
|
||||
type?: 'text' | 'image' | 'audio' | 'video' | 'multimodal'
|
||||
}
|
||||
|
||||
interface UniAPIResponse {
|
||||
data: UniAPIModel[]
|
||||
error?: string
|
||||
}
|
||||
|
||||
export class UniAPIService {
|
||||
private apiKey: string
|
||||
private baseUrl = 'https://api.uniapi.io/v1'
|
||||
|
||||
constructor() {
|
||||
this.apiKey = process.env.UNIAPI_API_KEY || ''
|
||||
if (!this.apiKey) {
|
||||
throw new Error('UNIAPI_API_KEY environment variable is required')
|
||||
}
|
||||
}
|
||||
|
||||
async getAvailableModels(): Promise<UniAPIModel[]> {
|
||||
try {
|
||||
const response = await fetch(`${this.baseUrl}/models`, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
throw new Error(`UniAPI error: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
// 检查不同可能的响应格式
|
||||
let models: UniAPIModel[] = []
|
||||
|
||||
if (data.data && Array.isArray(data.data)) {
|
||||
models = data.data
|
||||
} else if (Array.isArray(data)) {
|
||||
models = data
|
||||
} else if (data.models && Array.isArray(data.models)) {
|
||||
models = data.models
|
||||
} else {
|
||||
throw new Error('Unexpected response format from UniAPI')
|
||||
}
|
||||
|
||||
if (data.error) {
|
||||
throw new Error(`UniAPI API error: ${data.error}`)
|
||||
}
|
||||
|
||||
return models
|
||||
} catch (error) {
|
||||
console.error('Error fetching UniAPI models:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 将 UniAPI 模型转换为我们数据库的格式
|
||||
transformModelForDB(model: UniAPIModel) {
|
||||
// 确保模型有基本的属性
|
||||
if (!model.id) {
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
modelId: model.id,
|
||||
name: model.name || model.id || 'Unnamed Model', // 如果没有名称,使用模型ID作为后备
|
||||
provider: model.provider || this.extractProvider(model.id),
|
||||
serviceProvider: 'uniapi',
|
||||
outputType: model.type || 'text',
|
||||
description: model.description || null,
|
||||
maxTokens: model.context_length || null,
|
||||
inputCostPer1k: model.pricing?.input ? model.pricing.input * 1000 : null,
|
||||
outputCostPer1k: model.pricing?.output ? model.pricing.output * 1000 : null,
|
||||
supportedFeatures: {
|
||||
type: model.type,
|
||||
features: model.features || [],
|
||||
},
|
||||
metadata: {
|
||||
original: model,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
private extractProvider(modelId: string): string {
|
||||
// 从模型 ID 中提取提供商名称
|
||||
const providerMap: Record<string, string> = {
|
||||
'openai': 'OpenAI',
|
||||
'anthropic': 'Anthropic',
|
||||
'claude': 'Anthropic',
|
||||
'google': 'Google',
|
||||
'gemini': 'Google',
|
||||
'meta': 'Meta',
|
||||
'llama': 'Meta',
|
||||
'microsoft': 'Microsoft',
|
||||
'mistral': 'Mistral AI',
|
||||
'cohere': 'Cohere',
|
||||
'stability': 'Stability AI',
|
||||
'midjourney': 'Midjourney',
|
||||
'dall-e': 'OpenAI',
|
||||
'gpt': 'OpenAI',
|
||||
}
|
||||
|
||||
const modelLower = modelId.toLowerCase()
|
||||
|
||||
// 尝试从 ID 中匹配提供商
|
||||
for (const [key, value] of Object.entries(providerMap)) {
|
||||
if (modelLower.includes(key)) {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// 如果包含斜杠,取斜杠前的部分
|
||||
if (modelId.includes('/')) {
|
||||
const provider = modelId.split('/')[0]
|
||||
return providerMap[provider.toLowerCase()] ||
|
||||
provider.charAt(0).toUpperCase() + provider.slice(1)
|
||||
}
|
||||
|
||||
// 如果包含连字符,取第一个部分
|
||||
if (modelId.includes('-')) {
|
||||
const provider = modelId.split('-')[0]
|
||||
return providerMap[provider.toLowerCase()] ||
|
||||
provider.charAt(0).toUpperCase() + provider.slice(1)
|
||||
}
|
||||
|
||||
return 'UniAPI'
|
||||
}
|
||||
|
||||
// 执行聊天完成请求
|
||||
async createChatCompletion(params: {
|
||||
model: string
|
||||
messages: Array<{role: string, content: string}>
|
||||
temperature?: number
|
||||
max_tokens?: number
|
||||
top_p?: number
|
||||
frequency_penalty?: number
|
||||
presence_penalty?: number
|
||||
}) {
|
||||
try {
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(params),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}))
|
||||
throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
|
||||
}
|
||||
|
||||
return await response.json()
|
||||
} catch (error) {
|
||||
console.error('Error calling UniAPI chat completion:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 执行图像生成请求
|
||||
async createImageGeneration(params: {
|
||||
model: string
|
||||
prompt: string
|
||||
size?: string
|
||||
quality?: string
|
||||
n?: number
|
||||
}) {
|
||||
try {
|
||||
const response = await fetch(`${this.baseUrl}/images/generations`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(params),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}))
|
||||
throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
|
||||
}
|
||||
|
||||
return await response.json()
|
||||
} catch (error) {
|
||||
console.error('Error calling UniAPI image generation:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user