From 907f33a794fc4b43be4aa73fb414339da0281d49 Mon Sep 17 00:00:00 2001 From: songtianlun Date: Wed, 27 Aug 2025 23:37:14 +0800 Subject: [PATCH] add fal.ai --- src/app/admin/models/page.tsx | 5 +- src/app/api/admin/models/route.ts | 8 + src/app/api/simulator/[id]/execute/route.ts | 91 ++++++ src/lib/fal.ts | 299 ++++++++++++++++++++ 4 files changed, 401 insertions(+), 2 deletions(-) create mode 100644 src/lib/fal.ts diff --git a/src/app/admin/models/page.tsx b/src/app/admin/models/page.tsx index 2a8e702..cc08a36 100644 --- a/src/app/admin/models/page.tsx +++ b/src/app/admin/models/page.tsx @@ -74,7 +74,7 @@ export default function AdminModelsPage() { const [selectedPlan, setSelectedPlan] = useState('') const [selectedModels, setSelectedModels] = useState([]) const [showAvailableModels, setShowAvailableModels] = useState(false) - const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi'>('openrouter') + const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi' | 'fal'>('openrouter') useEffect(() => { loadInitialData() @@ -309,7 +309,7 @@ export default function AdminModelsPage() { )} diff --git a/src/app/api/admin/models/route.ts b/src/app/api/admin/models/route.ts index 8ea2fc4..b08322c 100644 --- a/src/app/api/admin/models/route.ts +++ b/src/app/api/admin/models/route.ts @@ -3,6 +3,7 @@ import { prisma } from '@/lib/prisma' import { OpenRouterService } from '@/lib/openrouter' import { ReplicateService } from '@/lib/replicate' import { UniAPIService } from '@/lib/uniapi' +import { FalService } from '@/lib/fal' // GET /api/admin/models - 获取所有模型(按套餐分组) export async function GET() { @@ -103,6 +104,13 @@ export async function POST(request: NextRequest) { availableModels = uniAPIModels .map(model => uniAPIService.transformModelForDB(model)) .filter(model => model !== null) // 过滤掉可能的 null 值 + } else if (provider === 'fal') { + // 从 Fal.ai 获取可用模型 + const falService = new FalService() + const falModels = await falService.getAvailableModels() + availableModels = falModels + .map(model => falService.transformModelForDB(model)) + .filter(model => model !== null) // 过滤掉可能的 null 值 } return NextResponse.json({ diff --git a/src/app/api/simulator/[id]/execute/route.ts b/src/app/api/simulator/[id]/execute/route.ts index 3ca5085..e2f6c87 100644 --- a/src/app/api/simulator/[id]/execute/route.ts +++ b/src/app/api/simulator/[id]/execute/route.ts @@ -4,6 +4,7 @@ import { prisma } from "@/lib/prisma"; import { getPromptContent, calculateCost } from "@/lib/simulator-utils"; import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit"; import { UniAPIService } from "@/lib/uniapi"; +import { FalService } from "@/lib/fal"; export async function POST( request: NextRequest, @@ -150,6 +151,96 @@ export async function POST( }); return NextResponse.json({ error: "Unsupported model type" }, { status: 400 }); } + } else if (run.model.serviceProvider === 'fal') { + const falService = new FalService(); + + if (run.model.outputType === 'image') { + // 使用图像生成API + const response = await falService.generateImage({ + model: run.model.modelId, + prompt: finalPrompt, + num_images: 1, + }); + + // 创建模拟的流式响应 + const mockStream = new ReadableStream({ + start(controller) { + const imageUrl = response.images?.[0]?.url || ''; + const result = { + images: response.images, + prompt: finalPrompt, + model: run.model.modelId + }; + + // 模拟流式数据 + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + choices: [{ + delta: { content: `Generated image: ${imageUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` } + }] + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量 + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`)); + controller.close(); + } + }); + + apiResponse = new Response(mockStream, { + headers: { + 'Content-Type': 'text/event-stream', + } + }); + } else if (run.model.outputType === 'video') { + // 使用视频生成API + const response = await falService.generateVideo({ + model: run.model.modelId, + prompt: finalPrompt, + }); + + // 创建模拟的流式响应 + const mockStream = new ReadableStream({ + start(controller) { + const videoUrl = response.video?.url || ''; + const result = { + video: response.video, + prompt: finalPrompt, + model: run.model.modelId + }; + + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + choices: [{ + delta: { content: `Generated video: ${videoUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` } + }] + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({ + usage: { prompt_tokens: 50, completion_tokens: 150 } // 估算的token使用量 + })}\n\n`)); + + controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`)); + controller.close(); + } + }); + + apiResponse = new Response(mockStream, { + headers: { + 'Content-Type': 'text/event-stream', + } + }); + } else { + // 对于其他类型,返回错误 + await prisma.simulatorRun.update({ + where: { id }, + data: { + status: "failed", + error: `Unsupported Fal.ai model type: ${run.model.outputType}`, + }, + }); + return NextResponse.json({ error: "Unsupported model type" }, { status: 400 }); + } } else { await prisma.simulatorRun.update({ where: { id }, diff --git a/src/lib/fal.ts b/src/lib/fal.ts new file mode 100644 index 0000000..f9fd810 --- /dev/null +++ b/src/lib/fal.ts @@ -0,0 +1,299 @@ +interface FalModel { + id: string + name?: string + description?: string + category?: string + type?: 'image' | 'video' | 'audio' | 'text' + input_schema?: { + properties?: Record + } + pricing?: { + per_request?: number + per_second?: number + per_image?: number + } + max_resolution?: string + supported_formats?: string[] + [key: string]: any +} + +interface FalModelsResponse { + models: FalModel[] + error?: string +} + +export class FalService { + private apiKey: string + private baseUrl = 'https://fal.run/api/v1' + + constructor() { + this.apiKey = process.env.FAL_API_KEY || '' + if (!this.apiKey) { + throw new Error('FAL_API_KEY environment variable is required') + } + } + + async getAvailableModels(): Promise { + try { + // Fal.ai 可能没有公开的模型列表 API,我们可以定义一些常用的模型 + // 这里基于 Fal.ai 的文档和常见模型创建一个静态列表 + const knownModels: FalModel[] = [ + { + id: 'fal-ai/stable-diffusion-xl', + name: 'Stable Diffusion XL', + description: 'High-quality text-to-image generation with SDXL', + category: 'image-generation', + type: 'image', + max_resolution: '1024x1024', + pricing: { + per_image: 0.012 + } + }, + { + id: 'fal-ai/flux', + name: 'FLUX.1 [schnell]', + description: 'Fast, high-quality image generation', + category: 'image-generation', + type: 'image', + max_resolution: '1024x1024', + pricing: { + per_image: 0.003 + } + }, + { + id: 'fal-ai/flux-pro', + name: 'FLUX.1 [pro]', + description: 'Professional-grade image generation', + category: 'image-generation', + type: 'image', + max_resolution: '2048x2048', + pricing: { + per_image: 0.055 + } + }, + { + id: 'fal-ai/stable-video-diffusion', + name: 'Stable Video Diffusion', + description: 'Generate videos from images or text', + category: 'video-generation', + type: 'video', + pricing: { + per_second: 0.1 + } + }, + { + id: 'fal-ai/aura-sr', + name: 'AuraSR', + description: 'AI-powered image super resolution', + category: 'image-enhancement', + type: 'image', + pricing: { + per_image: 0.015 + } + }, + { + id: 'fal-ai/face-swap', + name: 'Face Swap', + description: 'Swap faces in images using AI', + category: 'image-editing', + type: 'image', + pricing: { + per_image: 0.05 + } + }, + { + id: 'fal-ai/remove-background', + name: 'Background Removal', + description: 'Remove backgrounds from images automatically', + category: 'image-editing', + type: 'image', + pricing: { + per_image: 0.01 + } + }, + { + id: 'fal-ai/lora-image-generation', + name: 'LoRA Image Generation', + description: 'Fine-tuned image generation with LoRA models', + category: 'image-generation', + type: 'image', + pricing: { + per_image: 0.025 + } + }, + { + id: 'fal-ai/realtime-stable-diffusion', + name: 'Realtime Stable Diffusion', + description: 'Fast, real-time image generation', + category: 'image-generation', + type: 'image', + pricing: { + per_image: 0.008 + } + }, + { + id: 'fal-ai/photomaker', + name: 'PhotoMaker', + description: 'Generate personalized photos with AI', + category: 'image-generation', + type: 'image', + pricing: { + per_image: 0.04 + } + } + ] + + console.log(`Fal.ai returned ${knownModels.length} predefined models`) + return knownModels + } catch (error) { + console.error('Error fetching Fal models:', error) + throw error + } + } + + // 将 Fal.ai 模型转换为我们数据库的格式 + transformModelForDB(model: FalModel) { + if (!model.id) { + return null + } + + const modelName = model.name || model.id.split('/').pop() || model.id + const provider = this.extractProvider(model.id) + + // 根据定价结构计算每1k的成本 + let inputCostPer1k = null + let outputCostPer1k = null + + if (model.pricing) { + if (model.pricing.per_image) { + // 图像生成:假设1k token约等于1张图片 + inputCostPer1k = model.pricing.per_image * 1000 + outputCostPer1k = model.pricing.per_image * 1000 + } else if (model.pricing.per_second) { + // 视频生成:假设1k token约等于10秒视频 + inputCostPer1k = model.pricing.per_second * 10000 + outputCostPer1k = model.pricing.per_second * 10000 + } else if (model.pricing.per_request) { + inputCostPer1k = model.pricing.per_request * 1000 + outputCostPer1k = model.pricing.per_request * 1000 + } + } + + return { + modelId: model.id, + name: modelName, + provider: provider, + serviceProvider: 'fal', + outputType: model.type || 'image', + description: model.description || null, + maxTokens: null, // Fal.ai 模型通常不使用 token 限制 + inputCostPer1k: inputCostPer1k, + outputCostPer1k: outputCostPer1k, + supportedFeatures: { + type: model.type, + category: model.category, + max_resolution: model.max_resolution, + supported_formats: model.supported_formats, + }, + metadata: { + original: model, + pricing_model: model.pricing + }, + } + } + + private extractProvider(modelId: string): string { + // Fal.ai 模型ID格式通常是 fal-ai/model-name + const parts = modelId.split('/') + if (parts.length > 1) { + const provider = parts[0] + + const providerMap: Record = { + 'fal-ai': 'Fal.ai', + 'stabilityai': 'Stability AI', + 'runwayml': 'Runway ML', + 'openai': 'OpenAI', + } + + return providerMap[provider] || provider.charAt(0).toUpperCase() + provider.slice(1) + } + + return 'Fal.ai' + } + + // 执行图像生成请求 + async generateImage(params: { + model: string + prompt: string + image_size?: string + num_inference_steps?: number + guidance_scale?: number + num_images?: number + seed?: number + }) { + try { + const response = await fetch(`${this.baseUrl}/fal-ai/${params.model.replace('fal-ai/', '')}`, { + method: 'POST', + headers: { + 'Authorization': `Key ${this.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + prompt: params.prompt, + image_size: params.image_size || '1024x1024', + num_inference_steps: params.num_inference_steps || 25, + guidance_scale: params.guidance_scale || 7.5, + num_images: params.num_images || 1, + ...(params.seed && { seed: params.seed }), + }), + }) + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})) + throw new Error(`Fal.ai API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`) + } + + return await response.json() + } catch (error) { + console.error('Error calling Fal.ai image generation:', error) + throw error + } + } + + // 执行视频生成请求 + async generateVideo(params: { + model: string + prompt?: string + image_url?: string + motion_bucket_id?: number + fps?: number + num_frames?: number + }) { + try { + const response = await fetch(`${this.baseUrl}/fal-ai/${params.model.replace('fal-ai/', '')}`, { + method: 'POST', + headers: { + 'Authorization': `Key ${this.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + prompt: params.prompt, + image_url: params.image_url, + motion_bucket_id: params.motion_bucket_id || 127, + fps: params.fps || 6, + num_frames: params.num_frames || 25, + }), + }) + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})) + throw new Error(`Fal.ai API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`) + } + + return await response.json() + } catch (error) { + console.error('Error calling Fal.ai video generation:', error) + throw error + } + } +} \ No newline at end of file