feat: 支持 llmtuner
perf: 优化流水线
This commit is contained in:
@@ -23,6 +23,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"identifier": False,
|
||||
@@ -47,29 +48,7 @@ DEFAULT_CONFIG = {
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": "",
|
||||
"project_atri": {
|
||||
"enable": False,
|
||||
"long_term_memory": {
|
||||
"enable": False,
|
||||
"summary_threshold_cnt": 5,
|
||||
"embedding_provider_id": "",
|
||||
"summarize_provider_id": "",
|
||||
},
|
||||
"active_message": {"enable": False},
|
||||
"vision": {
|
||||
"enable": False,
|
||||
"provider_id_or_ofa_model_path": "",
|
||||
"reply_meme_prob": 0.4,
|
||||
"reply_meme_similar_threshold": 0.7,
|
||||
},
|
||||
"persona": "",
|
||||
"split_response": True,
|
||||
"chat_provider_id": "",
|
||||
"chat_base_model_path": "",
|
||||
"chat_adapter_model_path": "",
|
||||
"quantization_bit": 4,
|
||||
},
|
||||
"plugin_repo_mirror": ""
|
||||
}
|
||||
|
||||
|
||||
@@ -296,6 +275,16 @@ CONFIG_METADATA_2 = {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
},
|
||||
"llmtuner": {
|
||||
"id": "llmtuner_default",
|
||||
"type": "llm_tuner",
|
||||
"enable": True,
|
||||
"base_model_path": "",
|
||||
"adapter_model_path": "",
|
||||
"llmtuner_template": "",
|
||||
"finetuning_type": "lora",
|
||||
"quantization_bit": 4,
|
||||
}
|
||||
},
|
||||
"items": {
|
||||
"id": {
|
||||
@@ -324,6 +313,31 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。",
|
||||
},
|
||||
"base_model_path": {
|
||||
"description": "基座模型路径",
|
||||
"type": "string",
|
||||
"hint": "基座模型路径。",
|
||||
},
|
||||
"adapter_model_path": {
|
||||
"description": "Adapter 模型路径",
|
||||
"type": "string",
|
||||
"hint": "Adapter 模型路径。如 Lora",
|
||||
},
|
||||
"llmtuner_template": {
|
||||
"description": "template",
|
||||
"type": "string",
|
||||
"hint": "基座模型的类型。如 llama3, qwen, 请参考 LlamaFactory 文档。",
|
||||
},
|
||||
"finetuning_type": {
|
||||
"description": "微调类型",
|
||||
"type": "string",
|
||||
"hint": "微调类型。如 `lora`",
|
||||
},
|
||||
"quantization_bit": {
|
||||
"description": "量化位数",
|
||||
"type": "int",
|
||||
"hint": "量化位数。如 4",
|
||||
},
|
||||
"model_config": {
|
||||
"description": "文本生成模型",
|
||||
"type": "object",
|
||||
@@ -347,6 +361,11 @@ CONFIG_METADATA_2 = {
|
||||
"description": "大语言模型设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "是否启用大语言模型聊天。默认启用",
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"type": "string",
|
||||
|
||||
@@ -10,6 +10,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
class ProcessStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
self.llm_request_sub_stage = LLMRequestSubStage()
|
||||
@@ -23,10 +24,13 @@ class ProcessStage(Stage):
|
||||
'''
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
|
||||
if not activated_handlers:
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
else:
|
||||
if activated_handlers:
|
||||
async for _ in self.star_request_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
|
||||
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
if not event._has_send_oper:
|
||||
'''当没有发送操作'''
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
@@ -44,6 +44,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
self._result: MessageEventResult = None
|
||||
'''消息事件的结果'''
|
||||
|
||||
self._has_send_oper = False
|
||||
'''是否有过至少一次发送操作'''
|
||||
|
||||
def get_platform_name(self):
|
||||
return self.platform_meta.name
|
||||
@@ -227,4 +230,5 @@ class AstrMessageEvent(abc.ABC):
|
||||
'''
|
||||
发送消息到消息平台。
|
||||
'''
|
||||
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
|
||||
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
|
||||
self._has_send_oper = True
|
||||
@@ -28,6 +28,9 @@ class ProviderManager():
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -65,7 +65,7 @@ class Provider(abc.ABC):
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
'''获得提供商 Key'''
|
||||
return self.provider_config['key']
|
||||
return self.provider_config.get("key", [])
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_key(self, key: str):
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
import os
|
||||
from llmtuner.chat import ChatModel
|
||||
from typing import List
|
||||
from .. import ProviderMetaData, Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型")
|
||||
class LLMTunerModelLoader(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
self.base_model_path = provider_config['base_model_path']
|
||||
self.adapter_model_path = provider_config['adapter_model_path']
|
||||
self.model = ChatModel({
|
||||
"model_name_or_path": self.base_model_path,
|
||||
"adapter_name_or_path": self.adapter_model_path,
|
||||
"template": provider_config['llmtuner_template'],
|
||||
"finetuning_type": provider_config['finetuning_type'],
|
||||
"quantization_bit": provider_config['quantization_bit'],
|
||||
})
|
||||
self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path))
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
'''
|
||||
组装上下文。
|
||||
'''
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str] = None,
|
||||
tools = None,
|
||||
contexts: List=None,
|
||||
**kwargs) -> str:
|
||||
|
||||
system_prompt = ""
|
||||
if not contexts:
|
||||
contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}]
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
else:
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(contexts):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + contexts.pop(idx)["content"]
|
||||
|
||||
logger.debug(f"请求上下文:{contexts}")
|
||||
logger.debug(f"请求 System Prompt:{system_prompt}")
|
||||
|
||||
conf = {
|
||||
"messages": contexts,
|
||||
"system": system_prompt,
|
||||
}
|
||||
if tools:
|
||||
conf['tools'] = tools
|
||||
|
||||
responses = await self.model.achat(**conf)
|
||||
logger.debug(f"返回上下文:{responses}")
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]))
|
||||
self.session_memory[session_id].append({"role": "user", "content": prompt})
|
||||
self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text})
|
||||
return responses[-1].response_text
|
||||
|
||||
async def forget(self, session_id):
|
||||
logger.info("llmtuner reset")
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
|
||||
async def set_key(self, key):
|
||||
pass
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
contexts.append(f"Assistant: {record['content']}")
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
@@ -133,6 +133,15 @@ class Context:
|
||||
注册一个 LLM Provider。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
'''
|
||||
通过 ID 获取 LLM Provider。
|
||||
'''
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
chromadb
|
||||
openai
|
||||
@@ -1,2 +0,0 @@
|
||||
llmtuner
|
||||
bitsandbytes
|
||||
Reference in New Issue
Block a user