feat: update provider and provider source configuration handling

This commit is contained in:
Soulter
2025-12-15 12:31:29 +08:00
parent a70088b799
commit 45110200ea
9 changed files with 237 additions and 79 deletions
+4 -16
View File
@@ -1974,22 +1974,10 @@ CONFIG_METADATA_2 = {
"description": "API Base URL",
"type": "string",
},
"model_config": {
"description": "模型配置",
"type": "object",
"items": {
"model": {
"description": "模型名称",
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_tokens": {
"description": "模型最大输出长度(tokens",
"type": "int",
},
"temperature": {"description": "温度", "type": "float"},
"top_p": {"description": "Top P值", "type": "float"},
},
"model": {
"description": "模型 ID",
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"dify_api_key": {
"description": "API Key",
+4
View File
@@ -37,6 +37,8 @@ class ProviderManager:
config = acm.confs["default"]
self.providers_config: list = config["provider"]
self.provider_sources_config: list = config.get("provider_sources", [])
self.merged_provider_config: dict = {}
"""合并 provider 和 provider_sources 配置后的结果"""
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
@@ -270,6 +272,8 @@ class ProviderManager:
merged_config["id"] = provider_config["id"]
provider_config = merged_config
self.merged_provider_config[provider_config["id"]] = provider_config
if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return
@@ -45,7 +45,7 @@ class ProviderAnthropic(Provider):
base_url=self.base_url,
)
self.set_model(provider_config["model_config"]["model"])
self.set_model(provider_config.get("model", "unknown"))
def _prepare_payload(self, messages: list[dict]):
"""准备 Anthropic API 的请求 payload
@@ -285,10 +285,9 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts
if system_prompt:
@@ -298,7 +297,6 @@ class ProviderAnthropic(Provider):
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response
@@ -340,10 +338,9 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts
if system_prompt:
@@ -68,7 +68,7 @@ class ProviderGoogleGenAI(Provider):
self.api_base = self.api_base[:-1]
self._init_client()
self.set_model(provider_config["model_config"]["model"])
self.set_model(provider_config.get("model", "unknown"))
self._init_safety_settings()
def _init_client(self) -> None:
@@ -652,10 +652,9 @@ class ProviderGoogleGenAI(Provider):
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": context_query, **model_config}
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()
@@ -705,10 +704,9 @@ class ProviderGoogleGenAI(Provider):
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": context_query, **model_config}
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()
@@ -68,8 +68,7 @@ class ProviderOpenAIOfficial(Provider):
self.client.chat.completions.create,
).parameters.keys()
model_config = provider_config.get("model_config", {})
model = model_config.get("model", "unknown")
model = provider_config.get("model", "unknown")
self.set_model(model)
self.reasoning_key = "reasoning_content"
@@ -358,10 +357,9 @@ class ProviderOpenAIOfficial(Provider):
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": context_query, **model_config}
payloads = {"messages": context_query, "model": model}
# xAI origin search tool inject
self._maybe_inject_xai_search(payloads, **kwargs)
+1
View File
@@ -51,6 +51,7 @@ def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None:
"model",
"modalities",
"custom_extra_body",
"enable",
}
# Fields that should not go to source
+73 -2
View File
@@ -188,9 +188,80 @@ class ConfigRoute(Route):
"GET",
self.get_provider_source_models,
),
"/config/provider_sources/<provider_source_id>/update": (
"POST",
self.update_provider_source,
),
}
self.register_routes()
async def update_provider_source(self, provider_source_id: str):
"""更新或新增 provider_source,并重载关联的 providers"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
new_source_config = post_data.get("config") or post_data
original_id = post_data.get("original_id") or provider_source_id
if not isinstance(new_source_config, dict):
return Response().error("缺少或错误的配置数据").__dict__
# 确保配置中有 id 字段
if not new_source_config.get("id"):
new_source_config["id"] = original_id
provider_sources = self.config.get("provider_sources", [])
# 查找旧的 provider_source,若不存在则追加为新配置
target_idx = next(
(i for i, ps in enumerate(provider_sources) if ps.get("id") == original_id),
-1,
)
old_id = original_id
if target_idx == -1:
provider_sources.append(new_source_config)
else:
old_id = provider_sources[target_idx].get("id")
provider_sources[target_idx] = new_source_config
# 更新引用了该 provider_source 的 providers
affected_providers = []
for provider in self.config.get("provider", []):
if provider.get("provider_source_id") == old_id:
provider["provider_source_id"] = new_source_config["id"]
affected_providers.append(provider)
# 写回配置
self.config["provider_sources"] = provider_sources
try:
save_config(self.config, self.config, is_core=True)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
# 重载受影响的 providers,使新的 source 配置生效
reload_errors = []
prov_mgr = self.core_lifecycle.provider_manager
for provider in affected_providers:
try:
await prov_mgr.reload(provider)
except Exception as e:
logger.error(traceback.format_exc())
reload_errors.append(f"{provider.get('id')}: {e}")
if reload_errors:
return (
Response()
.error("更新成功,但部分提供商重载失败: " + ", ".join(reload_errors))
.__dict__
)
return Response().ok(message="更新 provider source 成功").__dict__
async def get_provider_template(self):
provider_config = astrbot_config["provider"]
config_schema = {
@@ -449,8 +520,8 @@ class ConfigRoute(Route):
return Response().error("缺少参数 provider_type").__dict__
provider_type_ls = provider_type.split(",")
provider_list = []
astrbot_config = self.core_lifecycle.astrbot_config
for provider in astrbot_config["provider"]:
pc = self.core_lifecycle.provider_manager.merged_provider_config
for provider in pc.values():
if provider.get("provider_type", None) in provider_type_ls:
provider_list.append(provider)
return Response().ok(provider_list).__dict__
@@ -51,7 +51,7 @@
<v-list-item-title>{{ provider.id }}</v-list-item-title>
<v-list-item-subtitle>
{{ provider.type || provider.provider_type || tm('providerSelector.unknownType') }}
<span v-if="provider.model_config?.model">- {{ provider.model_config.model }}</span>
<span v-if="provider.model">- {{ provider.model }}</span>
</v-list-item-subtitle>
<template v-slot:append>
+141 -40
View File
@@ -100,14 +100,6 @@
<!-- Provider Source 配置 -->
<v-card-title class="pa-4 pb-2 d-flex align-center justify-space-between">
<span class="text-h4 font-weight-bold">{{ selectedProviderSource.id }}</span>
<v-switch
v-model="selectedProviderSource.enable"
density="compact"
hide-details
inset
color="primary"
@change="isSourceModified = true">
</v-switch>
</v-card-title>
<v-card-text class="pa-4">
@@ -195,7 +187,7 @@
<div class="d-flex align-center justify-space-between" style="width: 100%;">
<div>
<strong>{{ provider.id }}</strong>
<span class="text-caption text-grey ml-2">{{ provider.model_config?.model }}</span>
<span class="text-caption text-grey ml-2">{{ provider.model }}</span>
</div>
<div class="d-flex align-center" @click.stop>
<v-switch
@@ -214,6 +206,15 @@
@click.stop="testProvider(provider)"
class="mr-1">
</v-btn>
<v-btn
icon="mdi-content-save"
size="small"
variant="text"
color="success"
:loading="savingProviders.includes(provider.id)"
@click.stop="saveSingleProvider(provider)"
class="mr-1">
</v-btn>
<v-btn
icon="mdi-delete"
size="small"
@@ -359,7 +360,7 @@
</template>
<script setup>
import { ref, computed, onMounted } from 'vue'
import { ref, computed, onMounted, nextTick, watch } from 'vue'
import { useRouter } from 'vue-router'
import axios from 'axios'
import { useModuleI18n } from '@/i18n/composables'
@@ -378,10 +379,13 @@ const providerSources = ref([])
const providers = ref([])
const selectedProviderType = ref('chat_completion')
const selectedProviderSource = ref(null)
const selectedProviderSourceOriginalId = ref(null)
const editableProviderSource = ref(null)
const availableModels = ref([])
const loadingModels = ref(false)
const savingSource = ref(false)
const testingProviders = ref([])
const savingProviders = ref([])
const isSourceModified = ref(false)
const configSchema = ref({})
const providerTemplates = ref({})
@@ -396,6 +400,8 @@ const loading = ref(false)
const providerStatuses = ref([])
const showAgentRunnerDialog = ref(false)
let suppressSourceWatch = false
const snackbar = ref({
show: false,
message: '',
@@ -452,33 +458,46 @@ const sourceProviders = computed(() => {
//
const basicSourceConfig = computed(() => {
if (!selectedProviderSource.value) return null
const basicFields = ['id', 'key', 'api_base']
if (!editableProviderSource.value) return null
const fields = ['id', 'key', 'api_base']
const basic = {}
for (const [key, value] of Object.entries(selectedProviderSource.value)) {
if (basicFields.includes(key)) {
basic[key] = value
}
}
fields.forEach(field => {
Object.defineProperty(basic, field, {
get() {
return editableProviderSource.value[field]
},
set(val) {
editableProviderSource.value[field] = val
},
enumerable: true
})
})
return basic
})
//
const advancedSourceConfig = computed(() => {
if (!selectedProviderSource.value) return null
const basicFields = ['id', 'key', 'api_base', 'enable', 'type', 'provider_type']
if (!editableProviderSource.value) return null
const excluded = ['id', 'key', 'api_base', 'enable', 'type', 'provider_type']
const advanced = {}
for (const [key, value] of Object.entries(selectedProviderSource.value)) {
if (!basicFields.includes(key)) {
advanced[key] = value
}
for (const key of Object.keys(editableProviderSource.value)) {
if (excluded.includes(key)) continue
Object.defineProperty(advanced, key, {
get() {
return editableProviderSource.value[key]
},
set(val) {
editableProviderSource.value[key] = val
},
enumerable: true
})
}
return advanced
})
@@ -524,11 +543,19 @@ function showMessage(message, color = 'success') {
function selectProviderType(type) {
selectedProviderType.value = type
selectedProviderSource.value = null
selectedProviderSourceOriginalId.value = null
editableProviderSource.value = null
availableModels.value = []
}
function selectProviderSource(source) {
selectedProviderSource.value = source
selectedProviderSourceOriginalId.value = source?.id || null
suppressSourceWatch = true
editableProviderSource.value = source ? JSON.parse(JSON.stringify(source)) : null
nextTick(() => {
suppressSourceWatch = false
})
availableModels.value = []
isSourceModified.value = false
}
@@ -549,12 +576,14 @@ function addProviderSource(templateKey) {
type: template.type,
provider_type: template.provider_type,
enable: true,
// id, enable, type, provider_type, model_config provider
// id, enable, type, provider_type provider
...extractSourceFieldsFromTemplate(template)
}
providerSources.value.push(newSource)
selectedProviderSource.value = newSource
selectedProviderSourceOriginalId.value = newId
editableProviderSource.value = JSON.parse(JSON.stringify(newSource))
isSourceModified.value = true
}
@@ -562,7 +591,7 @@ function extractSourceFieldsFromTemplate(template) {
// source
const sourceFields = {}
const excludeKeys = [
'id', 'enable', 'type', 'provider_type', 'model_config', 'model',
'id', 'enable', 'type', 'provider_type', 'model',
'provider_source_id', 'provider', 'hint', 'modalities',
'custom_extra_body', 'custom_headers'
]
@@ -587,6 +616,8 @@ async function deleteProviderSource(source) {
if (selectedProviderSource.value?.id === source.id) {
selectedProviderSource.value = null
selectedProviderSourceOriginalId.value = null
editableProviderSource.value = null
}
await saveConfig()
@@ -600,12 +631,49 @@ async function saveProviderSource() {
if (!selectedProviderSource.value) return
savingSource.value = true
const originalId = selectedProviderSourceOriginalId.value || selectedProviderSource.value.id
try {
await saveConfig()
const response = await axios.post(
`/api/config/provider_sources/${originalId}/update`,
{
config: editableProviderSource.value,
original_id: originalId
}
)
if (response.data.status !== 'ok') {
throw new Error(response.data.message)
}
if (editableProviderSource.value.id !== originalId) {
providers.value = providers.value.map(p =>
p.provider_source_id === originalId
? { ...p, provider_source_id: editableProviderSource.value.id }
: p
)
selectedProviderSourceOriginalId.value = editableProviderSource.value.id
}
// source
const idx = providerSources.value.findIndex(ps => ps.id === originalId)
if (idx !== -1) {
providerSources.value[idx] = JSON.parse(JSON.stringify(editableProviderSource.value))
selectedProviderSource.value = providerSources.value[idx]
}
//
suppressSourceWatch = true
editableProviderSource.value = selectedProviderSource.value
nextTick(() => {
suppressSourceWatch = false
})
isSourceModified.value = false
showMessage(tm('providerSources.saveSuccess'))
showMessage(response.data.message || tm('providerSources.saveSuccess'))
return true
} catch (error) {
showMessage(error.message || tm('providerSources.saveError'), 'error')
showMessage(error.response?.data?.message || error.message || tm('providerSources.saveError'), 'error')
return false
} finally {
savingSource.value = false
}
@@ -617,13 +685,17 @@ async function fetchAvailableModels() {
//
if (isSourceModified.value) {
await saveProviderSource()
const saved = await saveProviderSource()
if (!saved) {
return
}
}
loadingModels.value = true
try {
const sourceId = editableProviderSource.value?.id || selectedProviderSource.value.id
const response = await axios.get(
`/api/config/provider_sources/${selectedProviderSource.value.id}/models`
`/api/config/provider_sources/${sourceId}/models`
)
if (response.data.status === 'ok') {
availableModels.value = response.data.data.models || []
@@ -643,10 +715,11 @@ async function fetchAvailableModels() {
function addModelProvider(modelName) {
if (!selectedProviderSource.value) return
const newId = `${selectedProviderSource.value.id}_${modelName}_${Date.now()}`
const sourceId = editableProviderSource.value?.id || selectedProviderSource.value.id
const newId = `${sourceId}/${modelName}`
const newProvider = {
id: newId,
provider_source_id: selectedProviderSource.value.id,
provider_source_id: sourceId,
model: modelName,
modalities: [],
custom_extra_body: {}
@@ -661,8 +734,8 @@ async function deleteProvider(provider) {
if (!confirm(tm('models.deleteConfirm', { id: provider.id }))) return
try {
await axios.post('/api/config/provider/delete', { id: provider.id })
providers.value = providers.value.filter(p => p.id !== provider.id)
await saveConfig()
showMessage(tm('models.deleteSuccess'))
} catch (error) {
showMessage(error.message || tm('models.deleteError'), 'error')
@@ -673,7 +746,7 @@ async function testProvider(provider) {
testingProviders.value.push(provider.id)
try {
const response = await axios.get('/api/config/provider/check_one', {
params: { provider_id: provider.id }
params: { id: provider.id }
})
if (response.data.status === 'ok') {
showMessage(tm('models.testSuccess', { id: provider.id }))
@@ -687,6 +760,27 @@ async function testProvider(provider) {
}
}
async function saveSingleProvider(provider) {
if (!provider) return
const exists = (config.value.provider || []).some(p => p.id === provider.id)
savingProviders.value.push(provider.id)
try {
const url = exists ? '/api/config/provider/update' : '/api/config/provider/new'
const payload = exists ? { id: provider.id, config: provider } : provider
const res = await axios.post(url, payload)
if (res.data.status === 'error') {
throw new Error(res.data.message)
}
showMessage(res.data.message || tm('providerSources.saveSuccess'))
await loadConfig()
} catch (err) {
showMessage(err.response?.data?.message || err.message || tm('providerSources.saveError'), 'error')
} finally {
savingProviders.value = savingProviders.value.filter(id => id !== provider.id)
}
}
async function saveConfig() {
try {
config.value.provider_sources = providerSources.value
@@ -737,6 +831,13 @@ onMounted(async () => {
await loadMetadata()
})
// provider source
watch(editableProviderSource, () => {
if (suppressSourceWatch) return
if (!editableProviderSource.value) return
isSourceModified.value = true
}, { deep: true })
// ===== chat =====
function getProviderType(provider) {
if (!provider) return undefined