add fal.ai

This commit is contained in:
songtianlun 2025-08-27 23:37:14 +08:00
parent 947b0d08d3
commit 907f33a794
4 changed files with 401 additions and 2 deletions

View File

@ -74,7 +74,7 @@ export default function AdminModelsPage() {
const [selectedPlan, setSelectedPlan] = useState<string>('')
const [selectedModels, setSelectedModels] = useState<string[]>([])
const [showAvailableModels, setShowAvailableModels] = useState(false)
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi'>('openrouter')
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi' | 'fal'>('openrouter')
useEffect(() => {
loadInitialData()
@ -309,7 +309,7 @@ export default function AdminModelsPage() {
<select
value={selectedServiceProvider}
onChange={(e) => {
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate' | 'uniapi')
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate' | 'uniapi' | 'fal')
setShowAvailableModels(false)
setSelectedModels([])
}}
@ -318,6 +318,7 @@ export default function AdminModelsPage() {
<option value="openrouter">OpenRouter (Text Models)</option>
<option value="replicate">Replicate (Image/Video/Audio Models)</option>
<option value="uniapi">UniAPI (Multi-modal Models)</option>
<option value="fal">Fal.ai (AI Generation Models)</option>
</select>
</div>
)}

View File

@ -3,6 +3,7 @@ import { prisma } from '@/lib/prisma'
import { OpenRouterService } from '@/lib/openrouter'
import { ReplicateService } from '@/lib/replicate'
import { UniAPIService } from '@/lib/uniapi'
import { FalService } from '@/lib/fal'
// GET /api/admin/models - 获取所有模型(按套餐分组)
export async function GET() {
@ -103,6 +104,13 @@ export async function POST(request: NextRequest) {
availableModels = uniAPIModels
.map(model => uniAPIService.transformModelForDB(model))
.filter(model => model !== null) // 过滤掉可能的 null 值
} else if (provider === 'fal') {
// 从 Fal.ai 获取可用模型
const falService = new FalService()
const falModels = await falService.getAvailableModels()
availableModels = falModels
.map(model => falService.transformModelForDB(model))
.filter(model => model !== null) // 过滤掉可能的 null 值
}
return NextResponse.json({

View File

@ -4,6 +4,7 @@ import { prisma } from "@/lib/prisma";
import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
import { UniAPIService } from "@/lib/uniapi";
import { FalService } from "@/lib/fal";
export async function POST(
request: NextRequest,
@ -150,6 +151,96 @@ export async function POST(
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else if (run.model.serviceProvider === 'fal') {
const falService = new FalService();
if (run.model.outputType === 'image') {
// 使用图像生成API
const response = await falService.generateImage({
model: run.model.modelId,
prompt: finalPrompt,
num_images: 1,
});
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const imageUrl = response.images?.[0]?.url || '';
const result = {
images: response.images,
prompt: finalPrompt,
model: run.model.modelId
};
// 模拟流式数据
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: `Generated image: ${imageUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else if (run.model.outputType === 'video') {
// 使用视频生成API
const response = await falService.generateVideo({
model: run.model.modelId,
prompt: finalPrompt,
});
// 创建模拟的流式响应
const mockStream = new ReadableStream({
start(controller) {
const videoUrl = response.video?.url || '';
const result = {
video: response.video,
prompt: finalPrompt,
model: run.model.modelId
};
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
choices: [{
delta: { content: `Generated video: ${videoUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
}]
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
usage: { prompt_tokens: 50, completion_tokens: 150 } // 估算的token使用量
})}\n\n`));
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
controller.close();
}
});
apiResponse = new Response(mockStream, {
headers: {
'Content-Type': 'text/event-stream',
}
});
} else {
// 对于其他类型,返回错误
await prisma.simulatorRun.update({
where: { id },
data: {
status: "failed",
error: `Unsupported Fal.ai model type: ${run.model.outputType}`,
},
});
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
}
} else {
await prisma.simulatorRun.update({
where: { id },

299
src/lib/fal.ts Normal file
View File

@ -0,0 +1,299 @@
interface FalModel {
id: string
name?: string
description?: string
category?: string
type?: 'image' | 'video' | 'audio' | 'text'
input_schema?: {
properties?: Record<string, any>
}
pricing?: {
per_request?: number
per_second?: number
per_image?: number
}
max_resolution?: string
supported_formats?: string[]
[key: string]: any
}
interface FalModelsResponse {
models: FalModel[]
error?: string
}
export class FalService {
private apiKey: string
private baseUrl = 'https://fal.run/api/v1'
constructor() {
this.apiKey = process.env.FAL_API_KEY || ''
if (!this.apiKey) {
throw new Error('FAL_API_KEY environment variable is required')
}
}
async getAvailableModels(): Promise<FalModel[]> {
try {
// Fal.ai 可能没有公开的模型列表 API我们可以定义一些常用的模型
// 这里基于 Fal.ai 的文档和常见模型创建一个静态列表
const knownModels: FalModel[] = [
{
id: 'fal-ai/stable-diffusion-xl',
name: 'Stable Diffusion XL',
description: 'High-quality text-to-image generation with SDXL',
category: 'image-generation',
type: 'image',
max_resolution: '1024x1024',
pricing: {
per_image: 0.012
}
},
{
id: 'fal-ai/flux',
name: 'FLUX.1 [schnell]',
description: 'Fast, high-quality image generation',
category: 'image-generation',
type: 'image',
max_resolution: '1024x1024',
pricing: {
per_image: 0.003
}
},
{
id: 'fal-ai/flux-pro',
name: 'FLUX.1 [pro]',
description: 'Professional-grade image generation',
category: 'image-generation',
type: 'image',
max_resolution: '2048x2048',
pricing: {
per_image: 0.055
}
},
{
id: 'fal-ai/stable-video-diffusion',
name: 'Stable Video Diffusion',
description: 'Generate videos from images or text',
category: 'video-generation',
type: 'video',
pricing: {
per_second: 0.1
}
},
{
id: 'fal-ai/aura-sr',
name: 'AuraSR',
description: 'AI-powered image super resolution',
category: 'image-enhancement',
type: 'image',
pricing: {
per_image: 0.015
}
},
{
id: 'fal-ai/face-swap',
name: 'Face Swap',
description: 'Swap faces in images using AI',
category: 'image-editing',
type: 'image',
pricing: {
per_image: 0.05
}
},
{
id: 'fal-ai/remove-background',
name: 'Background Removal',
description: 'Remove backgrounds from images automatically',
category: 'image-editing',
type: 'image',
pricing: {
per_image: 0.01
}
},
{
id: 'fal-ai/lora-image-generation',
name: 'LoRA Image Generation',
description: 'Fine-tuned image generation with LoRA models',
category: 'image-generation',
type: 'image',
pricing: {
per_image: 0.025
}
},
{
id: 'fal-ai/realtime-stable-diffusion',
name: 'Realtime Stable Diffusion',
description: 'Fast, real-time image generation',
category: 'image-generation',
type: 'image',
pricing: {
per_image: 0.008
}
},
{
id: 'fal-ai/photomaker',
name: 'PhotoMaker',
description: 'Generate personalized photos with AI',
category: 'image-generation',
type: 'image',
pricing: {
per_image: 0.04
}
}
]
console.log(`Fal.ai returned ${knownModels.length} predefined models`)
return knownModels
} catch (error) {
console.error('Error fetching Fal models:', error)
throw error
}
}
// 将 Fal.ai 模型转换为我们数据库的格式
transformModelForDB(model: FalModel) {
if (!model.id) {
return null
}
const modelName = model.name || model.id.split('/').pop() || model.id
const provider = this.extractProvider(model.id)
// 根据定价结构计算每1k的成本
let inputCostPer1k = null
let outputCostPer1k = null
if (model.pricing) {
if (model.pricing.per_image) {
// 图像生成假设1k token约等于1张图片
inputCostPer1k = model.pricing.per_image * 1000
outputCostPer1k = model.pricing.per_image * 1000
} else if (model.pricing.per_second) {
// 视频生成假设1k token约等于10秒视频
inputCostPer1k = model.pricing.per_second * 10000
outputCostPer1k = model.pricing.per_second * 10000
} else if (model.pricing.per_request) {
inputCostPer1k = model.pricing.per_request * 1000
outputCostPer1k = model.pricing.per_request * 1000
}
}
return {
modelId: model.id,
name: modelName,
provider: provider,
serviceProvider: 'fal',
outputType: model.type || 'image',
description: model.description || null,
maxTokens: null, // Fal.ai 模型通常不使用 token 限制
inputCostPer1k: inputCostPer1k,
outputCostPer1k: outputCostPer1k,
supportedFeatures: {
type: model.type,
category: model.category,
max_resolution: model.max_resolution,
supported_formats: model.supported_formats,
},
metadata: {
original: model,
pricing_model: model.pricing
},
}
}
private extractProvider(modelId: string): string {
// Fal.ai 模型ID格式通常是 fal-ai/model-name
const parts = modelId.split('/')
if (parts.length > 1) {
const provider = parts[0]
const providerMap: Record<string, string> = {
'fal-ai': 'Fal.ai',
'stabilityai': 'Stability AI',
'runwayml': 'Runway ML',
'openai': 'OpenAI',
}
return providerMap[provider] || provider.charAt(0).toUpperCase() + provider.slice(1)
}
return 'Fal.ai'
}
// 执行图像生成请求
async generateImage(params: {
model: string
prompt: string
image_size?: string
num_inference_steps?: number
guidance_scale?: number
num_images?: number
seed?: number
}) {
try {
const response = await fetch(`${this.baseUrl}/fal-ai/${params.model.replace('fal-ai/', '')}`, {
method: 'POST',
headers: {
'Authorization': `Key ${this.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
prompt: params.prompt,
image_size: params.image_size || '1024x1024',
num_inference_steps: params.num_inference_steps || 25,
guidance_scale: params.guidance_scale || 7.5,
num_images: params.num_images || 1,
...(params.seed && { seed: params.seed }),
}),
})
if (!response.ok) {
const errorData = await response.json().catch(() => ({}))
throw new Error(`Fal.ai API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
}
return await response.json()
} catch (error) {
console.error('Error calling Fal.ai image generation:', error)
throw error
}
}
// 执行视频生成请求
async generateVideo(params: {
model: string
prompt?: string
image_url?: string
motion_bucket_id?: number
fps?: number
num_frames?: number
}) {
try {
const response = await fetch(`${this.baseUrl}/fal-ai/${params.model.replace('fal-ai/', '')}`, {
method: 'POST',
headers: {
'Authorization': `Key ${this.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
prompt: params.prompt,
image_url: params.image_url,
motion_bucket_id: params.motion_bucket_id || 127,
fps: params.fps || 6,
num_frames: params.num_frames || 25,
}),
})
if (!response.ok) {
const errorData = await response.json().catch(() => ({}))
throw new Error(`Fal.ai API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
}
return await response.json()
} catch (error) {
console.error('Error calling Fal.ai video generation:', error)
throw error
}
}
}