add model type
This commit is contained in:
parent
be53af4638
commit
ff67d48bcc
@ -7,6 +7,8 @@ interface OpenRouterModel {
|
|||||||
modality: string
|
modality: string
|
||||||
tokenizer: string
|
tokenizer: string
|
||||||
instruct_type?: string
|
instruct_type?: string
|
||||||
|
input_modalities?: string[]
|
||||||
|
output_modalities?: string[]
|
||||||
}
|
}
|
||||||
pricing: {
|
pricing: {
|
||||||
prompt: string
|
prompt: string
|
||||||
@ -63,12 +65,15 @@ export class OpenRouterService {
|
|||||||
|
|
||||||
// 将 OpenRouter 模型转换为我们数据库的格式
|
// 将 OpenRouter 模型转换为我们数据库的格式
|
||||||
transformModelForDB(model: OpenRouterModel) {
|
transformModelForDB(model: OpenRouterModel) {
|
||||||
|
const allOutputTypes = this.determineAllOutputTypes(model)
|
||||||
|
const primaryOutputType = this.determinePrimaryOutputType(allOutputTypes)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
modelId: model.id,
|
modelId: model.id,
|
||||||
name: model.name,
|
name: model.name,
|
||||||
provider: this.extractProvider(model.id),
|
provider: this.extractProvider(model.id),
|
||||||
serviceProvider: 'openrouter',
|
serviceProvider: 'openrouter',
|
||||||
outputType: 'text',
|
outputType: primaryOutputType, // 主要输出类型
|
||||||
description: model.description || null,
|
description: model.description || null,
|
||||||
maxTokens: model.context_length || null,
|
maxTokens: model.context_length || null,
|
||||||
inputCostPer1k: parseFloat(model.pricing.prompt) * 1000 || null,
|
inputCostPer1k: parseFloat(model.pricing.prompt) * 1000 || null,
|
||||||
@ -77,6 +82,9 @@ export class OpenRouterService {
|
|||||||
modality: model.architecture.modality,
|
modality: model.architecture.modality,
|
||||||
tokenizer: model.architecture.tokenizer,
|
tokenizer: model.architecture.tokenizer,
|
||||||
instruct_type: model.architecture.instruct_type,
|
instruct_type: model.architecture.instruct_type,
|
||||||
|
input_modalities: model.architecture.input_modalities,
|
||||||
|
output_modalities: model.architecture.output_modalities,
|
||||||
|
all_output_types: allOutputTypes, // 记录所有支持的输出类型
|
||||||
is_moderated: model.top_provider.is_moderated,
|
is_moderated: model.top_provider.is_moderated,
|
||||||
max_completion_tokens: model.top_provider.max_completion_tokens,
|
max_completion_tokens: model.top_provider.max_completion_tokens,
|
||||||
},
|
},
|
||||||
@ -88,6 +96,83 @@ export class OpenRouterService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据模型的架构信息确定所有支持的输出类型
|
||||||
|
private determineAllOutputTypes(model: OpenRouterModel): string[] {
|
||||||
|
const supportedTypes: Set<string> = new Set()
|
||||||
|
|
||||||
|
// 优先使用 output_modalities 字段
|
||||||
|
if (model.architecture.output_modalities && Array.isArray(model.architecture.output_modalities)) {
|
||||||
|
model.architecture.output_modalities.forEach(modality => {
|
||||||
|
supportedTypes.add(modality)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有 output_modalities,则通过 modality 字段推断
|
||||||
|
if (model.architecture.modality) {
|
||||||
|
const modality = model.architecture.modality.toLowerCase()
|
||||||
|
|
||||||
|
// 检查 modality 字段中的输出类型指示
|
||||||
|
if (modality.includes('->image') || (modality.includes('image') && modality.includes('->'))) {
|
||||||
|
supportedTypes.add('image')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modality.includes('->video') || (modality.includes('video') && modality.includes('->'))) {
|
||||||
|
supportedTypes.add('video')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modality.includes('->audio') || (modality.includes('audio') && modality.includes('->'))) {
|
||||||
|
supportedTypes.add('audio')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modality.includes('->text') || (modality.includes('text') && modality.includes('->'))) {
|
||||||
|
supportedTypes.add('text')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通过模型名称进行启发式判断
|
||||||
|
const modelName = model.name?.toLowerCase() || ''
|
||||||
|
const modelId = model.id?.toLowerCase() || ''
|
||||||
|
|
||||||
|
// 图像生成模型关键词
|
||||||
|
const imageKeywords = ['dall-e', 'dalle', 'midjourney', 'stable-diffusion', 'flux', 'imagen', 'ideogram']
|
||||||
|
if (imageKeywords.some(keyword => modelName.includes(keyword) || modelId.includes(keyword))) {
|
||||||
|
supportedTypes.add('image')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 视频生成模型关键词
|
||||||
|
const videoKeywords = ['runway', 'gen-2', 'pika', 'video', 'sora']
|
||||||
|
if (videoKeywords.some(keyword => modelName.includes(keyword) || modelId.includes(keyword))) {
|
||||||
|
supportedTypes.add('video')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 音频生成模型关键词
|
||||||
|
const audioKeywords = ['whisper', 'musicgen', 'bark', 'audio', 'speech', 'tts']
|
||||||
|
if (audioKeywords.some(keyword => modelName.includes(keyword) || modelId.includes(keyword))) {
|
||||||
|
supportedTypes.add('audio')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有找到任何输出类型,默认为 text
|
||||||
|
if (supportedTypes.size === 0) {
|
||||||
|
supportedTypes.add('text')
|
||||||
|
}
|
||||||
|
|
||||||
|
return Array.from(supportedTypes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确定主要输出类型(用于数据库的 outputType 字段)
|
||||||
|
private determinePrimaryOutputType(supportedTypes: string[]): string {
|
||||||
|
// 按优先级返回主要类型
|
||||||
|
const priority = ['image', 'video', 'audio', 'text']
|
||||||
|
|
||||||
|
for (const type of priority) {
|
||||||
|
if (supportedTypes.includes(type)) {
|
||||||
|
return type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 'text' // 默认返回 text
|
||||||
|
}
|
||||||
|
|
||||||
private extractProvider(modelId: string): string {
|
private extractProvider(modelId: string): string {
|
||||||
// 从模型 ID 中提取提供商名称,如 "openai/gpt-4" -> "OpenAI"
|
// 从模型 ID 中提取提供商名称,如 "openai/gpt-4" -> "OpenAI"
|
||||||
const providerMap: Record<string, string> = {
|
const providerMap: Record<string, string> = {
|
||||||
|
Loading…
Reference in New Issue
Block a user