From 766f6a1ba2a7b35bebe773b4d8942602c52d918e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 21 Dec 2024 16:35:16 +0800 Subject: [PATCH] perf: use request_llm --- astrbot/api/event/__init__.py | 17 +- astrbot/api/provider/__init__.py | 2 +- astrbot/core/message/message_event_result.py | 4 +- .../strategies/strategy.py | 5 +- .../process_stage/method/llm_request.py | 39 +++-- .../process_stage/method/star_request.py | 4 +- astrbot/core/pipeline/process_stage/stage.py | 5 +- astrbot/core/pipeline/stage.py | 23 +-- astrbot/core/platform/astr_message_event.py | 29 +++- astrbot/core/provider/__init__.py | 2 +- astrbot/core/provider/entites.py | 40 +++++ .../{tool.py => func_tool_manager.py} | 26 +-- astrbot/core/provider/llm_response.py | 13 -- astrbot/core/provider/provider.py | 8 +- astrbot/core/provider/provider_metadata.py | 6 - astrbot/core/provider/provider_request.py | 18 --- astrbot/core/provider/register.py | 4 +- .../core/provider/sources/llmtuner_source.py | 149 +++++++++++------- .../core/provider/sources/openai_source.py | 44 +++--- astrbot/core/star/context.py | 7 +- astrbot/core/star/register/star_handler.py | 4 +- astrbot/core/star/star_manager.py | 5 + packages/reminder/main.py | 12 +- packages/web_searcher/main.py | 28 +--- 24 files changed, 274 insertions(+), 220 deletions(-) create mode 100644 astrbot/core/provider/entites.py rename astrbot/core/provider/{tool.py => func_tool_manager.py} (89%) delete mode 100644 astrbot/core/provider/llm_response.py delete mode 100644 astrbot/core/provider/provider_metadata.py delete mode 100644 astrbot/core/provider/provider_request.py diff --git a/astrbot/api/event/__init__.py b/astrbot/api/event/__init__.py index ae5f7fc37..1f2fce640 100644 --- a/astrbot/api/event/__init__.py +++ b/astrbot/api/event/__init__.py @@ -1,9 +1,18 @@ from astrbot.core.message.message_event_result import ( - MessageEventResult, MessageChain, CommandResult, EventResultType -) + MessageEventResult, + MessageChain, + CommandResult, + EventResultType, + ResultContentType, +) from astrbot.core.platform import AstrMessageEvent __all__ = [ - 'MessageEventResult', 'MessageChain', 'CommandResult', 'EventResultType', 'AstrMessageEvent' -] \ No newline at end of file + "MessageEventResult", + "MessageChain", + "CommandResult", + "EventResultType", + "AstrMessageEvent", + "ResultContentType", +] diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 453ffb876..377f8d4b3 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,2 +1,2 @@ from astrbot.core.provider import Provider, Personality, ProviderMetaData -from astrbot.core.provider.provider_request import ProviderRequest \ No newline at end of file +from astrbot.core.provider.entites import ProviderRequest \ No newline at end of file diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index d214544f2..b75f0b2a6 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -139,13 +139,13 @@ class MessageEventResult(MessageChain): ''' return self.result_type == EventResultType.STOP - def set_result_content_type(self, result_type: EventResultType) -> 'MessageEventResult': + def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult': '''设置事件处理的结果类型。 Args: result_type (EventResultType): 事件处理的结果类型。 ''' - self.result_type = result_type + self.result_content_type = typ return self diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py index dc1bf7e09..57efd22f9 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py @@ -1,6 +1,6 @@ from . import ContentSafetyStrategy from typing import List, Tuple - +from astrbot import logger class StrategySelector: def __init__(self, config: dict) -> None: @@ -15,7 +15,8 @@ class StrategySelector: try: from .baidu_aip import BaiduAipStrategy except ImportError: - raise ImportError("使用百度内容审核应该先 pip install baidu-aip") + logger.warning("使用百度内容审核应该先 pip install baidu-aip") + return self.enabled_strategies.append( BaiduAipStrategy( config["baidu_aip"]["app_id"], diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index a0a2e1a26..a04224984 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -8,7 +8,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.provider_request import ProviderRequest +from astrbot.core.provider.entites import ProviderRequest from astrbot.core.star.star_handler import star_handlers_registry, EventType class LLMRequestSubStage(Stage): @@ -22,8 +22,8 @@ class LLMRequestSubStage(Stage): async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: req: ProviderRequest = None + provider = self.ctx.plugin_manager.context.get_using_provider() if event.get_extra("provider_request"): - print("provider_request") req = event.get_extra("provider_request") assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。" else: @@ -38,8 +38,9 @@ class LLMRequestSubStage(Stage): image_url = comp.url if comp.url else comp.file req.image_urls.append(image_url) req.session_id = event.session_id - - provider = self.ctx.plugin_manager.context.get_using_provider() + event.set_extra("provider_request", req) + session_provider_context = provider.session_memory.get(event.session_id) + req.contexts = session_provider_context if session_provider_context else [] if self.prompt_prefix: req.prompt = self.prompt_prefix + req.prompt @@ -61,27 +62,45 @@ class LLMRequestSubStage(Stage): await handler.handler(event, req) except BaseException: logger.error(traceback.format_exc()) + try: + logger.debug(f"请求 LLM:{req.__dict__}") llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type) - + if llm_response.role == 'assistant': # text completion event.set_result(MessageEventResult().message(llm_response.completion_text) .set_result_content_type(ResultContentType.LLM_RESULT)) elif llm_response.role == 'tool': # function calling + function_calling_result = {} for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args): func_tool = req.func_tool.get_func(func_tool_name) logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}") try: # 尝试调用工具函数 - wrapper = self._call_handler(self.ctx, event, func_tool.star_handler_metadata.handler, **func_tool_args) - async for _ in wrapper: - yield + wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args) + async for resp in wrapper: + if resp is not None: + function_calling_result[func_tool_name] = resp + else: + yield event.clear_result() # 清除上一个 handler 的结果 - except BaseException: - logger.error(traceback.format_exc()) + except BaseException as e: + logger.warning(traceback.format_exc()) + function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e) + if function_calling_result: + # 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。 + # 我们重新执行一遍这个 stage + req.func_tool = None # 暂时不支持递归工具调用 + extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n" + for tool_name, tool_result in function_calling_result.items(): + extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n" + req.prompt += extra_prompt + async for _ in self.process(event): + yield + except BaseException as e: logger.error(traceback.format_exc()) event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e))) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 177c7470e..c1b89e9b9 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -30,8 +30,8 @@ class StarRequestSubStage(Stage): logger.debug(f"执行 Star Handler {handler.handler_full_name}") wrapper = self._call_handler(self.ctx, event, handler.handler, **params) - async for _ in wrapper: - yield + async for ret in wrapper: + yield ret event.clear_result() # 清除上一个 handler 的结果 except Exception as e: logger.error(traceback.format_exc()) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 275d45b99..837e349a7 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -5,7 +5,8 @@ from .method.llm_request import LLMRequestSubStage from .method.star_request import StarRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata -from astrbot.core.provider.provider_request import ProviderRequest +from astrbot.core.provider.entites import ProviderRequest +from astrbot.core import logger @register_stage class ProcessStage(Stage): @@ -24,12 +25,12 @@ class ProcessStage(Stage): '''处理事件 ''' activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") - if activated_handlers: async for resp in self.star_request_sub_stage.process(event): # 生成器返回值处理 if isinstance(resp, ProviderRequest): # Handler 的 LLM 请求 + logger.debug(f"llm request -> {resp.prompt}") event.set_extra("provider_request", resp) async for _ in self.llm_request_sub_stage.process(event): yield diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 9bb9339d5..77a7dbeea 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -49,29 +49,18 @@ class Stage(abc.ABC): ready_to_call = handler(event, ctx.plugin_manager.context, **params) if isinstance(ready_to_call, AsyncGenerator): - async for mer in ready_to_call: + async for ret in ready_to_call: # 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值) - if mer: - assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" - event.set_result(mer) + if isinstance(ret, (MessageEventResult, CommandResult)): + event.set_result(ret) yield else: - if event.get_result(): - yield + yield ret elif inspect.iscoroutine(ready_to_call): # 如果只是一个 coroutine ret = await ready_to_call - if ret: - # 如果有返回值 - assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" + if isinstance(ret, (MessageEventResult, CommandResult)): event.set_result(ret) - # 执行后续步骤来发送消息 - if event.is_stopped() and event.get_result(): - # 插件主动停止事件传播,并且有结果 - event.continue_event() - yield - event.clear_result() - event.stop_event() yield else: - yield \ No newline at end of file + yield ret \ No newline at end of file diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 2f89d11ad..65b06149e 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -7,6 +7,8 @@ from astrbot.core.platform.message_type import MessageType from typing import List, Union from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward from astrbot.core.utils.metrics import Metric +from astrbot.core.provider.entites import ProviderRequest + @dataclass class MessageSesion: @@ -281,4 +283,29 @@ class AstrMessageEvent(abc.ABC): '''LLM 请求相关''' - \ No newline at end of file + def request_llm( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = None, + contexts: List = None, + system_prompt: str = "" + ) -> ProviderRequest: + ''' + 创建一个 LLM 请求。 + + Examples: + ```py + yield event.request_llm(prompt="hi") + ``` + + image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。 + contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。 + ''' + return ProviderRequest( + prompt = prompt, + session_id = session_id, + image_urls = image_urls, + contexts = contexts, + system_prompt = system_prompt + ) \ No newline at end of file diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index d7e2bc18d..b1dfe8732 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1,6 +1,6 @@ from .provider import Provider, Personality -from .provider_metadata import ProviderMetaData +from .entites import ProviderMetaData __all__ = [ "Provider", diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py new file mode 100644 index 000000000..3ee3379f1 --- /dev/null +++ b/astrbot/core/provider/entites.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import List, Dict +from .func_tool_manager import FuncCall + + +@dataclass +class ProviderMetaData(): + type: str # 提供商适配器名称,如 openai, ollama + desc: str = "" # 提供商适配器描述. + + +@dataclass +class ProviderRequest(): + prompt: str + '''提示词''' + session_id: str = "" + '''会话 ID''' + image_urls: List[str] = None + '''图片 URL 列表''' + func_tool: FuncCall = None + '''工具''' + contexts: List = None + '''上下文。格式与 openai 的上下文格式一致: + 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages + ''' + + system_prompt: str = "" + '''系统提示词''' + + +@dataclass +class LLMResponse: + role: str + '''角色''' + completion_text: str = None + '''LLM 返回的文本''' + tools_call_args: List[Dict[str, any]] = None + '''工具调用参数''' + tools_call_name: List[str] = None + '''工具调用名称''' \ No newline at end of file diff --git a/astrbot/core/provider/tool.py b/astrbot/core/provider/func_tool_manager.py similarity index 89% rename from astrbot/core/provider/tool.py rename to astrbot/core/provider/func_tool_manager.py index 76b0d8d15..4aee3eec8 100644 --- a/astrbot/core/provider/tool.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,23 +1,7 @@ import json import textwrap -from typing import Dict, List +from typing import Dict, List, Awaitable from dataclasses import dataclass -from astrbot.core.star.star_handler import StarHandlerMetadata - -class FuncCallJsonFormatError(Exception): - def __init__(self, msg): - self.msg = msg - - def __str__(self): - return self.msg - - -class FuncNotFoundError(Exception): - def __init__(self, msg): - self.msg = msg - - def __str__(self): - return self.msg @dataclass @@ -29,7 +13,7 @@ class FuncTool: name: str parameters: Dict description: str - star_handler_metadata: StarHandlerMetadata + handler: Awaitable active: bool = True '''是否激活''' @@ -55,7 +39,7 @@ class FuncCall: name: str, func_args: list, desc: str, - star_handler_metadata: StarHandlerMetadata, + handler: Awaitable, ) -> None: """ 为函数调用(function-calling / tools-use)添加工具。 @@ -78,7 +62,7 @@ class FuncCall: name=name, parameters=params, description=desc, - star_handler_metadata=star_handler_metadata, + handler=handler, ) self.func_list.append(_func) @@ -180,7 +164,7 @@ class FuncCall: tool_callable = func.star_handler_metadata.handler break if not tool_callable: - raise FuncNotFoundError(f"Request function {func_name} not found.") + raise Exception(f"Request function {func_name} not found.") ret = await tool_callable(**args) if ret: tool_call_result.append(str(ret)) diff --git a/astrbot/core/provider/llm_response.py b/astrbot/core/provider/llm_response.py deleted file mode 100644 index 89fbf4045..000000000 --- a/astrbot/core/provider/llm_response.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Dict, List -from dataclasses import dataclass - -@dataclass -class LLMResponse: - role: str - '''角色''' - completion_text: str = None - '''LLM 返回的文本''' - tools_call_args: List[Dict[str, any]] = None - '''工具调用参数''' - tools_call_name: List[str] = None - '''工具调用名称''' diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index c4ab049e4..553dd78dc 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -5,8 +5,8 @@ from typing import List from astrbot.core.db import BaseDatabase from astrbot.core import logger from typing import TypedDict -from astrbot.core.provider.tool import FuncCall -from astrbot.core.provider.llm_response import LLMResponse +from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.provider.entites import LLMResponse from dataclasses import dataclass class Personality(TypedDict): prompt: str = "" @@ -112,13 +112,11 @@ class Provider(abc.ABC): kwargs: 其他参数 Notes: + - 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。 - 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话, 并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。 - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 - - 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。 - 传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。 - ''' raise NotImplementedError() diff --git a/astrbot/core/provider/provider_metadata.py b/astrbot/core/provider/provider_metadata.py deleted file mode 100644 index 34299a934..000000000 --- a/astrbot/core/provider/provider_metadata.py +++ /dev/null @@ -1,6 +0,0 @@ -from dataclasses import dataclass - -@dataclass -class ProviderMetaData(): - type: str # 提供商适配器名称,如 openai, ollama - desc: str = "" # 提供商适配器描述. \ No newline at end of file diff --git a/astrbot/core/provider/provider_request.py b/astrbot/core/provider/provider_request.py deleted file mode 100644 index 7b8ebc450..000000000 --- a/astrbot/core/provider/provider_request.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass -from typing import List -from .tool import FuncCall - -@dataclass -class ProviderRequest(): - prompt: str - '''提示词''' - session_id: str = "" - '''会话 ID''' - image_urls: List[str] = None - '''图片 URL 列表''' - func_tool: FuncCall = None - '''工具''' - contexts: List = None - '''上下文''' - system_prompt: str = "" - '''系统提示词''' \ No newline at end of file diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index f1086cb4f..00c3ad877 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,7 +1,7 @@ from typing import List, Dict, Type -from .provider_metadata import ProviderMetaData +from .entites import ProviderMetaData from astrbot.core import logger -from .tool import FuncCall +from .func_tool_manager import FuncCall provider_registry: List[ProviderMetaData] = [] '''维护了通过装饰器注册的 Provider''' diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index 9daeb2387..743999fd0 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -3,92 +3,119 @@ import os from llmtuner.chat import ChatModel from typing import List from .. import Provider +from ..entites import LLMResponse +from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase -from astrbot import logger - from ..register import register_provider_adapter -@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型") + +@register_provider_adapter( + "llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型" +) class LLMTunerModelLoader(Provider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, - db_helper: BaseDatabase, - persistant_history = True + db_helper: BaseDatabase, + persistant_history=True, ) -> None: - super().__init__(provider_config, provider_settings, persistant_history, db_helper) - if not os.path.exists(provider_config['base_model_path']) or not os.path.exists(provider_config['adapter_model_path']): + super().__init__( + provider_config, provider_settings, persistant_history, db_helper + ) + if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists( + provider_config["adapter_model_path"] + ): raise FileNotFoundError("模型文件路径不存在。") - self.base_model_path = provider_config['base_model_path'] - self.adapter_model_path = provider_config['adapter_model_path'] - self.model = ChatModel({ - "model_name_or_path": self.base_model_path, - "adapter_name_or_path": self.adapter_model_path, - "template": provider_config['llmtuner_template'], - "finetuning_type": provider_config['finetuning_type'], - "quantization_bit": provider_config['quantization_bit'], - }) - self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path)) - + self.base_model_path = provider_config["base_model_path"] + self.adapter_model_path = provider_config["adapter_model_path"] + self.model = ChatModel( + { + "model_name_or_path": self.base_model_path, + "adapter_name_or_path": self.adapter_model_path, + "template": provider_config["llmtuner_template"], + "finetuning_type": provider_config["finetuning_type"], + "quantization_bit": provider_config["quantization_bit"], + } + ) + self.set_model( + os.path.basename(self.base_model_path) + + "_" + + os.path.basename(self.adapter_model_path) + ) + async def assemble_context(self, text: str, image_urls: List[str] = None): - ''' + """ 组装上下文。 - ''' + """ return {"role": "user", "content": text} - - async def text_chat(self, - prompt: str, - session_id: str, - image_urls: List[str] = None, - tools = None, - contexts: List=None, - **kwargs) -> str: - + + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = None, + func_tool: FuncCall = None, + contexts: List = None, + system_prompt: str = None, + **kwargs, + ) -> LLMResponse: system_prompt = "" if not contexts: - contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}] + query_context = [ + *self.session_memory[session_id], + {"role": "user", "content": prompt}, + ] system_prompt = self.curr_personality["prompt"] else: - # 提取出系统提示 - system_idxs = [] - for idx, context in enumerate(contexts): - if context["role"] == "system": - system_idxs.append(idx) - for idx in reversed(system_idxs): - system_prompt += " " + contexts.pop(idx)["content"] - - logger.debug(f"请求上下文:{contexts}") - logger.debug(f"请求 System Prompt:{system_prompt}") - + query_context = [*contexts, {"role": "user", "content": prompt}] + + # 提取出系统提示 + system_idxs = [] + for idx, context in enumerate(query_context): + if context["role"] == "system": + system_idxs.append(idx) + for idx in reversed(system_idxs): + system_prompt += " " + query_context.pop(idx)["content"] + conf = { - "messages": contexts, + "messages": query_context, "system": system_prompt, } - if tools: - conf['tools'] = tools - + if func_tool: + conf["tools"] = func_tool + responses = await self.model.achat(**conf) - logger.debug(f"返回上下文:{responses}") - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type) - self.session_memory[session_id].append({"role": "user", "content": prompt}) - self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text}) + + if session_id: + if not contexts: + self.session_memory[session_id].append( + {"role": "user", "content": prompt} + ) + self.session_memory[session_id].append( + {"role": "assistant", "content": responses[-1].response_text} + ) + else: + self.session_memory[session_id] = [ + *contexts, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": responses[-1].response_text}, + ] + self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type) return responses[-1].response_text async def forget(self, session_id): - logger.info("llmtuner reset") self.session_memory[session_id] = [] return True - + async def get_current_key(self): return "none" - + async def set_key(self, key): pass - + async def get_models(self): return [self.get_model()] - async def get_human_readable_context(self, session_id, page, page_size): if session_id not in self.session_memory: @@ -96,9 +123,9 @@ class LLMTunerModelLoader(Provider): contexts = [] temp_contexts = [] for record in self.session_memory[session_id]: - if record['role'] == "user": + if record["role"] == "user": temp_contexts.append(f"User: {record['content']}") - elif record['role'] == "assistant": + elif record["role"] == "assistant": temp_contexts.append(f"Assistant: {record['content']}") contexts.insert(0, temp_contexts) temp_contexts = [] @@ -107,9 +134,9 @@ class LLMTunerModelLoader(Provider): contexts = [item for sublist in contexts for item in sublist] # 计算分页 - paged_contexts = contexts[(page-1)*page_size:page*page_size] + paged_contexts = contexts[(page - 1) * page_size : page * page_size] total_pages = len(contexts) // page_size if len(contexts) % page_size != 0: total_pages += 1 - - return paged_contexts, total_pages \ No newline at end of file + + return paged_contexts, total_pages diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 042473f98..1cc792bf3 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -10,10 +10,10 @@ from astrbot.core.utils.io import download_image_by_url from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider from astrbot import logger -from astrbot.core.provider.tool import FuncCall +from astrbot.core.provider.func_tool_manager import FuncCall from typing import List from ..register import register_provider_adapter -from astrbot.core.provider.llm_response import LLMResponse +from astrbot.core.provider.entites import LLMResponse @register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器") class ProviderOpenAIOfficial(Provider): @@ -131,31 +131,30 @@ class ProviderOpenAIOfficial(Provider): else: raise Exception("Internal Error") - async def text_chat(self, - prompt: str, - session_id: str, - image_urls: List[str]=None, - func_tool: FuncCall=None, - contexts=None, - system_prompt=None, - **kwargs - ) -> LLMResponse: + async def text_chat( + self, + prompt: str, + session_id: str, + image_urls: List[str]=None, + func_tool: FuncCall=None, + contexts=None, + system_prompt=None, + **kwargs + ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) context_query = [] if not contexts: context_query = [*self.session_memory[session_id], new_record] - if system_prompt: - context_query.insert(0, {"role": "system", "content": system_prompt}) else: - context_query = contexts - - logger.debug(f"请求上下文:{context_query}, {self.get_model()}") - + context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + payloads = { "messages": context_query, **self.provider_config.get("model_config", {}) } - + try: llm_response = await self._query(payloads, func_tool) except Exception as e: @@ -164,7 +163,7 @@ class ProviderOpenAIOfficial(Provider): self.pop_record(session_id) logger.warning(traceback.format_exc()) - if llm_response.role == "assistant": + if llm_response.role == "assistant" and session_id: # 文本回复 if not contexts: # 添加用户 record @@ -174,7 +173,12 @@ class ProviderOpenAIOfficial(Provider): "role": "assistant", "content": llm_response.completion_text }) - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type']) + else: + self.session_memory[session_id] = [*contexts, new_record, { + "role": "assistant", + "content": llm_response.completion_text + }] + self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type']) return llm_response diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index c2f366e7d..108c25250 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,10 +1,10 @@ from asyncio import Queue from typing import List, TypedDict, Union -from astrbot.core.provider import Provider +from astrbot.core.provider.provider import Provider from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.provider.tool import FuncCall +from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.manager import ProviderManager @@ -78,7 +78,8 @@ class Context: event_filters=[], desc=desc ) - self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__) + star_handlers_registry.append(md) + self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj) def unregister_llm_tool(self, name: str) -> None: '''删除一个函数调用工具。如果再要启用,需要重新注册。''' diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index ac3ec6a36..db0c46bb1 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -9,7 +9,7 @@ from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAd from ..filter.permission import PermissionTypeFilter, PermissionType from ..filter.regex import RegexFilter from typing import Awaitable -from astrbot.core.provider.tool import SUPPORTED_TYPES +from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools from astrbot.core import logger @@ -185,7 +185,7 @@ def register_llm_tool(name: str = None): "description": arg.description }) md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) - llm_tools.add_func(llm_tool_name, args, docstring.short_description, md) + llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler) logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册") return awaitable diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 83ed36aa9..6f597b143 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -14,6 +14,7 @@ from . import StarMetadata from .updator import PluginUpdator from astrbot.core.utils.io import remove_dir from .star import star_registry, star_map +from astrbot.core.provider.register import llm_tools from .star_handler import star_handlers_registry @@ -181,6 +182,10 @@ class PluginManager: logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}") # handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self handler.handler = functools.partial(handler.handler, star_metadata.star_cls) + # llm_tool + for func_tool in llm_tools.func_list: + if func_tool.handler.__module__ == star_metadata.module_path: + func_tool.handler = functools.partial(func_tool.handler, star_metadata.star_cls) else: # v3.4.0 以前的方式注册插件 diff --git a/packages/reminder/main.py b/packages/reminder/main.py index 06071a7ba..1ccbf8cc6 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -20,18 +20,18 @@ class Main(star.Star): f.write("{}") with open("data/astrbot-reminder.json", "r") as f: self.reminder_data = json.load(f) - + self._init_scheduler() self.scheduler.start() - async def _init_scheduler(self): + def _init_scheduler(self): '''Initialize the scheduler.''' for group in self.reminder_data: for reminder in self.reminder_data[group]: if "datetime" in reminder: - self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"]], id=group, run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M")) + self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"], reminder], run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M")) elif "cron" in reminder: - self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"]) + self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"], reminder], **self._parse_cron_expr(reminder["cron"])) async def _save_data(self): '''Save the reminder data.''' @@ -67,14 +67,14 @@ class Main(star.Star): if cron_expression: d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron } self.reminder_data[event.unified_msg_origin].append(d) - self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d], id=event.unified_msg_origin) + self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d]) if human_readable_cron: reminder_time = f"{human_readable_cron}(Cron: {cron_expression})" else: d = { "text": text, "datetime": datetime_str } self.reminder_data[event.unified_msg_origin].append(d) datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M") - self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], id=event.unified_msg_origin, run_date=datetime_scheduled) + self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], run_date=datetime_scheduled) reminder_time = datetime_str await self._save_data() yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。") diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index 1ffffb678..31a285394 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -39,19 +39,6 @@ class Main(star.Star): ret = await self._tidy_text(soup.get_text()) return ret - async def _request_from_llm(self, event: AstrMessageEvent, resources: str) -> str: - '''使用 LLM 对文本进行生成''' - - if self.context.get_using_provider() is None: - raise ValueError("未找到可用的 LLM Provider,无法进行摘要总结") - provider = self.context.get_using_provider() - summary_prompt = f"""{event.get_message_str()} - -# Provided Sources: -{resources}""" - ret = await provider.text_chat(summary_prompt, session_id=event.session_id) - return ret.completion_text - @filter.command("websearch") async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str: websearch = self.context.get_config()['provider_settings']['web_search'] @@ -84,20 +71,21 @@ class Main(star.Star): ''' logger.info("web_searcher - search_from_search_engine: " + query) results = [] + RESULT_NUM = 5 try: - results = await self.google.search(query, 3) + results = await self.google.search(query, RESULT_NUM) except BaseException as e: logger.error(f"google search error: {e}, try the next one...") if len(results) == 0: logger.debug("search google failed") try: - results = await self.bing_search.search(query, 3) + results = await self.bing_search.search(query, RESULT_NUM) except BaseException as e: logger.error(f"bing search error: {e}, try the next one...") if len(results) == 0: logger.debug("search bing failed") try: - results = await self.sogo_search.search(query, 3) + results = await self.sogo_search.search(query, RESULT_NUM) except BaseException as e: logger.error(f"sogo search error: {e}") if len(results) == 0: @@ -111,12 +99,11 @@ class Main(star.Star): site_result = await self._get_from_url(i.url) except BaseException: site_result = "" - site_result = site_result[:1000] + "..." if len(site_result) > 1000 else site_result + site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n" idx += 1 - resp = await self._request_from_llm(event, ret) - event.set_result(MessageEventResult().message(resp)) + return ret @llm_tool("fetch_url") async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str: @@ -126,5 +113,4 @@ class Main(star.Star): url(string): The url of the website to fetch content from ''' resp = await self._get_from_url(url) - resp = await self._request_from_llm(event, resp) - event.set_result(MessageEventResult().message(resp)) \ No newline at end of file + return resp \ No newline at end of file