add fal.ai
This commit is contained in:
parent
947b0d08d3
commit
907f33a794
@ -74,7 +74,7 @@ export default function AdminModelsPage() {
|
||||
const [selectedPlan, setSelectedPlan] = useState<string>('')
|
||||
const [selectedModels, setSelectedModels] = useState<string[]>([])
|
||||
const [showAvailableModels, setShowAvailableModels] = useState(false)
|
||||
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi'>('openrouter')
|
||||
const [selectedServiceProvider, setSelectedServiceProvider] = useState<'openrouter' | 'replicate' | 'uniapi' | 'fal'>('openrouter')
|
||||
|
||||
useEffect(() => {
|
||||
loadInitialData()
|
||||
@ -309,7 +309,7 @@ export default function AdminModelsPage() {
|
||||
<select
|
||||
value={selectedServiceProvider}
|
||||
onChange={(e) => {
|
||||
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate' | 'uniapi')
|
||||
setSelectedServiceProvider(e.target.value as 'openrouter' | 'replicate' | 'uniapi' | 'fal')
|
||||
setShowAvailableModels(false)
|
||||
setSelectedModels([])
|
||||
}}
|
||||
@ -318,6 +318,7 @@ export default function AdminModelsPage() {
|
||||
<option value="openrouter">OpenRouter (Text Models)</option>
|
||||
<option value="replicate">Replicate (Image/Video/Audio Models)</option>
|
||||
<option value="uniapi">UniAPI (Multi-modal Models)</option>
|
||||
<option value="fal">Fal.ai (AI Generation Models)</option>
|
||||
</select>
|
||||
</div>
|
||||
)}
|
||||
|
@ -3,6 +3,7 @@ import { prisma } from '@/lib/prisma'
|
||||
import { OpenRouterService } from '@/lib/openrouter'
|
||||
import { ReplicateService } from '@/lib/replicate'
|
||||
import { UniAPIService } from '@/lib/uniapi'
|
||||
import { FalService } from '@/lib/fal'
|
||||
|
||||
// GET /api/admin/models - 获取所有模型(按套餐分组)
|
||||
export async function GET() {
|
||||
@ -103,6 +104,13 @@ export async function POST(request: NextRequest) {
|
||||
availableModels = uniAPIModels
|
||||
.map(model => uniAPIService.transformModelForDB(model))
|
||||
.filter(model => model !== null) // 过滤掉可能的 null 值
|
||||
} else if (provider === 'fal') {
|
||||
// 从 Fal.ai 获取可用模型
|
||||
const falService = new FalService()
|
||||
const falModels = await falService.getAvailableModels()
|
||||
availableModels = falModels
|
||||
.map(model => falService.transformModelForDB(model))
|
||||
.filter(model => model !== null) // 过滤掉可能的 null 值
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
|
@ -4,6 +4,7 @@ import { prisma } from "@/lib/prisma";
|
||||
import { getPromptContent, calculateCost } from "@/lib/simulator-utils";
|
||||
import { consumeCreditForSimulation, getUserBalance } from "@/lib/services/credit";
|
||||
import { UniAPIService } from "@/lib/uniapi";
|
||||
import { FalService } from "@/lib/fal";
|
||||
|
||||
export async function POST(
|
||||
request: NextRequest,
|
||||
@ -150,6 +151,96 @@ export async function POST(
|
||||
});
|
||||
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
|
||||
}
|
||||
} else if (run.model.serviceProvider === 'fal') {
|
||||
const falService = new FalService();
|
||||
|
||||
if (run.model.outputType === 'image') {
|
||||
// 使用图像生成API
|
||||
const response = await falService.generateImage({
|
||||
model: run.model.modelId,
|
||||
prompt: finalPrompt,
|
||||
num_images: 1,
|
||||
});
|
||||
|
||||
// 创建模拟的流式响应
|
||||
const mockStream = new ReadableStream({
|
||||
start(controller) {
|
||||
const imageUrl = response.images?.[0]?.url || '';
|
||||
const result = {
|
||||
images: response.images,
|
||||
prompt: finalPrompt,
|
||||
model: run.model.modelId
|
||||
};
|
||||
|
||||
// 模拟流式数据
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
choices: [{
|
||||
delta: { content: `Generated image: ${imageUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
|
||||
}]
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
usage: { prompt_tokens: 50, completion_tokens: 100 } // 估算的token使用量
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||||
controller.close();
|
||||
}
|
||||
});
|
||||
|
||||
apiResponse = new Response(mockStream, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
}
|
||||
});
|
||||
} else if (run.model.outputType === 'video') {
|
||||
// 使用视频生成API
|
||||
const response = await falService.generateVideo({
|
||||
model: run.model.modelId,
|
||||
prompt: finalPrompt,
|
||||
});
|
||||
|
||||
// 创建模拟的流式响应
|
||||
const mockStream = new ReadableStream({
|
||||
start(controller) {
|
||||
const videoUrl = response.video?.url || '';
|
||||
const result = {
|
||||
video: response.video,
|
||||
prompt: finalPrompt,
|
||||
model: run.model.modelId
|
||||
};
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
choices: [{
|
||||
delta: { content: `Generated video: ${videoUrl}\n\nResult: ${JSON.stringify(result, null, 2)}` }
|
||||
}]
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: ${JSON.stringify({
|
||||
usage: { prompt_tokens: 50, completion_tokens: 150 } // 估算的token使用量
|
||||
})}\n\n`));
|
||||
|
||||
controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`));
|
||||
controller.close();
|
||||
}
|
||||
});
|
||||
|
||||
apiResponse = new Response(mockStream, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 对于其他类型,返回错误
|
||||
await prisma.simulatorRun.update({
|
||||
where: { id },
|
||||
data: {
|
||||
status: "failed",
|
||||
error: `Unsupported Fal.ai model type: ${run.model.outputType}`,
|
||||
},
|
||||
});
|
||||
return NextResponse.json({ error: "Unsupported model type" }, { status: 400 });
|
||||
}
|
||||
} else {
|
||||
await prisma.simulatorRun.update({
|
||||
where: { id },
|
||||
|
299
src/lib/fal.ts
Normal file
299
src/lib/fal.ts
Normal file
@ -0,0 +1,299 @@
|
||||
interface FalModel {
|
||||
id: string
|
||||
name?: string
|
||||
description?: string
|
||||
category?: string
|
||||
type?: 'image' | 'video' | 'audio' | 'text'
|
||||
input_schema?: {
|
||||
properties?: Record<string, any>
|
||||
}
|
||||
pricing?: {
|
||||
per_request?: number
|
||||
per_second?: number
|
||||
per_image?: number
|
||||
}
|
||||
max_resolution?: string
|
||||
supported_formats?: string[]
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
interface FalModelsResponse {
|
||||
models: FalModel[]
|
||||
error?: string
|
||||
}
|
||||
|
||||
export class FalService {
|
||||
private apiKey: string
|
||||
private baseUrl = 'https://fal.run/api/v1'
|
||||
|
||||
constructor() {
|
||||
this.apiKey = process.env.FAL_API_KEY || ''
|
||||
if (!this.apiKey) {
|
||||
throw new Error('FAL_API_KEY environment variable is required')
|
||||
}
|
||||
}
|
||||
|
||||
async getAvailableModels(): Promise<FalModel[]> {
|
||||
try {
|
||||
// Fal.ai 可能没有公开的模型列表 API,我们可以定义一些常用的模型
|
||||
// 这里基于 Fal.ai 的文档和常见模型创建一个静态列表
|
||||
const knownModels: FalModel[] = [
|
||||
{
|
||||
id: 'fal-ai/stable-diffusion-xl',
|
||||
name: 'Stable Diffusion XL',
|
||||
description: 'High-quality text-to-image generation with SDXL',
|
||||
category: 'image-generation',
|
||||
type: 'image',
|
||||
max_resolution: '1024x1024',
|
||||
pricing: {
|
||||
per_image: 0.012
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/flux',
|
||||
name: 'FLUX.1 [schnell]',
|
||||
description: 'Fast, high-quality image generation',
|
||||
category: 'image-generation',
|
||||
type: 'image',
|
||||
max_resolution: '1024x1024',
|
||||
pricing: {
|
||||
per_image: 0.003
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/flux-pro',
|
||||
name: 'FLUX.1 [pro]',
|
||||
description: 'Professional-grade image generation',
|
||||
category: 'image-generation',
|
||||
type: 'image',
|
||||
max_resolution: '2048x2048',
|
||||
pricing: {
|
||||
per_image: 0.055
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/stable-video-diffusion',
|
||||
name: 'Stable Video Diffusion',
|
||||
description: 'Generate videos from images or text',
|
||||
category: 'video-generation',
|
||||
type: 'video',
|
||||
pricing: {
|
||||
per_second: 0.1
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/aura-sr',
|
||||
name: 'AuraSR',
|
||||
description: 'AI-powered image super resolution',
|
||||
category: 'image-enhancement',
|
||||
type: 'image',
|
||||
pricing: {
|
||||
per_image: 0.015
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/face-swap',
|
||||
name: 'Face Swap',
|
||||
description: 'Swap faces in images using AI',
|
||||
category: 'image-editing',
|
||||
type: 'image',
|
||||
pricing: {
|
||||
per_image: 0.05
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/remove-background',
|
||||
name: 'Background Removal',
|
||||
description: 'Remove backgrounds from images automatically',
|
||||
category: 'image-editing',
|
||||
type: 'image',
|
||||
pricing: {
|
||||
per_image: 0.01
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/lora-image-generation',
|
||||
name: 'LoRA Image Generation',
|
||||
description: 'Fine-tuned image generation with LoRA models',
|
||||
category: 'image-generation',
|
||||
type: 'image',
|
||||
pricing: {
|
||||
per_image: 0.025
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/realtime-stable-diffusion',
|
||||
name: 'Realtime Stable Diffusion',
|
||||
description: 'Fast, real-time image generation',
|
||||
category: 'image-generation',
|
||||
type: 'image',
|
||||
pricing: {
|
||||
per_image: 0.008
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'fal-ai/photomaker',
|
||||
name: 'PhotoMaker',
|
||||
description: 'Generate personalized photos with AI',
|
||||
category: 'image-generation',
|
||||
type: 'image',
|
||||
pricing: {
|
||||
per_image: 0.04
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
console.log(`Fal.ai returned ${knownModels.length} predefined models`)
|
||||
return knownModels
|
||||
} catch (error) {
|
||||
console.error('Error fetching Fal models:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 将 Fal.ai 模型转换为我们数据库的格式
|
||||
transformModelForDB(model: FalModel) {
|
||||
if (!model.id) {
|
||||
return null
|
||||
}
|
||||
|
||||
const modelName = model.name || model.id.split('/').pop() || model.id
|
||||
const provider = this.extractProvider(model.id)
|
||||
|
||||
// 根据定价结构计算每1k的成本
|
||||
let inputCostPer1k = null
|
||||
let outputCostPer1k = null
|
||||
|
||||
if (model.pricing) {
|
||||
if (model.pricing.per_image) {
|
||||
// 图像生成:假设1k token约等于1张图片
|
||||
inputCostPer1k = model.pricing.per_image * 1000
|
||||
outputCostPer1k = model.pricing.per_image * 1000
|
||||
} else if (model.pricing.per_second) {
|
||||
// 视频生成:假设1k token约等于10秒视频
|
||||
inputCostPer1k = model.pricing.per_second * 10000
|
||||
outputCostPer1k = model.pricing.per_second * 10000
|
||||
} else if (model.pricing.per_request) {
|
||||
inputCostPer1k = model.pricing.per_request * 1000
|
||||
outputCostPer1k = model.pricing.per_request * 1000
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
modelId: model.id,
|
||||
name: modelName,
|
||||
provider: provider,
|
||||
serviceProvider: 'fal',
|
||||
outputType: model.type || 'image',
|
||||
description: model.description || null,
|
||||
maxTokens: null, // Fal.ai 模型通常不使用 token 限制
|
||||
inputCostPer1k: inputCostPer1k,
|
||||
outputCostPer1k: outputCostPer1k,
|
||||
supportedFeatures: {
|
||||
type: model.type,
|
||||
category: model.category,
|
||||
max_resolution: model.max_resolution,
|
||||
supported_formats: model.supported_formats,
|
||||
},
|
||||
metadata: {
|
||||
original: model,
|
||||
pricing_model: model.pricing
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
private extractProvider(modelId: string): string {
|
||||
// Fal.ai 模型ID格式通常是 fal-ai/model-name
|
||||
const parts = modelId.split('/')
|
||||
if (parts.length > 1) {
|
||||
const provider = parts[0]
|
||||
|
||||
const providerMap: Record<string, string> = {
|
||||
'fal-ai': 'Fal.ai',
|
||||
'stabilityai': 'Stability AI',
|
||||
'runwayml': 'Runway ML',
|
||||
'openai': 'OpenAI',
|
||||
}
|
||||
|
||||
return providerMap[provider] || provider.charAt(0).toUpperCase() + provider.slice(1)
|
||||
}
|
||||
|
||||
return 'Fal.ai'
|
||||
}
|
||||
|
||||
// 执行图像生成请求
|
||||
async generateImage(params: {
|
||||
model: string
|
||||
prompt: string
|
||||
image_size?: string
|
||||
num_inference_steps?: number
|
||||
guidance_scale?: number
|
||||
num_images?: number
|
||||
seed?: number
|
||||
}) {
|
||||
try {
|
||||
const response = await fetch(`${this.baseUrl}/fal-ai/${params.model.replace('fal-ai/', '')}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Key ${this.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
prompt: params.prompt,
|
||||
image_size: params.image_size || '1024x1024',
|
||||
num_inference_steps: params.num_inference_steps || 25,
|
||||
guidance_scale: params.guidance_scale || 7.5,
|
||||
num_images: params.num_images || 1,
|
||||
...(params.seed && { seed: params.seed }),
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}))
|
||||
throw new Error(`Fal.ai API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
|
||||
}
|
||||
|
||||
return await response.json()
|
||||
} catch (error) {
|
||||
console.error('Error calling Fal.ai image generation:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 执行视频生成请求
|
||||
async generateVideo(params: {
|
||||
model: string
|
||||
prompt?: string
|
||||
image_url?: string
|
||||
motion_bucket_id?: number
|
||||
fps?: number
|
||||
num_frames?: number
|
||||
}) {
|
||||
try {
|
||||
const response = await fetch(`${this.baseUrl}/fal-ai/${params.model.replace('fal-ai/', '')}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Key ${this.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
prompt: params.prompt,
|
||||
image_url: params.image_url,
|
||||
motion_bucket_id: params.motion_bucket_id || 127,
|
||||
fps: params.fps || 6,
|
||||
num_frames: params.num_frames || 25,
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}))
|
||||
throw new Error(`Fal.ai API error: ${response.status} ${response.statusText} - ${errorData.error || ''}`)
|
||||
}
|
||||
|
||||
return await response.json()
|
||||
} catch (error) {
|
||||
console.error('Error calling Fal.ai video generation:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user