perf: use request_llm

This commit is contained in:
Soulter
2024-12-21 16:35:16 +08:00
parent 193ff24f4c
commit 766f6a1ba2
24 changed files with 274 additions and 220 deletions
+13 -4
View File
@@ -1,9 +1,18 @@
from astrbot.core.message.message_event_result import (
MessageEventResult, MessageChain, CommandResult, EventResultType
)
MessageEventResult,
MessageChain,
CommandResult,
EventResultType,
ResultContentType,
)
from astrbot.core.platform import AstrMessageEvent
__all__ = [
'MessageEventResult', 'MessageChain', 'CommandResult', 'EventResultType', 'AstrMessageEvent'
]
"MessageEventResult",
"MessageChain",
"CommandResult",
"EventResultType",
"AstrMessageEvent",
"ResultContentType",
]
+1 -1
View File
@@ -1,2 +1,2 @@
from astrbot.core.provider import Provider, Personality, ProviderMetaData
from astrbot.core.provider.provider_request import ProviderRequest
from astrbot.core.provider.entites import ProviderRequest
+2 -2
View File
@@ -139,13 +139,13 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, result_type: EventResultType) -> 'MessageEventResult':
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
result_type (EventResultType): 事件处理的结果类型。
'''
self.result_type = result_type
self.result_content_type = typ
return self
@@ -1,6 +1,6 @@
from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger
class StrategySelector:
def __init__(self, config: dict) -> None:
@@ -15,7 +15,8 @@ class StrategySelector:
try:
from .baidu_aip import BaiduAipStrategy
except ImportError:
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
logger.warning("使用百度内容审核应该先 pip install baidu-aip")
return
self.enabled_strategies.append(
BaiduAipStrategy(
config["baidu_aip"]["app_id"],
@@ -8,7 +8,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.provider_request import ProviderRequest
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
@@ -22,8 +22,8 @@ class LLMRequestSubStage(Stage):
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if event.get_extra("provider_request"):
print("provider_request")
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
@@ -38,8 +38,9 @@ class LLMRequestSubStage(Stage):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
provider = self.ctx.plugin_manager.context.get_using_provider()
event.set_extra("provider_request", req)
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
@@ -61,27 +62,45 @@ class LLMRequestSubStage(Stage):
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
try:
logger.debug(f"请求 LLM{req.__dict__}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'tool':
# function calling
function_calling_result = {}
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
# 尝试调用工具函数
wrapper = self._call_handler(self.ctx, event, func_tool.star_handler_metadata.handler, **func_tool_args)
async for _ in wrapper:
yield
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
async for resp in wrapper:
if resp is not None:
function_calling_result[func_tool_name] = resp
else:
yield
event.clear_result() # 清除上一个 handler 的结果
except BaseException:
logger.error(traceback.format_exc())
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
if function_calling_result:
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
# 我们重新执行一遍这个 stage
req.func_tool = None # 暂时不支持递归工具调用
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
for tool_name, tool_result in function_calling_result.items():
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
req.prompt += extra_prompt
async for _ in self.process(event):
yield
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
@@ -30,8 +30,8 @@ class StarRequestSubStage(Stage):
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for _ in wrapper:
yield
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果
except Exception as e:
logger.error(traceback.format_exc())
+3 -2
View File
@@ -5,7 +5,8 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.provider_request import ProviderRequest
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core import logger
@register_stage
class ProcessStage(Stage):
@@ -24,12 +25,12 @@ class ProcessStage(Stage):
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
if activated_handlers:
async for resp in self.star_request_sub_stage.process(event):
# 生成器返回值处理
if isinstance(resp, ProviderRequest):
# Handler 的 LLM 请求
logger.debug(f"llm request -> {resp.prompt}")
event.set_extra("provider_request", resp)
async for _ in self.llm_request_sub_stage.process(event):
yield
+6 -17
View File
@@ -49,29 +49,18 @@ class Stage(abc.ABC):
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
async for ret in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
if event.get_result():
yield
yield ret
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 插件主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
else:
yield
yield ret
+28 -1
View File
@@ -7,6 +7,8 @@ from astrbot.core.platform.message_type import MessageType
from typing import List, Union
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
@dataclass
class MessageSesion:
@@ -281,4 +283,29 @@ class AstrMessageEvent(abc.ABC):
'''LLM 请求相关'''
def request_llm(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
contexts: List = None,
system_prompt: str = ""
) -> ProviderRequest:
'''
创建一个 LLM 请求。
Examples:
```py
yield event.request_llm(prompt="hi")
```
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
'''
return ProviderRequest(
prompt = prompt,
session_id = session_id,
image_urls = image_urls,
contexts = contexts,
system_prompt = system_prompt
)
+1 -1
View File
@@ -1,6 +1,6 @@
from .provider import Provider, Personality
from .provider_metadata import ProviderMetaData
from .entites import ProviderMetaData
__all__ = [
"Provider",
+40
View File
@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import List, Dict
from .func_tool_manager import FuncCall
@dataclass
class ProviderMetaData():
type: str # 提供商适配器名称,如 openai, ollama
desc: str = "" # 提供商适配器描述.
@dataclass
class ProviderRequest():
prompt: str
'''提示词'''
session_id: str = ""
'''会话 ID'''
image_urls: List[str] = None
'''图片 URL 列表'''
func_tool: FuncCall = None
'''工具'''
contexts: List = None
'''上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
'''
system_prompt: str = ""
'''系统提示词'''
@dataclass
class LLMResponse:
role: str
'''角色'''
completion_text: str = None
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = None
'''工具调用参数'''
tools_call_name: List[str] = None
'''工具调用名称'''
@@ -1,23 +1,7 @@
import json
import textwrap
from typing import Dict, List
from typing import Dict, List, Awaitable
from dataclasses import dataclass
from astrbot.core.star.star_handler import StarHandlerMetadata
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
class FuncNotFoundError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
@dataclass
@@ -29,7 +13,7 @@ class FuncTool:
name: str
parameters: Dict
description: str
star_handler_metadata: StarHandlerMetadata
handler: Awaitable
active: bool = True
'''是否激活'''
@@ -55,7 +39,7 @@ class FuncCall:
name: str,
func_args: list,
desc: str,
star_handler_metadata: StarHandlerMetadata,
handler: Awaitable,
) -> None:
"""
为函数调用function-calling / tools-use添加工具
@@ -78,7 +62,7 @@ class FuncCall:
name=name,
parameters=params,
description=desc,
star_handler_metadata=star_handler_metadata,
handler=handler,
)
self.func_list.append(_func)
@@ -180,7 +164,7 @@ class FuncCall:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
raise FuncNotFoundError(f"Request function {func_name} not found.")
raise Exception(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
if ret:
tool_call_result.append(str(ret))
-13
View File
@@ -1,13 +0,0 @@
from typing import Dict, List
from dataclasses import dataclass
@dataclass
class LLMResponse:
role: str
'''角色'''
completion_text: str = None
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = None
'''工具调用参数'''
tools_call_name: List[str] = None
'''工具调用名称'''
+3 -5
View File
@@ -5,8 +5,8 @@ from typing import List
from astrbot.core.db import BaseDatabase
from astrbot.core import logger
from typing import TypedDict
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.llm_response import LLMResponse
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
@@ -112,13 +112,11 @@ class Provider(abc.ABC):
kwargs: 其他参数
Notes:
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
- 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。
传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。
'''
raise NotImplementedError()
@@ -1,6 +0,0 @@
from dataclasses import dataclass
@dataclass
class ProviderMetaData():
type: str # 提供商适配器名称,如 openai, ollama
desc: str = "" # 提供商适配器描述.
-18
View File
@@ -1,18 +0,0 @@
from dataclasses import dataclass
from typing import List
from .tool import FuncCall
@dataclass
class ProviderRequest():
prompt: str
'''提示词'''
session_id: str = ""
'''会话 ID'''
image_urls: List[str] = None
'''图片 URL 列表'''
func_tool: FuncCall = None
'''工具'''
contexts: List = None
'''上下文'''
system_prompt: str = ""
'''系统提示词'''
+2 -2
View File
@@ -1,7 +1,7 @@
from typing import List, Dict, Type
from .provider_metadata import ProviderMetaData
from .entites import ProviderMetaData
from astrbot.core import logger
from .tool import FuncCall
from .func_tool_manager import FuncCall
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
@@ -3,92 +3,119 @@ import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from astrbot import logger
from ..register import register_provider_adapter
@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型")
@register_provider_adapter(
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
)
class LLMTunerModelLoader(Provider):
def __init__(
self,
provider_config: dict,
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
db_helper: BaseDatabase,
persistant_history=True,
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
if not os.path.exists(provider_config['base_model_path']) or not os.path.exists(provider_config['adapter_model_path']):
super().__init__(
provider_config, provider_settings, persistant_history, db_helper
)
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
provider_config["adapter_model_path"]
):
raise FileNotFoundError("模型文件路径不存在。")
self.base_model_path = provider_config['base_model_path']
self.adapter_model_path = provider_config['adapter_model_path']
self.model = ChatModel({
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config['llmtuner_template'],
"finetuning_type": provider_config['finetuning_type'],
"quantization_bit": provider_config['quantization_bit'],
})
self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path))
self.base_model_path = provider_config["base_model_path"]
self.adapter_model_path = provider_config["adapter_model_path"]
self.model = ChatModel(
{
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config["llmtuner_template"],
"finetuning_type": provider_config["finetuning_type"],
"quantization_bit": provider_config["quantization_bit"],
}
)
self.set_model(
os.path.basename(self.base_model_path)
+ "_"
+ os.path.basename(self.adapter_model_path)
)
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
"""
组装上下文。
'''
"""
return {"role": "user", "content": text}
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str] = None,
tools = None,
contexts: List=None,
**kwargs) -> str:
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
system_prompt = ""
if not contexts:
contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}]
query_context = [
*self.session_memory[session_id],
{"role": "user", "content": prompt},
]
system_prompt = self.curr_personality["prompt"]
else:
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(contexts):
if context["role"] == "system":
system_idxs.append(idx)
for idx in reversed(system_idxs):
system_prompt += " " + contexts.pop(idx)["content"]
logger.debug(f"请求上下文:{contexts}")
logger.debug(f"请求 System Prompt{system_prompt}")
query_context = [*contexts, {"role": "user", "content": prompt}]
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(query_context):
if context["role"] == "system":
system_idxs.append(idx)
for idx in reversed(system_idxs):
system_prompt += " " + query_context.pop(idx)["content"]
conf = {
"messages": contexts,
"messages": query_context,
"system": system_prompt,
}
if tools:
conf['tools'] = tools
if func_tool:
conf["tools"] = func_tool
responses = await self.model.achat(**conf)
logger.debug(f"返回上下文:{responses}")
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
self.session_memory[session_id].append({"role": "user", "content": prompt})
self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text})
if session_id:
if not contexts:
self.session_memory[session_id].append(
{"role": "user", "content": prompt}
)
self.session_memory[session_id].append(
{"role": "assistant", "content": responses[-1].response_text}
)
else:
self.session_memory[session_id] = [
*contexts,
{"role": "user", "content": prompt},
{"role": "assistant", "content": responses[-1].response_text},
]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
return responses[-1].response_text
async def forget(self, session_id):
logger.info("llmtuner reset")
self.session_memory[session_id] = []
return True
async def get_current_key(self):
return "none"
async def set_key(self, key):
pass
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
@@ -96,9 +123,9 @@ class LLMTunerModelLoader(Provider):
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
elif record["role"] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
@@ -107,9 +134,9 @@ class LLMTunerModelLoader(Provider):
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
return paged_contexts, total_pages
+24 -20
View File
@@ -10,10 +10,10 @@ from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.llm_response import LLMResponse
from astrbot.core.provider.entites import LLMResponse
@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器")
class ProviderOpenAIOfficial(Provider):
@@ -131,31 +131,30 @@ class ProviderOpenAIOfficial(Provider):
else:
raise Exception("Internal Error")
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
else:
context_query = contexts
logger.debug(f"请求上下文:{context_query}, {self.get_model()}")
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
}
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
@@ -164,7 +163,7 @@ class ProviderOpenAIOfficial(Provider):
self.pop_record(session_id)
logger.warning(traceback.format_exc())
if llm_response.role == "assistant":
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
@@ -174,7 +173,12 @@ class ProviderOpenAIOfficial(Provider):
"role": "assistant",
"content": llm_response.completion_text
})
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
else:
self.session_memory[session_id] = [*contexts, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
return llm_response
+4 -3
View File
@@ -1,10 +1,10 @@
from asyncio import Queue
from typing import List, TypedDict, Union
from astrbot.core.provider import Provider
from astrbot.core.provider.provider import Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
@@ -78,7 +78,8 @@ class Context:
event_filters=[],
desc=desc
)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
+2 -2
View File
@@ -9,7 +9,7 @@ from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAd
from ..filter.permission import PermissionTypeFilter, PermissionType
from ..filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.provider.tool import SUPPORTED_TYPES
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core import logger
@@ -185,7 +185,7 @@ def register_llm_tool(name: str = None):
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md)
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
+5
View File
@@ -14,6 +14,7 @@ from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from astrbot.core.provider.register import llm_tools
from .star_handler import star_handlers_registry
@@ -181,6 +182,10 @@ class PluginManager:
logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}")
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
handler.handler = functools.partial(handler.handler, star_metadata.star_cls)
# llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == star_metadata.module_path:
func_tool.handler = functools.partial(func_tool.handler, star_metadata.star_cls)
else:
# v3.4.0 以前的方式注册插件
+6 -6
View File
@@ -20,18 +20,18 @@ class Main(star.Star):
f.write("{}")
with open("data/astrbot-reminder.json", "r") as f:
self.reminder_data = json.load(f)
self._init_scheduler()
self.scheduler.start()
async def _init_scheduler(self):
def _init_scheduler(self):
'''Initialize the scheduler.'''
for group in self.reminder_data:
for reminder in self.reminder_data[group]:
if "datetime" in reminder:
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"]], id=group, run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"], reminder], run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
elif "cron" in reminder:
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"])
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"], reminder], **self._parse_cron_expr(reminder["cron"]))
async def _save_data(self):
'''Save the reminder data.'''
@@ -67,14 +67,14 @@ class Main(star.Star):
if cron_expression:
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
self.reminder_data[event.unified_msg_origin].append(d)
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d], id=event.unified_msg_origin)
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d])
if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else:
d = { "text": text, "datetime": datetime_str }
self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], id=event.unified_msg_origin, run_date=datetime_scheduled)
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], run_date=datetime_scheduled)
reminder_time = datetime_str
await self._save_data()
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
+7 -21
View File
@@ -39,19 +39,6 @@ class Main(star.Star):
ret = await self._tidy_text(soup.get_text())
return ret
async def _request_from_llm(self, event: AstrMessageEvent, resources: str) -> str:
'''使用 LLM 对文本进行生成'''
if self.context.get_using_provider() is None:
raise ValueError("未找到可用的 LLM Provider,无法进行摘要总结")
provider = self.context.get_using_provider()
summary_prompt = f"""{event.get_message_str()}
# Provided Sources:
{resources}"""
ret = await provider.text_chat(summary_prompt, session_id=event.session_id)
return ret.completion_text
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str:
websearch = self.context.get_config()['provider_settings']['web_search']
@@ -84,20 +71,21 @@ class Main(star.Star):
'''
logger.info("web_searcher - search_from_search_engine: " + query)
results = []
RESULT_NUM = 5
try:
results = await self.google.search(query, 3)
results = await self.google.search(query, RESULT_NUM)
except BaseException as e:
logger.error(f"google search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search google failed")
try:
results = await self.bing_search.search(query, 3)
results = await self.bing_search.search(query, RESULT_NUM)
except BaseException as e:
logger.error(f"bing search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search bing failed")
try:
results = await self.sogo_search.search(query, 3)
results = await self.sogo_search.search(query, RESULT_NUM)
except BaseException as e:
logger.error(f"sogo search error: {e}")
if len(results) == 0:
@@ -111,12 +99,11 @@ class Main(star.Star):
site_result = await self._get_from_url(i.url)
except BaseException:
site_result = ""
site_result = site_result[:1000] + "..." if len(site_result) > 1000 else site_result
site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
resp = await self._request_from_llm(event, ret)
event.set_result(MessageEventResult().message(resp))
return ret
@llm_tool("fetch_url")
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
@@ -126,5 +113,4 @@ class Main(star.Star):
url(string): The url of the website to fetch content from
'''
resp = await self._get_from_url(url)
resp = await self._request_from_llm(event, resp)
event.set_result(MessageEventResult().message(resp))
return resp