feat: 支持 llmtuner

perf: 优化流水线
This commit is contained in:
Soulter
2024-12-11 23:53:10 +08:00
parent 86f53deade
commit 85380ade6a
9 changed files with 176 additions and 34 deletions
+42 -23
View File
@@ -23,6 +23,7 @@ DEFAULT_CONFIG = {
},
"provider": [],
"provider_settings": {
"enable": True,
"wake_prefix": "",
"web_search": False,
"identifier": False,
@@ -47,29 +48,7 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": "",
"project_atri": {
"enable": False,
"long_term_memory": {
"enable": False,
"summary_threshold_cnt": 5,
"embedding_provider_id": "",
"summarize_provider_id": "",
},
"active_message": {"enable": False},
"vision": {
"enable": False,
"provider_id_or_ofa_model_path": "",
"reply_meme_prob": 0.4,
"reply_meme_similar_threshold": 0.7,
},
"persona": "",
"split_response": True,
"chat_provider_id": "",
"chat_base_model_path": "",
"chat_adapter_model_path": "",
"quantization_bit": 4,
},
"plugin_repo_mirror": ""
}
@@ -296,6 +275,16 @@ CONFIG_METADATA_2 = {
"model": "glm-4-flash",
},
},
"llmtuner": {
"id": "llmtuner_default",
"type": "llm_tuner",
"enable": True,
"base_model_path": "",
"adapter_model_path": "",
"llmtuner_template": "",
"finetuning_type": "lora",
"quantization_bit": 4,
}
},
"items": {
"id": {
@@ -324,6 +313,31 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。",
},
"base_model_path": {
"description": "基座模型路径",
"type": "string",
"hint": "基座模型路径。",
},
"adapter_model_path": {
"description": "Adapter 模型路径",
"type": "string",
"hint": "Adapter 模型路径。如 Lora",
},
"llmtuner_template": {
"description": "template",
"type": "string",
"hint": "基座模型的类型。如 llama3, qwen, 请参考 LlamaFactory 文档。",
},
"finetuning_type": {
"description": "微调类型",
"type": "string",
"hint": "微调类型。如 `lora`",
},
"quantization_bit": {
"description": "量化位数",
"type": "int",
"hint": "量化位数。如 4",
},
"model_config": {
"description": "文本生成模型",
"type": "object",
@@ -347,6 +361,11 @@ CONFIG_METADATA_2 = {
"description": "大语言模型设置",
"type": "object",
"items": {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "是否启用大语言模型聊天。默认启用",
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
"type": "string",
+9 -5
View File
@@ -10,6 +10,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
class ProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.llm_request_sub_stage = LLMRequestSubStage()
@@ -23,10 +24,13 @@ class ProcessStage(Stage):
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
if not activated_handlers:
async for _ in self.llm_request_sub_stage.process(event):
yield
else:
if activated_handlers:
async for _ in self.star_request_sub_stage.process(event):
yield
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
if not event._has_send_oper:
'''当没有发送操作'''
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
async for _ in self.llm_request_sub_stage.process(event):
yield
+5 -1
View File
@@ -44,6 +44,9 @@ class AstrMessageEvent(abc.ABC):
self._result: MessageEventResult = None
'''消息事件的结果'''
self._has_send_oper = False
'''是否有过至少一次发送操作'''
def get_platform_name(self):
return self.platform_meta.name
@@ -227,4 +230,5 @@ class AstrMessageEvent(abc.ABC):
'''
发送消息到消息平台。
'''
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
+3
View File
@@ -28,6 +28,9 @@ class ProviderManager():
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
async def initialize(self):
+1 -1
View File
@@ -65,7 +65,7 @@ class Provider(abc.ABC):
def get_keys(self) -> List[str]:
'''获得提供商 Key'''
return self.provider_config['key']
return self.provider_config.get("key", [])
@abc.abstractmethod
def set_key(self, key: str):
@@ -0,0 +1,107 @@
import json
import os
from llmtuner.chat import ChatModel
from typing import List
from .. import ProviderMetaData, Provider
from astrbot.core.db import BaseDatabase
from astrbot import logger
from ..register import register_provider_adapter
@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型")
class LLMTunerModelLoader(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
self.base_model_path = provider_config['base_model_path']
self.adapter_model_path = provider_config['adapter_model_path']
self.model = ChatModel({
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config['llmtuner_template'],
"finetuning_type": provider_config['finetuning_type'],
"quantization_bit": provider_config['quantization_bit'],
})
self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path))
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
组装上下文。
'''
return {"role": "user", "content": text}
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str] = None,
tools = None,
contexts: List=None,
**kwargs) -> str:
system_prompt = ""
if not contexts:
contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}]
system_prompt = self.curr_personality["prompt"]
else:
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(contexts):
if context["role"] == "system":
system_idxs.append(idx)
for idx in reversed(system_idxs):
system_prompt += " " + contexts.pop(idx)["content"]
logger.debug(f"请求上下文:{contexts}")
logger.debug(f"请求 System Prompt{system_prompt}")
conf = {
"messages": contexts,
"system": system_prompt,
}
if tools:
conf['tools'] = tools
responses = await self.model.achat(**conf)
logger.debug(f"返回上下文:{responses}")
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]))
self.session_memory[session_id].append({"role": "user", "content": prompt})
self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text})
return responses[-1].response_text
async def forget(self, session_id):
logger.info("llmtuner reset")
self.session_memory[session_id] = []
return True
async def get_current_key(self):
return "none"
async def set_key(self, key):
pass
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
contexts.append(f"Assistant: {record['content']}")
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
+9
View File
@@ -133,6 +133,15 @@ class Context:
注册一个 LLM Provider。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider。
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''
-2
View File
@@ -1,2 +0,0 @@
chromadb
openai
-2
View File
@@ -1,2 +0,0 @@
llmtuner
bitsandbytes