diff --git a/src/app/admin/models/page.tsx b/src/app/admin/models/page.tsx index bcf06f8..2a8e702 100644 --- a/src/app/admin/models/page.tsx +++ b/src/app/admin/models/page.tsx @@ -74,7 +74,7 @@ export default function AdminModelsPage() { const [selectedPlan, setSelectedPlan] = useState('') const [selectedModels, setSelectedModels] = useState([]) const [showAvailableModels, setShowAvailableModels] = useState(false) - const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate'>('openrouter') + const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi'>('openrouter') useEffect(() => { loadInitialData() @@ -309,7 +309,7 @@ export default function AdminModelsPage() { )} @@ -362,7 +363,7 @@ export default function AdminModelsPage() { -
+
{availableModels.map(model => { const isSelected = selectedModels.includes(model.modelId) const isAlreadyAdded = models.some(m => @@ -372,7 +373,7 @@ export default function AdminModelsPage() { return (
-
-

{model.name}

- - {model.provider} - - - {model.outputType} - - - {model.serviceProvider} - +
+

{model.name}

+
+ + {model.provider} + + + {model.outputType} + + + {model.serviceProvider} + +
{model.maxTokens && ( diff --git a/src/app/api/admin/models/route.ts b/src/app/api/admin/models/route.ts index c10bfb8..8ea2fc4 100644 --- a/src/app/api/admin/models/route.ts +++ b/src/app/api/admin/models/route.ts @@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from 'next/server' import { prisma } from '@/lib/prisma' import { OpenRouterService } from '@/lib/openrouter' import { ReplicateService } from '@/lib/replicate' +import { UniAPIService } from '@/lib/uniapi' // GET /api/admin/models - 获取所有模型(按套餐分组) export async function GET() { @@ -95,6 +96,13 @@ export async function POST(request: NextRequest) { ...transformedVideoModels, ...transformedAudioModels ] + } else if (provider === 'uniapi') { + // 从 UniAPI 获取可用模型 + const uniAPIService = new UniAPIService() + const uniAPIModels = await uniAPIService.getAvailableModels() + availableModels = uniAPIModels + .map(model => uniAPIService.transformModelForDB(model)) + .filter(model => model !== null) // 过滤掉可能的 null 值 } return NextResponse.json({ @@ -113,6 +121,40 @@ export async function POST(request: NextRequest) { ) } + // 检查模型 ID 的唯一性 + const modelIds = selectedModels.map(model => model.modelId) + const existingModels = await prisma.model.findMany({ + where: { + modelId: { + in: modelIds + }, + subscriptionPlanId: { + not: planId // 排除当前套餐,允许同一套餐内更新 + } + }, + include: { + subscriptionPlan: { + select: { + displayName: true + } + } + } + }) + + if (existingModels.length > 0) { + const conflicts = existingModels.map(model => ({ + modelId: model.modelId, + existingPlan: model.subscriptionPlan.displayName, + existingServiceProvider: model.serviceProvider + })) + + return NextResponse.json({ + error: 'Model ID conflicts detected', + conflicts: conflicts, + message: `The following model IDs are already used by other service providers: ${conflicts.map(c => `${c.modelId} (used by ${c.existingServiceProvider} in ${c.existingPlan})`).join(', ')}` + }, { status: 409 }) + } + // 批量创建模型记录 const results = [] for (const modelData of selectedModels) { diff --git a/src/app/api/simulator/[id]/execute/route.ts b/src/app/api/simulator/[id]/execute/route.ts index dfe1fda..3ca5085 100644 --- a/src/app/api/simulator/[id]/execute/route.ts +++ b/src/app/api/simulator/[id]/execute/route.ts @@ -3,6 +3,7 @@ import { createServerSupabaseClient } from "@/lib/supabase-server"; import { prisma } from "@/lib/prisma"; import { getPromptContent, calculateCost } from "@/lib/simulator-utils"; import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit"; +import { UniAPIService } from "@/lib/uniapi"; export async function POST( request: NextRequest, @@ -58,40 +59,115 @@ export async function POST( const promptContent = getPromptContent(run); const finalPrompt = `${promptContent}\n\nUser Input: ${run.userInput}`; - const requestBody = { - model: run.model.modelId, - messages: [ - { - role: "user", - content: finalPrompt, - } - ], - temperature: run.temperature || 0.7, - ...(run.maxTokens && { max_tokens: run.maxTokens }), - ...(run.topP && { top_p: run.topP }), - ...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }), - ...(run.presencePenalty && { presence_penalty: run.presencePenalty }), - stream: true, - }; + let apiResponse: Response; + + // 根据服务提供商选择不同的API + if (run.model.serviceProvider === 'openrouter') { + const requestBody = { + model: run.model.modelId, + messages: [ + { + role: "user", + content: finalPrompt, + } + ], + temperature: run.temperature || 0.7, + ...(run.maxTokens && { max_tokens: run.maxTokens }), + ...(run.topP && { top_p: run.topP }), + ...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }), + ...(run.presencePenalty && { presence_penalty: run.presencePenalty }), + stream: true, + }; - const openRouterResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", { - method: "POST", - headers: { - "Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`, - "Content-Type": "application/json", - "HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000", - "X-Title": "Prmbr - AI Prompt Studio", - }, - body: JSON.stringify(requestBody), - }); + apiResponse = await fetch("https://openrouter.ai/api/v1/chat/completions", { + method: "POST", + headers: { + "Authorization": `Bearer ${process.env.OPENROUTER_API_KEY}`, + "Content-Type": "application/json", + "HTTP-Referer": process.env.NEXT_PUBLIC_APP_URL || "http://localhost:3000", + "X-Title": "Prmbr - AI Prompt Studio", + }, + body: JSON.stringify(requestBody), + }); + } else if (run.model.serviceProvider === 'uniapi') { + const uniAPIService = new UniAPIService(); + + if (run.model.outputType === 'text' || run.model.outputType === 'multimodal') { + // 使用聊天完成API + const requestBody = { + model: run.model.modelId, + messages: [ + { + role: "user", + content: finalPrompt, + } + ], + temperature: run.temperature || 0.7, + ...(run.maxTokens && { max_tokens: run.maxTokens }), + ...(run.topP && { top_p: run.topP }), + ...(run.frequencyPenalty && { frequency_penalty: run.frequencyPenalty }), + ...(run.presencePenalty && { presence_penalty: run.presencePenalty }), + }; - if (!openRouterResponse.ok) { - const errorText = await openRouterResponse.text(); + // 注意:UniAPI 可能不支持流式响应,这里需要调整 + const response = await uniAPIService.createChatCompletion(requestBody); + + // 创建模拟的流式响应 + const mockStream = new ReadableStream({ + start(controller) { + const content = response.choices?.[0]?.message?.content || ''; + const usage = response.usage || { prompt_tokens: 0, completion_tokens: 0 }; + + // 模拟流式数据 + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + choices: [{ + delta: { content: content } + }] + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + usage: usage + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`)); + controller.close(); + } + }); + + apiResponse = new Response(mockStream, { + headers: { + 'Content-Type': 'text/event-stream', + } + }); + } else { + // 对于非文本模型,返回错误 + await prisma.simulatorRun.update({ + where: { id }, + data: { + status: "failed", + error: `Unsupported model type: ${run.model.outputType}`, + }, + }); + return NextResponse.json({ error: "Unsupported model type" }, { status: 400 }); + } + } else { await prisma.simulatorRun.update({ where: { id }, data: { status: "failed", - error: `OpenRouter API error: ${openRouterResponse.status} - ${errorText}`, + error: `Unsupported service provider: ${run.model.serviceProvider}`, + }, + }); + return NextResponse.json({ error: "Unsupported service provider" }, { status: 400 }); + } + + if (!apiResponse.ok) { + const errorText = await apiResponse.text(); + await prisma.simulatorRun.update({ + where: { id }, + data: { + status: "failed", + error: `API error: ${apiResponse.status} - ${errorText}`, }, }); return NextResponse.json({ error: "AI API request failed" }, { status: 500 }); @@ -100,7 +176,7 @@ export async function POST( // 创建流式响应 const stream = new ReadableStream({ async start(controller) { - const reader = openRouterResponse.body?.getReader(); + const reader = apiResponse.body?.getReader(); if (!reader) { controller.close(); return; diff --git a/src/app/api/simulator/models/route.ts b/src/app/api/simulator/models/route.ts index 9369e02..198d8b4 100644 --- a/src/app/api/simulator/models/route.ts +++ b/src/app/api/simulator/models/route.ts @@ -38,6 +38,8 @@ export async function GET() { modelId: model.modelId, name: model.name, provider: model.provider, + serviceProvider: model.serviceProvider, + outputType: model.outputType, description: model.description, maxTokens: model.maxTokens, inputCostPer1k: model.inputCostPer1k, diff --git a/src/lib/uniapi.ts b/src/lib/uniapi.ts new file mode 100644 index 0000000..d919368 --- /dev/null +++ b/src/lib/uniapi.ts @@ -0,0 +1,204 @@ +interface UniAPIModel { + id: string + name: string + description?: string + provider?: string + context_length?: number + pricing?: { + input: number + output: number + } + features?: string[] + type?: 'text' | 'image' | 'audio' | 'video' | 'multimodal' +} + +interface UniAPIResponse { + data: UniAPIModel[] + error?: string +} + +export class UniAPIService { + private apiKey: string + private baseUrl = 'https://api.uniapi.io/v1' + + constructor() { + this.apiKey = process.env.UNIAPI_API_KEY || '' + if (!this.apiKey) { + throw new Error('UNIAPI_API_KEY environment variable is required') + } + } + + async getAvailableModels(): Promise { + try { + const response = await fetch(`${this.baseUrl}/models`, { + headers: { + 'Authorization': `Bearer ${this.apiKey}`, + 'Content-Type': 'application/json', + }, + }) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(`UniAPI error: ${response.status} ${response.statusText}`) + } + + const data = await response.json() + + // 检查不同可能的响应格式 + let models: UniAPIModel[] = [] + + if (data.data && Array.isArray(data.data)) { + models = data.data + } else if (Array.isArray(data)) { + models = data + } else if (data.models && Array.isArray(data.models)) { + models = data.models + } else { + throw new Error('Unexpected response format from UniAPI') + } + + if (data.error) { + throw new Error(`UniAPI API error: ${data.error}`) + } + + return models + } catch (error) { + console.error('Error fetching UniAPI models:', error) + throw error + } + } + + // 将 UniAPI 模型转换为我们数据库的格式 + transformModelForDB(model: UniAPIModel) { + // 确保模型有基本的属性 + if (!model.id) { + return null + } + + return { + modelId: model.id, + name: model.name || model.id || 'Unnamed Model', // 如果没有名称,使用模型ID作为后备 + provider: model.provider || this.extractProvider(model.id), + serviceProvider: 'uniapi', + outputType: model.type || 'text', + description: model.description || null, + maxTokens: model.context_length || null, + inputCostPer1k: model.pricing?.input ? model.pricing.input * 1000 : null, + outputCostPer1k: model.pricing?.output ? model.pricing.output * 1000 : null, + supportedFeatures: { + type: model.type, + features: model.features || [], + }, + metadata: { + original: model, + }, + } + } + + private extractProvider(modelId: string): string { + // 从模型 ID 中提取提供商名称 + const providerMap: Record = { + 'openai': 'OpenAI', + 'anthropic': 'Anthropic', + 'claude': 'Anthropic', + 'google': 'Google', + 'gemini': 'Google', + 'meta': 'Meta', + 'llama': 'Meta', + 'microsoft': 'Microsoft', + 'mistral': 'Mistral AI', + 'cohere': 'Cohere', + 'stability': 'Stability AI', + 'midjourney': 'Midjourney', + 'dall-e': 'OpenAI', + 'gpt': 'OpenAI', + } + + const modelLower = modelId.toLowerCase() + + // 尝试从 ID 中匹配提供商 + for (const [key, value] of Object.entries(providerMap)) { + if (modelLower.includes(key)) { + return value + } + } + + // 如果包含斜杠,取斜杠前的部分 + if (modelId.includes('/')) { + const provider = modelId.split('/')[0] + return providerMap[provider.toLowerCase()] || + provider.charAt(0).toUpperCase() + provider.slice(1) + } + + // 如果包含连字符,取第一个部分 + if (modelId.includes('-')) { + const provider = modelId.split('-')[0] + return providerMap[provider.toLowerCase()] || + provider.charAt(0).toUpperCase() + provider.slice(1) + } + + return 'UniAPI' + } + + // 执行聊天完成请求 + async createChatCompletion(params: { + model: string + messages: Array<{role: string, content: string}> + temperature?: number + max_tokens?: number + top_p?: number + frequency_penalty?: number + presence_penalty?: number + }) { + try { + const response = await fetch(`${this.baseUrl}/chat/completions`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${this.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify(params), + }) + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})) + throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`) + } + + return await response.json() + } catch (error) { + console.error('Error calling UniAPI chat completion:', error) + throw error + } + } + + // 执行图像生成请求 + async createImageGeneration(params: { + model: string + prompt: string + size?: string + quality?: string + n?: number + }) { + try { + const response = await fetch(`${this.baseUrl}/images/generations`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${this.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify(params), + }) + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})) + throw new Error(`UniAPI API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`) + } + + return await response.json() + } catch (error) { + console.error('Error calling UniAPI image generation:', error) + throw error + } + } +} \ No newline at end of file