perf: use request_llm
This commit is contained in:
@@ -1,9 +1,18 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult, MessageChain, CommandResult, EventResultType
|
||||
)
|
||||
MessageEventResult,
|
||||
MessageChain,
|
||||
CommandResult,
|
||||
EventResultType,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
__all__ = [
|
||||
'MessageEventResult', 'MessageChain', 'CommandResult', 'EventResultType', 'AstrMessageEvent'
|
||||
]
|
||||
"MessageEventResult",
|
||||
"MessageChain",
|
||||
"CommandResult",
|
||||
"EventResultType",
|
||||
"AstrMessageEvent",
|
||||
"ResultContentType",
|
||||
]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
from astrbot.core.provider.provider_request import ProviderRequest
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
@@ -139,13 +139,13 @@ class MessageEventResult(MessageChain):
|
||||
'''
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_result_content_type(self, result_type: EventResultType) -> 'MessageEventResult':
|
||||
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
|
||||
'''设置事件处理的结果类型。
|
||||
|
||||
Args:
|
||||
result_type (EventResultType): 事件处理的结果类型。
|
||||
'''
|
||||
self.result_type = result_type
|
||||
self.result_content_type = typ
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from . import ContentSafetyStrategy
|
||||
from typing import List, Tuple
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
class StrategySelector:
|
||||
def __init__(self, config: dict) -> None:
|
||||
@@ -15,7 +15,8 @@ class StrategySelector:
|
||||
try:
|
||||
from .baidu_aip import BaiduAipStrategy
|
||||
except ImportError:
|
||||
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
|
||||
logger.warning("使用百度内容审核应该先 pip install baidu-aip")
|
||||
return
|
||||
self.enabled_strategies.append(
|
||||
BaiduAipStrategy(
|
||||
config["baidu_aip"]["app_id"],
|
||||
|
||||
@@ -8,7 +8,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.provider_request import ProviderRequest
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -22,8 +22,8 @@ class LLMRequestSubStage(Stage):
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if event.get_extra("provider_request"):
|
||||
print("provider_request")
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
else:
|
||||
@@ -38,8 +38,9 @@ class LLMRequestSubStage(Stage):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
req.image_urls.append(image_url)
|
||||
req.session_id = event.session_id
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
event.set_extra("provider_request", req)
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
req.contexts = session_provider_context if session_provider_context else []
|
||||
|
||||
if self.prompt_prefix:
|
||||
req.prompt = self.prompt_prefix + req.prompt
|
||||
@@ -61,27 +62,45 @@ class LLMRequestSubStage(Stage):
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
logger.debug(f"请求 LLM:{req.__dict__}")
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
# text completion
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
||||
elif llm_response.role == 'tool':
|
||||
# function calling
|
||||
function_calling_result = {}
|
||||
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
||||
try:
|
||||
# 尝试调用工具函数
|
||||
wrapper = self._call_handler(self.ctx, event, func_tool.star_handler_metadata.handler, **func_tool_args)
|
||||
async for _ in wrapper:
|
||||
yield
|
||||
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
function_calling_result[func_tool_name] = resp
|
||||
else:
|
||||
yield
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
|
||||
if function_calling_result:
|
||||
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
|
||||
# 我们重新执行一遍这个 stage
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
|
||||
for tool_name, tool_result in function_calling_result.items():
|
||||
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
|
||||
req.prompt += extra_prompt
|
||||
async for _ in self.process(event):
|
||||
yield
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
|
||||
|
||||
@@ -30,8 +30,8 @@ class StarRequestSubStage(Stage):
|
||||
|
||||
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||
async for _ in wrapper:
|
||||
yield
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -5,7 +5,8 @@ from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.provider.provider_request import ProviderRequest
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core import logger
|
||||
|
||||
@register_stage
|
||||
class ProcessStage(Stage):
|
||||
@@ -24,12 +25,12 @@ class ProcessStage(Stage):
|
||||
'''处理事件
|
||||
'''
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
|
||||
if activated_handlers:
|
||||
async for resp in self.star_request_sub_stage.process(event):
|
||||
# 生成器返回值处理
|
||||
if isinstance(resp, ProviderRequest):
|
||||
# Handler 的 LLM 请求
|
||||
logger.debug(f"llm request -> {resp.prompt}")
|
||||
event.set_extra("provider_request", resp)
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -49,29 +49,18 @@ class Stage(abc.ABC):
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for mer in ready_to_call:
|
||||
async for ret in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
if mer:
|
||||
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(mer)
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
if event.get_result():
|
||||
yield
|
||||
yield ret
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
ret = await ready_to_call
|
||||
if ret:
|
||||
# 如果有返回值
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
if event.is_stopped() and event.get_result():
|
||||
# 插件主动停止事件传播,并且有结果
|
||||
event.continue_event()
|
||||
yield
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
yield ret
|
||||
@@ -7,6 +7,8 @@ from astrbot.core.platform.message_type import MessageType
|
||||
from typing import List, Union
|
||||
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSesion:
|
||||
@@ -281,4 +283,29 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
'''LLM 请求相关'''
|
||||
|
||||
|
||||
def request_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = ""
|
||||
) -> ProviderRequest:
|
||||
'''
|
||||
创建一个 LLM 请求。
|
||||
|
||||
Examples:
|
||||
```py
|
||||
yield event.request_llm(prompt="hi")
|
||||
```
|
||||
|
||||
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
||||
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
|
||||
'''
|
||||
return ProviderRequest(
|
||||
prompt = prompt,
|
||||
session_id = session_id,
|
||||
image_urls = image_urls,
|
||||
contexts = contexts,
|
||||
system_prompt = system_prompt
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
from .provider import Provider, Personality
|
||||
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from .entites import ProviderMetaData
|
||||
|
||||
__all__ = [
|
||||
"Provider",
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict
|
||||
from .func_tool_manager import FuncCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData():
|
||||
type: str # 提供商适配器名称,如 openai, ollama
|
||||
desc: str = "" # 提供商适配器描述.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest():
|
||||
prompt: str
|
||||
'''提示词'''
|
||||
session_id: str = ""
|
||||
'''会话 ID'''
|
||||
image_urls: List[str] = None
|
||||
'''图片 URL 列表'''
|
||||
func_tool: FuncCall = None
|
||||
'''工具'''
|
||||
contexts: List = None
|
||||
'''上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
'''
|
||||
|
||||
system_prompt: str = ""
|
||||
'''系统提示词'''
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
'''角色'''
|
||||
completion_text: str = None
|
||||
'''LLM 返回的文本'''
|
||||
tools_call_args: List[Dict[str, any]] = None
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = None
|
||||
'''工具调用名称'''
|
||||
@@ -1,23 +1,7 @@
|
||||
import json
|
||||
import textwrap
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Awaitable
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
|
||||
class FuncNotFoundError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,7 +13,7 @@ class FuncTool:
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
star_handler_metadata: StarHandlerMetadata
|
||||
handler: Awaitable
|
||||
|
||||
active: bool = True
|
||||
'''是否激活'''
|
||||
@@ -55,7 +39,7 @@ class FuncCall:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
star_handler_metadata: StarHandlerMetadata,
|
||||
handler: Awaitable,
|
||||
) -> None:
|
||||
"""
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
@@ -78,7 +62,7 @@ class FuncCall:
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
star_handler_metadata=star_handler_metadata,
|
||||
handler=handler,
|
||||
)
|
||||
self.func_list.append(_func)
|
||||
|
||||
@@ -180,7 +164,7 @@ class FuncCall:
|
||||
tool_callable = func.star_handler_metadata.handler
|
||||
break
|
||||
if not tool_callable:
|
||||
raise FuncNotFoundError(f"Request function {func_name} not found.")
|
||||
raise Exception(f"Request function {func_name} not found.")
|
||||
ret = await tool_callable(**args)
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
'''角色'''
|
||||
completion_text: str = None
|
||||
'''LLM 返回的文本'''
|
||||
tools_call_args: List[Dict[str, any]] = None
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = None
|
||||
'''工具调用名称'''
|
||||
@@ -5,8 +5,8 @@ from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core import logger
|
||||
from typing import TypedDict
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.provider.llm_response import LLMResponse
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from dataclasses import dataclass
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
@@ -112,13 +112,11 @@ class Provider(abc.ABC):
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
|
||||
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
|
||||
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
- 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。
|
||||
传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。
|
||||
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData():
|
||||
type: str # 提供商适配器名称,如 openai, ollama
|
||||
desc: str = "" # 提供商适配器描述.
|
||||
@@ -1,18 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from .tool import FuncCall
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest():
|
||||
prompt: str
|
||||
'''提示词'''
|
||||
session_id: str = ""
|
||||
'''会话 ID'''
|
||||
image_urls: List[str] = None
|
||||
'''图片 URL 列表'''
|
||||
func_tool: FuncCall = None
|
||||
'''工具'''
|
||||
contexts: List = None
|
||||
'''上下文'''
|
||||
system_prompt: str = ""
|
||||
'''系统提示词'''
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Dict, Type
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from .entites import ProviderMetaData
|
||||
from astrbot.core import logger
|
||||
from .tool import FuncCall
|
||||
from .func_tool_manager import FuncCall
|
||||
|
||||
provider_registry: List[ProviderMetaData] = []
|
||||
'''维护了通过装饰器注册的 Provider'''
|
||||
|
||||
@@ -3,92 +3,119 @@ import os
|
||||
from llmtuner.chat import ChatModel
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型")
|
||||
|
||||
@register_provider_adapter(
|
||||
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
|
||||
)
|
||||
class LLMTunerModelLoader(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
if not os.path.exists(provider_config['base_model_path']) or not os.path.exists(provider_config['adapter_model_path']):
|
||||
super().__init__(
|
||||
provider_config, provider_settings, persistant_history, db_helper
|
||||
)
|
||||
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
|
||||
provider_config["adapter_model_path"]
|
||||
):
|
||||
raise FileNotFoundError("模型文件路径不存在。")
|
||||
self.base_model_path = provider_config['base_model_path']
|
||||
self.adapter_model_path = provider_config['adapter_model_path']
|
||||
self.model = ChatModel({
|
||||
"model_name_or_path": self.base_model_path,
|
||||
"adapter_name_or_path": self.adapter_model_path,
|
||||
"template": provider_config['llmtuner_template'],
|
||||
"finetuning_type": provider_config['finetuning_type'],
|
||||
"quantization_bit": provider_config['quantization_bit'],
|
||||
})
|
||||
self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path))
|
||||
|
||||
self.base_model_path = provider_config["base_model_path"]
|
||||
self.adapter_model_path = provider_config["adapter_model_path"]
|
||||
self.model = ChatModel(
|
||||
{
|
||||
"model_name_or_path": self.base_model_path,
|
||||
"adapter_name_or_path": self.adapter_model_path,
|
||||
"template": provider_config["llmtuner_template"],
|
||||
"finetuning_type": provider_config["finetuning_type"],
|
||||
"quantization_bit": provider_config["quantization_bit"],
|
||||
}
|
||||
)
|
||||
self.set_model(
|
||||
os.path.basename(self.base_model_path)
|
||||
+ "_"
|
||||
+ os.path.basename(self.adapter_model_path)
|
||||
)
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
'''
|
||||
"""
|
||||
组装上下文。
|
||||
'''
|
||||
"""
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str] = None,
|
||||
tools = None,
|
||||
contexts: List=None,
|
||||
**kwargs) -> str:
|
||||
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
system_prompt = ""
|
||||
if not contexts:
|
||||
contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}]
|
||||
query_context = [
|
||||
*self.session_memory[session_id],
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
else:
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(contexts):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + contexts.pop(idx)["content"]
|
||||
|
||||
logger.debug(f"请求上下文:{contexts}")
|
||||
logger.debug(f"请求 System Prompt:{system_prompt}")
|
||||
|
||||
query_context = [*contexts, {"role": "user", "content": prompt}]
|
||||
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(query_context):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + query_context.pop(idx)["content"]
|
||||
|
||||
conf = {
|
||||
"messages": contexts,
|
||||
"messages": query_context,
|
||||
"system": system_prompt,
|
||||
}
|
||||
if tools:
|
||||
conf['tools'] = tools
|
||||
|
||||
if func_tool:
|
||||
conf["tools"] = func_tool
|
||||
|
||||
responses = await self.model.achat(**conf)
|
||||
logger.debug(f"返回上下文:{responses}")
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
|
||||
self.session_memory[session_id].append({"role": "user", "content": prompt})
|
||||
self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text})
|
||||
|
||||
if session_id:
|
||||
if not contexts:
|
||||
self.session_memory[session_id].append(
|
||||
{"role": "user", "content": prompt}
|
||||
)
|
||||
self.session_memory[session_id].append(
|
||||
{"role": "assistant", "content": responses[-1].response_text}
|
||||
)
|
||||
else:
|
||||
self.session_memory[session_id] = [
|
||||
*contexts,
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": responses[-1].response_text},
|
||||
]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
|
||||
return responses[-1].response_text
|
||||
|
||||
async def forget(self, session_id):
|
||||
logger.info("llmtuner reset")
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
|
||||
|
||||
async def set_key(self, key):
|
||||
pass
|
||||
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
@@ -96,9 +123,9 @@ class LLMTunerModelLoader(Provider):
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
if record["role"] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
elif record["role"] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
@@ -107,9 +134,9 @@ class LLMTunerModelLoader(Provider):
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
@@ -10,10 +10,10 @@ from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.llm_response import LLMResponse
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
|
||||
@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器")
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
@@ -131,31 +131,30 @@ class ProviderOpenAIOfficial(Provider):
|
||||
else:
|
||||
raise Exception("Internal Error")
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
else:
|
||||
context_query = contexts
|
||||
|
||||
logger.debug(f"请求上下文:{context_query}, {self.get_model()}")
|
||||
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**self.provider_config.get("model_config", {})
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
@@ -164,7 +163,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
@@ -174,7 +173,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
else:
|
||||
self.session_memory[session_id] = [*contexts, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
return llm_response
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, TypedDict, Union
|
||||
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
@@ -78,7 +78,8 @@ class Context:
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
|
||||
star_handlers_registry.append(md)
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
|
||||
@@ -9,7 +9,7 @@ from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAd
|
||||
from ..filter.permission import PermissionTypeFilter, PermissionType
|
||||
from ..filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.provider.tool import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -185,7 +185,7 @@ def register_llm_tool(name: str = None):
|
||||
"description": arg.description
|
||||
})
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md)
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler)
|
||||
|
||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
||||
return awaitable
|
||||
|
||||
@@ -14,6 +14,7 @@ from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .star_handler import star_handlers_registry
|
||||
|
||||
@@ -181,6 +182,10 @@ class PluginManager:
|
||||
logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}")
|
||||
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
|
||||
handler.handler = functools.partial(handler.handler, star_metadata.star_cls)
|
||||
# llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler.__module__ == star_metadata.module_path:
|
||||
func_tool.handler = functools.partial(func_tool.handler, star_metadata.star_cls)
|
||||
|
||||
else:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
|
||||
@@ -20,18 +20,18 @@ class Main(star.Star):
|
||||
f.write("{}")
|
||||
with open("data/astrbot-reminder.json", "r") as f:
|
||||
self.reminder_data = json.load(f)
|
||||
|
||||
|
||||
self._init_scheduler()
|
||||
self.scheduler.start()
|
||||
|
||||
async def _init_scheduler(self):
|
||||
def _init_scheduler(self):
|
||||
'''Initialize the scheduler.'''
|
||||
for group in self.reminder_data:
|
||||
for reminder in self.reminder_data[group]:
|
||||
if "datetime" in reminder:
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"]], id=group, run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"], reminder], run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
|
||||
elif "cron" in reminder:
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"])
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"], reminder], **self._parse_cron_expr(reminder["cron"]))
|
||||
|
||||
async def _save_data(self):
|
||||
'''Save the reminder data.'''
|
||||
@@ -67,14 +67,14 @@ class Main(star.Star):
|
||||
if cron_expression:
|
||||
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d], id=event.unified_msg_origin)
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d])
|
||||
if human_readable_cron:
|
||||
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
||||
else:
|
||||
d = { "text": text, "datetime": datetime_str }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], id=event.unified_msg_origin, run_date=datetime_scheduled)
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], run_date=datetime_scheduled)
|
||||
reminder_time = datetime_str
|
||||
await self._save_data()
|
||||
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
|
||||
|
||||
@@ -39,19 +39,6 @@ class Main(star.Star):
|
||||
ret = await self._tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
async def _request_from_llm(self, event: AstrMessageEvent, resources: str) -> str:
|
||||
'''使用 LLM 对文本进行生成'''
|
||||
|
||||
if self.context.get_using_provider() is None:
|
||||
raise ValueError("未找到可用的 LLM Provider,无法进行摘要总结")
|
||||
provider = self.context.get_using_provider()
|
||||
summary_prompt = f"""{event.get_message_str()}
|
||||
|
||||
# Provided Sources:
|
||||
{resources}"""
|
||||
ret = await provider.text_chat(summary_prompt, session_id=event.session_id)
|
||||
return ret.completion_text
|
||||
|
||||
@filter.command("websearch")
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str:
|
||||
websearch = self.context.get_config()['provider_settings']['web_search']
|
||||
@@ -84,20 +71,21 @@ class Main(star.Star):
|
||||
'''
|
||||
logger.info("web_searcher - search_from_search_engine: " + query)
|
||||
results = []
|
||||
RESULT_NUM = 5
|
||||
try:
|
||||
results = await self.google.search(query, 3)
|
||||
results = await self.google.search(query, RESULT_NUM)
|
||||
except BaseException as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search google failed")
|
||||
try:
|
||||
results = await self.bing_search.search(query, 3)
|
||||
results = await self.bing_search.search(query, RESULT_NUM)
|
||||
except BaseException as e:
|
||||
logger.error(f"bing search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search bing failed")
|
||||
try:
|
||||
results = await self.sogo_search.search(query, 3)
|
||||
results = await self.sogo_search.search(query, RESULT_NUM)
|
||||
except BaseException as e:
|
||||
logger.error(f"sogo search error: {e}")
|
||||
if len(results) == 0:
|
||||
@@ -111,12 +99,11 @@ class Main(star.Star):
|
||||
site_result = await self._get_from_url(i.url)
|
||||
except BaseException:
|
||||
site_result = ""
|
||||
site_result = site_result[:1000] + "..." if len(site_result) > 1000 else site_result
|
||||
site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
|
||||
resp = await self._request_from_llm(event, ret)
|
||||
event.set_result(MessageEventResult().message(resp))
|
||||
return ret
|
||||
|
||||
@llm_tool("fetch_url")
|
||||
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
|
||||
@@ -126,5 +113,4 @@ class Main(star.Star):
|
||||
url(string): The url of the website to fetch content from
|
||||
'''
|
||||
resp = await self._get_from_url(url)
|
||||
resp = await self._request_from_llm(event, resp)
|
||||
event.set_result(MessageEventResult().message(resp))
|
||||
return resp
|
||||
Reference in New Issue
Block a user