From 85380ade6a6d5bc5a301a057bf67ebdeb9fd9200 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 11 Dec 2024 23:53:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20llmtuner=20perf:?= =?UTF-8?q?=20=E4=BC=98=E5=8C=96=E6=B5=81=E6=B0=B4=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 65 +++++++---- astrbot/core/pipeline/process_stage/stage.py | 14 ++- astrbot/core/platform/astr_message_event.py | 6 +- astrbot/core/provider/manager.py | 3 + astrbot/core/provider/provider.py | 2 +- .../core/provider/sources/llmtuner_source.py | 107 ++++++++++++++++++ astrbot/core/star/context.py | 9 ++ requirements_atri_base.txt | 2 - requirements_atri_ft.txt | 2 - 9 files changed, 176 insertions(+), 34 deletions(-) create mode 100644 astrbot/core/provider/sources/llmtuner_source.py delete mode 100644 requirements_atri_base.txt delete mode 100644 requirements_atri_ft.txt diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 18e1d054a..e6148951a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -23,6 +23,7 @@ DEFAULT_CONFIG = { }, "provider": [], "provider_settings": { + "enable": True, "wake_prefix": "", "web_search": False, "identifier": False, @@ -47,29 +48,7 @@ DEFAULT_CONFIG = { "log_level": "INFO", "t2i_endpoint": "", "pip_install_arg": "", - "plugin_repo_mirror": "", - "project_atri": { - "enable": False, - "long_term_memory": { - "enable": False, - "summary_threshold_cnt": 5, - "embedding_provider_id": "", - "summarize_provider_id": "", - }, - "active_message": {"enable": False}, - "vision": { - "enable": False, - "provider_id_or_ofa_model_path": "", - "reply_meme_prob": 0.4, - "reply_meme_similar_threshold": 0.7, - }, - "persona": "", - "split_response": True, - "chat_provider_id": "", - "chat_base_model_path": "", - "chat_adapter_model_path": "", - "quantization_bit": 4, - }, + "plugin_repo_mirror": "" } @@ -296,6 +275,16 @@ CONFIG_METADATA_2 = { "model": "glm-4-flash", }, }, + "llmtuner": { + "id": "llmtuner_default", + "type": "llm_tuner", + "enable": True, + "base_model_path": "", + "adapter_model_path": "", + "llmtuner_template": "", + "finetuning_type": "lora", + "quantization_bit": 4, + } }, "items": { "id": { @@ -324,6 +313,31 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。", }, + "base_model_path": { + "description": "基座模型路径", + "type": "string", + "hint": "基座模型路径。", + }, + "adapter_model_path": { + "description": "Adapter 模型路径", + "type": "string", + "hint": "Adapter 模型路径。如 Lora", + }, + "llmtuner_template": { + "description": "template", + "type": "string", + "hint": "基座模型的类型。如 llama3, qwen, 请参考 LlamaFactory 文档。", + }, + "finetuning_type": { + "description": "微调类型", + "type": "string", + "hint": "微调类型。如 `lora`", + }, + "quantization_bit": { + "description": "量化位数", + "type": "int", + "hint": "量化位数。如 4", + }, "model_config": { "description": "文本生成模型", "type": "object", @@ -347,6 +361,11 @@ CONFIG_METADATA_2 = { "description": "大语言模型设置", "type": "object", "items": { + "enable": { + "description": "启用大语言模型聊天", + "type": "bool", + "hint": "是否启用大语言模型聊天。默认启用", + }, "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", "type": "string", diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index b5a50c416..faf3bba45 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -10,6 +10,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata class ProcessStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager self.llm_request_sub_stage = LLMRequestSubStage() @@ -23,10 +24,13 @@ class ProcessStage(Stage): ''' activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") - if not activated_handlers: - async for _ in self.llm_request_sub_stage.process(event): - yield - else: + if activated_handlers: async for _ in self.star_request_sub_stage.process(event): yield - \ No newline at end of file + + if self.ctx.astrbot_config['provider_settings'].get('enable', True): + if not event._has_send_oper: + '''当没有发送操作''' + if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result(): + async for _ in self.llm_request_sub_stage.process(event): + yield \ No newline at end of file diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 058d42950..2b9366497 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -44,6 +44,9 @@ class AstrMessageEvent(abc.ABC): self._result: MessageEventResult = None '''消息事件的结果''' + + self._has_send_oper = False + '''是否有过至少一次发送操作''' def get_platform_name(self): return self.platform_meta.name @@ -227,4 +230,5 @@ class AstrMessageEvent(abc.ABC): ''' 发送消息到消息平台。 ''' - await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name) \ No newline at end of file + await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name) + self._has_send_oper = True \ No newline at end of file diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index c4e543e7d..6d6610ba7 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -28,6 +28,9 @@ class ProviderManager(): match provider_cfg['type']: case "openai_chat_completion": from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401 + case "llm_tuner": + logger.info("加载 LLM Tuner 工具 ...") + from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401 async def initialize(self): diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 47fa8f3cc..1cc267679 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -65,7 +65,7 @@ class Provider(abc.ABC): def get_keys(self) -> List[str]: '''获得提供商 Key''' - return self.provider_config['key'] + return self.provider_config.get("key", []) @abc.abstractmethod def set_key(self, key: str): diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py new file mode 100644 index 000000000..751b2ce13 --- /dev/null +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -0,0 +1,107 @@ +import json +import os +from llmtuner.chat import ChatModel +from typing import List +from .. import ProviderMetaData, Provider +from astrbot.core.db import BaseDatabase +from astrbot import logger + +from ..register import register_provider_adapter + +@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型") +class LLMTunerModelLoader(Provider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + db_helper: BaseDatabase, + persistant_history = True + ) -> None: + super().__init__(provider_config, provider_settings, persistant_history, db_helper) + self.base_model_path = provider_config['base_model_path'] + self.adapter_model_path = provider_config['adapter_model_path'] + self.model = ChatModel({ + "model_name_or_path": self.base_model_path, + "adapter_name_or_path": self.adapter_model_path, + "template": provider_config['llmtuner_template'], + "finetuning_type": provider_config['finetuning_type'], + "quantization_bit": provider_config['quantization_bit'], + }) + self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path)) + + async def assemble_context(self, text: str, image_urls: List[str] = None): + ''' + 组装上下文。 + ''' + return {"role": "user", "content": text} + + async def text_chat(self, + prompt: str, + session_id: str, + image_urls: List[str] = None, + tools = None, + contexts: List=None, + **kwargs) -> str: + + system_prompt = "" + if not contexts: + contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}] + system_prompt = self.curr_personality["prompt"] + else: + # 提取出系统提示 + system_idxs = [] + for idx, context in enumerate(contexts): + if context["role"] == "system": + system_idxs.append(idx) + for idx in reversed(system_idxs): + system_prompt += " " + contexts.pop(idx)["content"] + + logger.debug(f"请求上下文:{contexts}") + logger.debug(f"请求 System Prompt:{system_prompt}") + + conf = { + "messages": contexts, + "system": system_prompt, + } + if tools: + conf['tools'] = tools + + responses = await self.model.achat(**conf) + logger.debug(f"返回上下文:{responses}") + self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id])) + self.session_memory[session_id].append({"role": "user", "content": prompt}) + self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text}) + return responses[-1].response_text + + async def forget(self, session_id): + logger.info("llmtuner reset") + self.session_memory[session_id] = [] + return True + + async def get_current_key(self): + return "none" + + async def set_key(self, key): + pass + + async def get_models(self): + return [self.get_model()] + + + async def get_human_readable_context(self, session_id, page, page_size): + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + contexts = [] + for record in self.session_memory[session_id]: + if record['role'] == "user": + contexts.append(f"User: {record['content']}") + elif record['role'] == "assistant": + contexts.append(f"Assistant: {record['content']}") + + # 计算分页 + paged_contexts = contexts[(page-1)*page_size:page*page_size] + total_pages = len(contexts) // page_size + if len(contexts) % page_size != 0: + total_pages += 1 + + return paged_contexts, total_pages \ No newline at end of file diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 130b9097d..8f59df8eb 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -133,6 +133,15 @@ class Context: 注册一个 LLM Provider。 ''' self.provider_manager.provider_insts.append(provider) + + def get_provider_by_id(self, provider_id: str) -> Provider: + ''' + 通过 ID 获取 LLM Provider。 + ''' + for provider in self.provider_manager.provider_insts: + if provider.meta().id == provider_id: + return provider + return None def get_all_providers(self) -> List[Provider]: ''' diff --git a/requirements_atri_base.txt b/requirements_atri_base.txt deleted file mode 100644 index ea90da291..000000000 --- a/requirements_atri_base.txt +++ /dev/null @@ -1,2 +0,0 @@ -chromadb -openai \ No newline at end of file diff --git a/requirements_atri_ft.txt b/requirements_atri_ft.txt deleted file mode 100644 index d2e4234f2..000000000 --- a/requirements_atri_ft.txt +++ /dev/null @@ -1,2 +0,0 @@ -llmtuner -bitsandbytes \ No newline at end of file