Compare commits

...

20 Commits

Author SHA1 Message Date
Soulter 33ec92258d Update config.py 2024-08-13 15:05:16 +08:00
Soulter a578edf137 fix: metrics
perf: aiocqhttp image url
2024-08-12 02:50:31 -04:00
Soulter f8949ebead perf: reboot after installing plugin 2024-08-11 23:24:37 -04:00
Soulter 141c91301f perf: Improve sleep time handling in QQOfficial and ProviderOpenAIOfficial 2024-08-11 23:24:37 -04:00
Soulter 8d95e67b5a Update README.md 2024-08-11 17:13:49 +08:00
Soulter 0633e7f25f perf: improve the effects of local function-calling 2024-08-11 03:55:31 -04:00
Soulter 266da0a9d8 fix: 修复重启时 aiocqhttp 没有正常退出导致端口占用的问题 2024-08-11 02:30:49 -04:00
Soulter 121c40f273 perf: raise error when badrequest 2024-08-11 01:49:33 -04:00
Soulter a876efb95f fix: 更新后覆盖文件路径错误 2024-08-10 04:35:07 -04:00
Soulter 95a8cc9498 fix: 修复部分字段未更新导致的错误 2024-08-10 04:13:24 -04:00
Soulter f02731055e fix: 修复插件启用忽略前缀之后可能的逻辑冲突 2024-08-10 03:25:50 -04:00
Soulter 1df83addfc update: add gcc 2024-08-10 14:59:00 +08:00
Soulter 9db43ac5e6 feat: 注册指令支持忽略指令前缀;快捷主动回复 2024-08-10 02:35:54 -04:00
Soulter 0f470cf96f Update README.md 2024-08-09 12:26:00 +08:00
Soulter da3fcb7b86 Merge pull request #186 from itgpt-com/master
优化 docker build
2024-08-08 22:15:48 +08:00
Soulter 73dd4703b9 Update .dockerignore 2024-08-08 22:15:05 +08:00
itgpt 0c679a0151 添加 .dockerignore 过滤 docker cp 不必要文件。缩小镜像 2024-08-08 16:21:30 +08:00
itgpt 1d6ea2dbe6 添加端口输出 2024-08-08 16:16:55 +08:00
itgpt 933df57654 优化 docker build 2024-08-08 15:53:44 +08:00
Soulter cbe761fc33 Update README.md 2024-08-07 00:49:00 +08:00
27 changed files with 459 additions and 327 deletions
+18
View File
@@ -0,0 +1,18 @@
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# github acions
.github/
.*ignore
.git/
# User-specific stuff
.idea/
# Byte-compiled / optimized / DLL files
__pycache__/
# Environments
.env
.venv
env/
venv*/
ENV/
.conda/
README*.md
+32 -13
View File
@@ -4,20 +4,39 @@ on:
release: release:
types: [published] types: [published]
workflow_dispatch: workflow_dispatch:
jobs: jobs:
publish-latest-docker-image: publish-docker:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: Build and publish docker image
steps: steps:
- name: Checkout - name: 拉取源码
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Build image with:
run: | fetch-depth: 1
git clone https://github.com/Soulter/AstrBot
cd AstrBot - name: 设置 QEMU
docker build -t ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest . uses: docker/setup-qemu-action@v3
- name: Publish image
run: | - name: 设置 Docker Buildx
docker login -u ${{ secrets.DOCKER_HUB_USERNAME }} -p ${{ secrets.DOCKER_HUB_PASSWORD }} uses: docker/setup-buildx-action@v3
docker push ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
- name: 登录到 DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: 构建和推送 Docker hub
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"
+12
View File
@@ -3,6 +3,18 @@ WORKDIR /AstrBot
COPY . /AstrBot/ COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python -m pip install -r requirements.txt RUN python -m pip install -r requirements.txt
EXPOSE 6185
EXPOSE 6186
CMD [ "python", "main.py" ] CMD [ "python", "main.py" ]
+23 -7
View File
@@ -1,6 +1,6 @@
<p align="center"> <p align="center">
<img width="806" alt="image" src="https://github.com/Soulter/AstrBot/assets/37870767/c6f057d9-46d7-4144-8116-00a962941746"> <img width="750" alt="image" src="https://github.com/Soulter/AstrBot/assets/37870767/c6f057d9-46d7-4144-8116-00a962941746">
</p> </p>
<div align="center"> <div align="center">
@@ -21,27 +21,43 @@
🌍 支持的消息平台 🌍 支持的消息平台
- QQ 群、QQ 频道(OneBot、QQ 官方接口) - QQ 群、QQ 频道(OneBot、QQ 官方接口)
- Telegram[astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件支持 - Telegram[astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件)
- WeChat(微信) ([astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件支持) - WeChat(微信) ([astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件)
🌍 支持的大模型一览 🌍 支持的大模型/底座
- OpenAI GPT、DallE 系列 - OpenAI GPT、DallE 系列
- Claude(由[LLMs插件](https://github.com/Soulter/llms)支持) - Claude(由[LLMs插件](https://github.com/Soulter/llms)支持)
- HuggingChat(由[LLMs插件](https://github.com/Soulter/llms)支持) - HuggingChat(由[LLMs插件](https://github.com/Soulter/llms)支持)
- Gemini(由[LLMs插件](https://github.com/Soulter/llms)支持) - Gemini(由[LLMs插件](https://github.com/Soulter/llms)支持)
- Ollama
- 几乎所有已知模型(可接入 [OneAPI](https://astrbot.soulter.top/docs/docs/adavanced/one-api)
🌍 机器人支持的能力一览: 🌍 机器人支持的能力一览:
- 大模型对话、人格、网页搜索 - 大模型对话、人格、网页搜索
- 可视化管理面板 - 可视化仪表盘
- 同时处理多平台消息 - 同时处理多平台消息
- 精确到个人的会话隔离 - 精确到个人的会话隔离
- 插件支持 - 插件支持
- 文本转图片回复(Markdown - 文本转图片回复(Markdown
## 🧩 插件支持 ## 🧩 插件
有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/center/docs/%E4%BD%BF%E7%94%A8/%E6%8F%92%E4%BB%B6) 有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/docs/get-started/plugin)
## 云部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ❤️ 贡献
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
对于新功能的添加,请先通过 Issue 进行讨论。
## 🔭 展望
- [ ] 更多、更开放的 LLM Agent 能力
## ✨ Demo ## ✨ Demo
+4
View File
@@ -78,6 +78,7 @@ class AstrBotBootstrap():
self.context.updator = self.updator self.context.updator = self.updator
self.context.plugin_updator = self.plugin_manager.updator self.context.plugin_updator = self.plugin_manager.updator
self.context.message_handler = self.message_handler self.context.message_handler = self.message_handler
self.context.command_manager = self.command_manager
# load plugins, plugins' commands. # load plugins, plugins' commands.
self.load_plugins() self.load_plugins()
@@ -99,6 +100,9 @@ class AstrBotBootstrap():
try: try:
result = await task result = await task
return result return result
except asyncio.CancelledError:
logger.info(f"{task.get_name()} 任务已取消。")
return
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。") logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。")
+8 -3
View File
@@ -134,8 +134,8 @@ class MessageHandler():
self.persist_manager.record_message(message.platform.platform_name, message.session_id) self.persist_manager.record_message(message.platform.platform_name, message.session_id)
# TODO: this should be configurable # TODO: this should be configurable
if not message.message_str: # if not message.message_str:
return MessageResult("Hi~") # return MessageResult("Hi~")
# check the rate limit # check the rate limit
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id): if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
@@ -158,6 +158,11 @@ class MessageHandler():
use_t2i=cmd_res.is_use_t2i use_t2i=cmd_res.is_use_t2i
) )
# next is the LLM part
if message.only_command:
return
# check if the message is a llm-wake-up command # check if the message is a llm-wake-up command
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix): if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。") logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
@@ -184,7 +189,7 @@ class MessageHandler():
try: try:
if web_search: if web_search:
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider) llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
else: else:
llm_result = await provider.text_chat( llm_result = await provider.text_chat(
prompt=msg_plain, prompt=msg_plain,
+8 -6
View File
@@ -192,10 +192,11 @@ class AstrBotDashBoard():
try: try:
logger.info(f"正在安装插件 {repo_url}") logger.info(f"正在安装插件 {repo_url}")
self.plugin_manager.install_plugin(repo_url) self.plugin_manager.install_plugin(repo_url)
logger.info(f"安装插件 {repo_url} 成功") threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
return Response( return Response(
status="success", status="success",
message="安装成功~", message="安装成功,机器人将在 2 秒内重启。",
data=None data=None
).__dict__ ).__dict__
except Exception as e: except Exception as e:
@@ -258,10 +259,11 @@ class AstrBotDashBoard():
try: try:
logger.info(f"正在更新插件 {plugin_name}") logger.info(f"正在更新插件 {plugin_name}")
self.plugin_manager.update_plugin(plugin_name) self.plugin_manager.update_plugin(plugin_name)
logger.info(f"更新插件 {plugin_name} 成功") threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response( return Response(
status="success", status="success",
message="更新成功~", message="更新成功,机器人将在 2 秒内重启。",
data=None data=None
).__dict__ ).__dict__
except Exception as e: except Exception as e:
@@ -311,7 +313,7 @@ class AstrBotDashBoard():
latest = False latest = False
try: try:
self.astrbot_updator.update(latest=latest, version=version) self.astrbot_updator.update(latest=latest, version=version)
threading.Thread(target=self.astrbot_updator._reboot, args=(3, )).start() threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
return Response( return Response(
status="success", status="success",
message="更新成功,机器人将在 3 秒内重启。", message="更新成功,机器人将在 3 秒内重启。",
@@ -374,7 +376,7 @@ class AstrBotDashBoard():
self.dashboard_data, self.context.config_helper.get_all()) self.dashboard_data, self.context.config_helper.get_all())
# 重启 # 重启
threading.Thread(target=self.astrbot_updator._reboot, threading.Thread(target=self.astrbot_updator._reboot,
args=(2, ), daemon=True).start() args=(2, self.context), daemon=True).start()
except Exception as e: except Exception as e:
raise e raise e
+1 -1
View File
@@ -53,7 +53,7 @@ if __name__ == "__main__":
check_env() check_env()
logger = LogManager.GetLogger( logger = LogManager.GetLogger(
log_name='astrbot', log_name='astrbot',
out_to_console=True, out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S") custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
) )
+2 -2
View File
@@ -117,11 +117,11 @@ class InternalCommandHandler:
success=False, success=False,
message_chain="你没有权限使用该指令", message_chain="你没有权限使用该指令",
) )
context.updator._reboot(5) context.updator._reboot(3, context)
return CommandResult( return CommandResult(
hit=True, hit=True,
success=True, success=True,
message_chain="AstrBot 将在 5s 后重启。", message_chain="AstrBot 将在 3s 后重启。",
) )
def plugin(self, message: AstrMessageEvent, context: Context): def plugin(self, message: AstrMessageEvent, context: Context):
+20 -1
View File
@@ -21,6 +21,7 @@ class CommandMetadata():
plugin_metadata: PluginMetadata plugin_metadata: PluginMetadata
handler: callable handler: callable
use_regex: bool = False use_regex: bool = False
ignore_prefix: bool = False
description: str = "" description: str = ""
class CommandManager(): class CommandManager():
@@ -35,6 +36,7 @@ class CommandManager():
priority: int, priority: int,
handler: callable, handler: callable,
use_regex: bool = False, use_regex: bool = False,
ignore_prefix: bool = False,
plugin_metadata: PluginMetadata = None, plugin_metadata: PluginMetadata = None,
): ):
''' '''
@@ -53,6 +55,7 @@ class CommandManager():
plugin_metadata=plugin_metadata, plugin_metadata=plugin_metadata,
handler=handler, handler=handler,
use_regex=use_regex, use_regex=use_regex,
ignore_prefix=ignore_prefix,
description=description description=description
) )
if plugin_metadata: if plugin_metadata:
@@ -75,9 +78,23 @@ class CommandManager():
priority=request.priority, priority=request.priority,
handler=request.handler, handler=request.handler,
use_regex=request.use_regex, use_regex=request.use_regex,
ignore_prefix=request.ignore_prefix,
plugin_metadata=plugin.metadata) plugin_metadata=plugin.metadata)
self.plugin_commands_waitlist = [] self.plugin_commands_waitlist = []
async def check_command_ignore_prefix(self, message_str: str) -> bool:
for _, command in self.commands:
command_metadata = self.commands_handler[command]
if command_metadata.ignore_prefix:
trig = False
if self.commands_handler[command].use_regex:
trig = self.command_parser.regex_match(message_str, command)
else:
trig = message_str.startswith(command)
if trig:
return True
return False
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult: async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
message_str = message_event.message_str message_str = message_event.message_str
for _, command in self.commands: for _, command in self.commands:
@@ -89,6 +106,8 @@ class CommandManager():
if trig: if trig:
logger.info(f"触发 {command} 指令。") logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context) command_result = await self.execute_handler(command, message_event, context)
if not command_result:
continue
if command_result.hit: if command_result.hit:
return command_result return command_result
+14 -3
View File
@@ -3,11 +3,13 @@ from typing import Union, Any, List
from nakuru.entities.components import Plain, At, Image, BaseMessageComponent from nakuru.entities.components import Plain, At, Image, BaseMessageComponent
from type.astrbot_message import AstrBotMessage from type.astrbot_message import AstrBotMessage
from type.command import CommandResult from type.command import CommandResult
from type.astrbot_message import MessageType
class Platform(): class Platform():
def __init__(self) -> None: def __init__(self, platform_name: str, context) -> None:
pass self.PLATFORM_NAME = platform_name
self.context = context
@abc.abstractmethod @abc.abstractmethod
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
@@ -30,6 +32,13 @@ class Platform():
发送消息(主动) 发送消息(主动)
''' '''
pass pass
@abc.abstractmethod
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
'''
发送消息(主动)
'''
pass
def parse_message_outline(self, message: AstrBotMessage) -> str: def parse_message_outline(self, message: AstrBotMessage) -> str:
''' '''
@@ -71,4 +80,6 @@ class Platform():
else: else:
rendered_images.append(Image.fromFileSystem(p)) rendered_images.append(Image.fromFileSystem(p))
return rendered_images return rendered_images
async def record_metrics(self):
self.context.metrics_uploader.increment_platform_stat(self.PLATFORM_NAME)
+2 -2
View File
@@ -58,7 +58,7 @@ class PlatformManager():
try: try:
qq_gocq = QQGOCQ(self.context, self.msg_handler) qq_gocq = QQGOCQ(self.context, self.msg_handler)
self.context.platforms.append(RegisteredPlatform( self.context.platforms.append(RegisteredPlatform(
platform_name="gocq", platform_instance=qq_gocq, origin="internal")) platform_name="nakuru", platform_instance=qq_gocq, origin="internal"))
await qq_gocq.run() await qq_gocq.run()
except BaseException as e: except BaseException as e:
logger.error("启动 nakuru 适配器时出现错误: " + str(e)) logger.error("启动 nakuru 适配器时出现错误: " + str(e))
@@ -81,7 +81,7 @@ class PlatformManager():
from model.platform.qq_official import QQOfficial from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(self.context, self.msg_handler) qqchannel_bot = QQOfficial(self.context, self.msg_handler)
self.context.platforms.append(RegisteredPlatform( self.context.platforms.append(RegisteredPlatform(
platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal"))
return qqchannel_bot.run() return qqchannel_bot.run()
except BaseException as e: except BaseException as e:
logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e)) logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e))
+47 -14
View File
@@ -18,6 +18,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AIOCQHTTP(Platform): class AIOCQHTTP(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None: def __init__(self, context: Context, message_handler: MessageHandler) -> None:
super().__init__("aiocqhttp", context)
self.message_handler = message_handler self.message_handler = message_handler
self.waiting = {} self.waiting = {}
self.context = context self.context = context
@@ -67,7 +68,9 @@ class AIOCQHTTP(Platform):
message_str += m['data']['text'].strip() message_str += m['data']['text'].strip()
abm.message.append(a) abm.message.append(a)
if t == 'image': if t == 'image':
a = Image(file=m['data']['file']) file = m['data']['file'] if 'file' in m['data'] else None
url = m['data']['url'] if 'url' in m['data'] else None
a = Image(file=file, url=url)
abm.message.append(a) abm.message.append(a)
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.message_str = message_str abm.message_str = message_str
@@ -99,26 +102,31 @@ class AIOCQHTTP(Platform):
return bot return bot
async def shutdown_trigger_placeholder(self): async def shutdown_trigger_placeholder(self):
while True: while self.context.running:
await asyncio.sleep(1) await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool: def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True # if message chain contains Plain components or
# At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE: if message.type == MessageType.FRIEND_MESSAGE:
return True return True, "friend"
for comp in message.message: for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id: if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True return True, "at"
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
return True, "command"
# check nicks # check nicks
if self.check_nick(message.message_str): if self.check_nick(message.message_str):
return True return True, "nick"
return False return False, "none"
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
logger.info( logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
if not self.pre_check(message): ok, reason = self.pre_check(message)
if not ok:
return return
# 解析 role # 解析 role
@@ -129,14 +137,30 @@ class AIOCQHTTP(Platform):
else: else:
role = 'member' role = 'member'
# parse unified message origin
unified_msg_origin = None
assert isinstance(message.raw_message, Event)
if message.type == MessageType.GROUP_MESSAGE:
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}"
elif message.type == MessageType.FRIEND_MESSAGE:
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}"
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
# construct astrbot message event # construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role) ame = AstrMessageEvent.from_astrbot_message(message,
self.context,
"aiocqhttp",
message.session_id,
role,
unified_msg_origin,
reason == "command") # only_command
# transfer control to message handler # transfer control to message handler
message_result = await self.message_handler.handle(ame) message_result = await self.message_handler.handle(ame)
if not message_result: return if not message_result: return
await self.reply_msg(message, message_result.result_message) await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback: if message_result.callback:
message_result.callback() message_result.callback()
@@ -147,7 +171,8 @@ class AIOCQHTTP(Platform):
async def reply_msg(self, async def reply_msg(self,
message: AstrBotMessage, message: AstrBotMessage,
result_message: list): result_message: list,
use_t2i: bool = None):
""" """
回复用户唤醒机器人的消息。(被动回复) 回复用户唤醒机器人的消息。(被动回复)
""" """
@@ -160,7 +185,7 @@ class AIOCQHTTP(Platform):
res = [Plain(text=res), ] res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture. # if image mode, put all Plain texts into a new picture.
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res) rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images: if rendered_images:
try: try:
@@ -173,9 +198,9 @@ class AIOCQHTTP(Platform):
await self._reply(message, res) await self._reply(message, res)
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]): async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
await self.record_metrics()
if isinstance(message_chain, str): if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ] message_chain = [Plain(text=message_chain), ]
ret = [] ret = []
image_idx = [] image_idx = []
for idx, segment in enumerate(message_chain): for idx, segment in enumerate(message_chain):
@@ -223,4 +248,12 @@ class AIOCQHTTP(Platform):
''' '''
await self._reply(target, result_message.message_chain) await self._reply(target, result_message.message_chain)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
if message_type == MessageType.GROUP_MESSAGE:
await self.send_msg({'group_id': int(target)}, result_message)
elif message_type == MessageType.FRIEND_MESSAGE:
await self.send_msg({'user_id': int(target)}, result_message)
else:
raise Exception("aiocqhttp: 无法识别的消息类型。")
+69 -13
View File
@@ -30,6 +30,7 @@ class FakeSource:
class QQGOCQ(Platform): class QQGOCQ(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None: def __init__(self, context: Context, message_handler: MessageHandler) -> None:
super().__init__("nakuru", context)
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
@@ -74,14 +75,17 @@ class QQGOCQ(Platform):
def pre_check(self, message: AstrBotMessage) -> bool: def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True # if message chain contains Plain components or At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE: if message.type == MessageType.FRIEND_MESSAGE:
return True return True, "friend"
for comp in message.message: for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id: if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True return True, "at"
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
return True, "command"
# check nicks # check nicks
if self.check_nick(message.message_str): if self.check_nick(message.message_str):
return True return True, "nick"
return False return False, "none"
def run(self): def run(self):
coro = self.client._run() coro = self.client._run()
@@ -95,7 +99,8 @@ class QQGOCQ(Platform):
(GroupMessage, FriendMessage, GuildMessage)) (GroupMessage, FriendMessage, GuildMessage))
# 判断是否响应消息 # 判断是否响应消息
if not self.pre_check(message): ok, reason = self.pre_check(message)
if not ok:
return return
# 解析 session_id # 解析 session_id
@@ -118,14 +123,35 @@ class QQGOCQ(Platform):
else: else:
role = 'member' role = 'member'
# parse unified message origin
unified_msg_origin = None
if message.type == MessageType.GROUP_MESSAGE:
assert isinstance(message.raw_message, GroupMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}"
elif message.type == MessageType.FRIEND_MESSAGE:
assert isinstance(message.raw_message, FriendMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}"
elif message.type == MessageType.GUILD_MESSAGE:
assert isinstance(message.raw_message, GuildMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}"
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
# construct astrbot message event # construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role) ame = AstrMessageEvent.from_astrbot_message(message,
self.context,
"nakuru",
session_id,
role,
unified_msg_origin,
reason == 'command') # only_command
# transfer control to message handler # transfer control to message handler
message_result = await self.message_handler.handle(ame) message_result = await self.message_handler.handle(ame)
if not message_result: return if not message_result: return
await self.reply_msg(message, message_result.result_message) await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback: if message_result.callback:
message_result.callback() message_result.callback()
@@ -135,7 +161,8 @@ class QQGOCQ(Platform):
async def reply_msg(self, async def reply_msg(self,
message: AstrBotMessage, message: AstrBotMessage,
result_message: List[BaseMessageComponent]): result_message: List[BaseMessageComponent],
use_t2i: bool = None):
""" """
回复用户唤醒机器人的消息。(被动回复) 回复用户唤醒机器人的消息。(被动回复)
""" """
@@ -152,7 +179,7 @@ class QQGOCQ(Platform):
res = [Plain(text=res), ] res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture. # if image mode, put all Plain texts into a new picture.
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res) rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images: if rendered_images:
try: try:
@@ -169,14 +196,26 @@ class QQGOCQ(Platform):
message_chain = [Plain(text=message_chain), ] message_chain = [Plain(text=message_chain), ]
is_dict = isinstance(source, dict) is_dict = isinstance(source, dict)
if source.type == "GuildMessage":
typ = None
if is_dict:
if "group_id" in source:
typ = "GroupMessage"
elif "user_id" in source:
typ = "FriendMessage"
elif "guild_id" in source:
typ = "GuildMessage"
else:
typ = source.type
if typ == "GuildMessage":
guild_id = source['guild_id'] if is_dict else source.guild_id guild_id = source['guild_id'] if is_dict else source.guild_id
chan_id = source['channel_id'] if is_dict else source.channel_id chan_id = source['channel_id'] if is_dict else source.channel_id
await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain) await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain)
elif source.type == "FriendMessage": elif typ == "FriendMessage":
user_id = source['user_id'] if is_dict else source.user_id user_id = source['user_id'] if is_dict else source.user_id
await self.client.sendFriendMessage(user_id, message_chain) await self.client.sendFriendMessage(user_id, message_chain)
elif source.type == "GroupMessage": elif typ == "GroupMessage":
group_id = source['group_id'] if is_dict else source.group_id group_id = source['group_id'] if is_dict else source.group_id
# 过长时forward发送 # 过长时forward发送
plain_text_len = 0 plain_text_len = 0
@@ -213,6 +252,23 @@ class QQGOCQ(Platform):
guild_id 不是频道号。 guild_id 不是频道号。
''' '''
await self._reply(target, result_message.message_chain) await self._reply(target, result_message.message_chain)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
'''
以主动的方式给用户、群或者频道发送一条消息。
`message_type` 为 MessageType 枚举类型。
- 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE
- 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE
- 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。
'''
if message_type == MessageType.FRIEND_MESSAGE:
await self.send_msg({"user_id": int(target)}, result_message)
elif message_type == MessageType.GROUP_MESSAGE:
await self.send_msg({"group_id": int(target)}, result_message)
elif message_type == MessageType.GUILD_MESSAGE:
await self.send_msg({"channel_id": int(target)}, result_message)
def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage: def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage:
abm = AstrBotMessage() abm = AstrBotMessage()
@@ -233,7 +289,7 @@ class QQGOCQ(Platform):
str(message.sender.user_id), str(message.sender.user_id),
str(message.sender.nickname) str(message.sender.nickname)
) )
abm.tag = "gocq" abm.tag = "nakuru"
abm.message = message.message abm.message = message.message
return abm return abm
+10 -6
View File
@@ -53,7 +53,7 @@ class botClient(Client):
class QQOfficial(Platform): class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None: def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
super().__init__() super().__init__("qqofficial", context)
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
@@ -112,7 +112,7 @@ class QQOfficial(Platform):
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.raw_message = message abm.raw_message = message
abm.message_id = message.id abm.message_id = message.id
abm.tag = "qqchan" abm.tag = "qqofficial"
msg: List[BaseMessageComponent] = [] msg: List[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage): if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
@@ -222,7 +222,7 @@ class QQOfficial(Platform):
if not message_result: if not message_result:
return return
ret = await self.reply_msg(message, message_result.result_message) ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback: if message_result.callback:
message_result.callback() message_result.callback()
@@ -234,7 +234,8 @@ class QQOfficial(Platform):
async def reply_msg(self, async def reply_msg(self,
message: AstrBotMessage, message: AstrBotMessage,
result_message: List[BaseMessageComponent]): result_message: List[BaseMessageComponent],
use_t2i: bool = None):
''' '''
回复频道消息 回复频道消息
''' '''
@@ -249,7 +250,7 @@ class QQOfficial(Platform):
msg_ref = None msg_ref = None
rendered_images = [] rendered_images = []
if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list): if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(result_message) rendered_images = await self.convert_to_t2i_chain(result_message)
if isinstance(result_message, list): if isinstance(result_message, list):
@@ -388,6 +389,9 @@ class QQOfficial(Platform):
if image_path: if image_path:
payload['file_image'] = image_path payload['file_image'] = image_path
await self._reply(**payload) await self._reply(**payload)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
raise NotImplementedError("qqofficial 不支持此方法。")
def wait_for_message(self, channel_id: int) -> AstrBotMessage: def wait_for_message(self, channel_id: int) -> AstrBotMessage:
''' '''
@@ -404,4 +408,4 @@ class QQOfficial(Platform):
cnt += 1 cnt += 1
if cnt > 300: if cnt > 300:
raise Exception("等待消息超时。") raise Exception("等待消息超时。")
time.sleep(1)() time.sleep(1)
+3 -2
View File
@@ -15,12 +15,13 @@ class CommandRegisterRequest():
handler: Callable handler: Callable
use_regex: bool = False use_regex: bool = False
plugin_name: str = None plugin_name: str = None
ignore_prefix: bool = False
class PluginCommandBridge(): class PluginCommandBridge():
def __init__(self, cached_plugins: RegisteredPlugins): def __init__(self, cached_plugins: RegisteredPlugins):
self.plugin_commands_waitlist: List[CommandRegisterRequest] = [] self.plugin_commands_waitlist: List[CommandRegisterRequest] = []
self.cached_plugins = cached_plugins self.cached_plugins = cached_plugins
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False): def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False):
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name)) self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix))
+35 -9
View File
@@ -5,6 +5,7 @@ import traceback
import uuid import uuid
import shutil import shutil
import yaml import yaml
import subprocess
from util.updator.plugin_updator import PluginUpdator from util.updator.plugin_updator import PluginUpdator
from util.io import remove_dir, download_file from util.io import remove_dir, download_file
@@ -84,8 +85,28 @@ class PluginManager():
def update_plugin_dept(self, path): def update_plugin_dept(self, path):
mirror = "https://mirrors.aliyun.com/pypi/simple/" mirror = "https://mirrors.aliyun.com/pypi/simple/"
py = sys.executable py = sys.executable
os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet") # os.system(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com")
process = subprocess.Popen(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com",
stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True)
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
output = output.strip()
if output.startswith("Requirement already satisfied"):
continue
if output.startswith("Using cached"):
continue
if output.startswith("Looking in indexes"):
continue
logger.info(output)
rc = process.poll()
def install_plugin(self, repo_url: str): def install_plugin(self, repo_url: str):
ppath = self.plugin_store_path ppath = self.plugin_store_path
@@ -95,10 +116,13 @@ class PluginManager():
plugin_path = self.updator.update(repo_url) plugin_path = self.updator.update(repo_url)
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
f.write(repo_url) f.write(repo_url)
self.check_plugin_dept_update()
ok, err = self.plugin_reload() return plugin_path
if not ok: # ok, err = self.plugin_reload()
raise Exception(err) # if not ok:
# raise Exception(err)
def download_from_repo_url(self, target_path: str, repo_url: str): def download_from_repo_url(self, target_path: str, repo_url: str):
repo_namespace = repo_url.split("/")[-2:] repo_namespace = repo_url.split("/")[-2:]
@@ -158,7 +182,7 @@ class PluginManager():
logger.info(f"正在加载插件 {root_dir_name} ...") logger.info(f"正在加载插件 {root_dir_name} ...")
# self.check_plugin_dept_update(cached_plugins, root_dir_name) self.check_plugin_dept_update(target_plugin=root_dir_name)
module = __import__("addons.plugins." + module = __import__("addons.plugins." +
root_dir_name + "." + p, fromlist=[p]) root_dir_name + "." + p, fromlist=[p])
@@ -227,10 +251,12 @@ class PluginManager():
# remove the temp dir # remove the temp dir
remove_dir(temp_dir) remove_dir(temp_dir)
self.check_plugin_dept_update()
ok, err = self.plugin_reload() # ok, err = self.plugin_reload()
if not ok: # if not ok:
raise Exception(err) # raise Exception(err)
def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata: def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata:
metadata = None metadata = None
+7 -7
View File
@@ -1,5 +1,5 @@
import os import os
import sys import asyncio
import json import json
import time import time
import tiktoken import tiktoken
@@ -14,8 +14,6 @@ from openai._exceptions import *
from astrbot.persist.helper import dbConn from astrbot.persist.helper import dbConn
from model.provider.provider import Provider from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
from logging import Logger from logging import Logger
from typing import List, Dict from typing import List, Dict
@@ -359,7 +357,7 @@ class ProviderOpenAIOfficial(Provider):
logger.warn(f"OpenAI 请求异常:{e}") logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e): if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。") raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
retry += 1 raise e
except RateLimitError as e: except RateLimitError as e:
if "You exceeded your current quota" in str(e): if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False self.keys_data[self.chosen_api_key] = False
@@ -369,7 +367,9 @@ class ProviderOpenAIOfficial(Provider):
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}") logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key() await self.switch_to_next_key()
rate_limit_retry += 1 rate_limit_retry += 1
time.sleep(1) await asyncio.sleep(1)
except NotFoundError as e:
raise e
except Exception as e: except Exception as e:
retry += 1 retry += 1
if retry >= 3: if retry >= 3:
@@ -381,7 +381,7 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(traceback.format_exc()) logger.warning(traceback.format_exc())
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。") logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
time.sleep(1) await asyncio.sleep(1)
assert isinstance(completion, ChatCompletion) assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}") logger.debug(f"openai completion: {completion.usage}")
@@ -452,7 +452,7 @@ class ProviderOpenAIOfficial(Provider):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。") raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。") logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
time.sleep(1) await asyncio.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool: async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False if session_id is None: return False
+13 -11
View File
@@ -2,7 +2,6 @@ from typing import Union, List, Callable
from dataclasses import dataclass from dataclasses import dataclass
from nakuru.entities.components import Plain, Image from nakuru.entities.components import Plain, Image
@dataclass @dataclass
class CommandItem(): class CommandItem():
''' '''
@@ -19,12 +18,17 @@ class CommandResult():
用于在Command中返回多个值 用于在Command中返回多个值
''' '''
def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None: def __init__(self,
hit: bool = True,
success: bool = True,
message_chain: list = [],
command_name: str = "unknown_command",
use_t2i: bool = None) -> None:
self.hit = hit self.hit = hit
self.success = success self.success = success
self.message_chain = message_chain self.message_chain = message_chain
self.command_name = command_name self.command_name = command_name
self.is_use_t2i = None # default self.is_use_t2i = use_t2i
def message(self, message: str): def message(self, message: str):
''' '''
@@ -63,14 +67,12 @@ class CommandResult():
self.message_chain = [Image.fromFileSystem(path), ] self.message_chain = [Image.fromFileSystem(path), ]
return self return self
# def use_t2i(self, use_t2i: bool): def use_t2i(self, use_t2i: bool):
# ''' '''
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
'''
# CommandResult().use_t2i(False) self.is_use_t2i = use_t2i
# ''' return self
# self.is_use_t2i = use_t2i
# return self
def _result_tuple(self): def _result_tuple(self):
return (self.success, self.message_chain, self.command_name) return (self.success, self.message_chain, self.command_name)
+2 -2
View File
@@ -1,4 +1,4 @@
VERSION = '3.3.7' VERSION = '3.3.9'
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"qqbot": { "qqbot": {
@@ -72,4 +72,4 @@ DEFAULT_CONFIG = {
"ws_reverse_host": "", "ws_reverse_host": "",
"ws_reverse_port": 0, "ws_reverse_port": 0,
} }
} }
+21 -10
View File
@@ -2,7 +2,14 @@ from typing import List, Union, Optional
from dataclasses import dataclass from dataclasses import dataclass
from type.register import RegisteredPlatform from type.register import RegisteredPlatform
from type.types import Context from type.types import Context
from type.astrbot_message import AstrBotMessage from type.astrbot_message import AstrBotMessage, MessageType
@dataclass
class MessageResult():
result_message: Union[str, list]
is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None
class AstrMessageEvent(): class AstrMessageEvent():
@@ -12,7 +19,9 @@ class AstrMessageEvent():
platform: RegisteredPlatform, platform: RegisteredPlatform,
role: str, role: str,
context: Context, context: Context,
session_id: str = None): session_id: str = None,
unified_msg_origin: str = None,
only_command: bool = False):
''' '''
AstrBot 消息事件。 AstrBot 消息事件。
@@ -22,6 +31,8 @@ class AstrMessageEvent():
`role`: 角色,`admin` or `member` `role`: 角色,`admin` or `member`
`context`: 全局对象 `context`: 全局对象
`session_id`: 会话id `session_id`: 会话id
`unified_msg_origin`: 统一消息来源
`only_command`: 是否只处理指令,而不使用 LLM 回复
''' '''
self.context = context self.context = context
self.message_str = message_str self.message_str = message_str
@@ -29,24 +40,24 @@ class AstrMessageEvent():
self.platform = platform self.platform = platform
self.role = role self.role = role
self.session_id = session_id self.session_id = session_id
self.unified_msg_origin = unified_msg_origin
self.only_command = only_command
def from_astrbot_message(message: AstrBotMessage, def from_astrbot_message(message: AstrBotMessage,
context: Context, context: Context,
platform_name: str, platform_name: str,
session_id: str, session_id: str,
role: str = "member"): role: str = "member",
unified_msg_origin: str = None,
only_command: bool = False):
ame = AstrMessageEvent(message.message_str, ame = AstrMessageEvent(message.message_str,
message, message,
context.find_platform(platform_name), context.find_platform(platform_name),
role, role,
context, context,
session_id) session_id,
unified_msg_origin,
only_command=only_command)
return ame return ame
@dataclass
class MessageResult():
result_message: Union[str, list]
is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None
+34 -2
View File
@@ -8,6 +8,8 @@ from util.t2i.renderer import TextToImageRenderer
from util.updator.astrbot_updator import AstrBotUpdator from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator from util.updator.plugin_updator import PluginUpdator
from type.command import CommandResult
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider from model.provider.provider import Provider
@@ -40,6 +42,9 @@ class Context:
self.image_uploader = ImageUploader() self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = [] self.ext_tasks: List[Task] = []
self.command_manager = None
self.running = True
# useless # useless
self.reply_prefix = "" self.reply_prefix = ""
@@ -50,7 +55,8 @@ class Context:
description: str, description: str,
priority: int, priority: int,
handler: callable, handler: callable,
use_regex: bool = False): use_regex: bool = False,
ignore_prefix: bool = False):
''' '''
注册插件指令。 注册插件指令。
@@ -60,8 +66,19 @@ class Context:
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
@param use_regex: 是否使用正则表达式匹配指令名。 @param use_regex: 是否使用正则表达式匹配指令名。
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
.. Example::
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
''' '''
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex) self.plugin_command_bridge.register_command(plugin_name,
command_name,
description,
priority,
handler,
use_regex,
ignore_prefix)
def register_task(self, coro: Awaitable, task_name: str): def register_task(self, coro: Awaitable, task_name: str):
''' '''
@@ -87,3 +104,18 @@ class Context:
return platform return platform
raise ValueError("couldn't find the platform you specified") raise ValueError("couldn't find the platform you specified")
async def send_message(self, unified_msg_origin: str, message: CommandResult):
'''
发送消息。
`unified_msg_origin`: 统一消息来源
`message`: 消息内容
'''
l = unified_msg_origin.split(":")
if len(l) != 3:
raise ValueError("Invalid unified_msg_origin")
platform_name, message_type, id = l
platform = self.find_platform(platform_name)
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
+45 -165
View File
@@ -1,9 +1,7 @@
from model.provider.provider import Provider
import json import json
import util.general_utils as gu
import time import time
import textwrap
class FuncCallJsonFormatError(Exception): class FuncCallJsonFormatError(Exception):
def __init__(self, msg): def __init__(self, msg):
@@ -22,14 +20,11 @@ class FuncNotFoundError(Exception):
class FuncCall(): class FuncCall():
def __init__(self, provider) -> None: def __init__(self, provider: Provider) -> None:
self.func_list = [] self.func_list = []
self.provider = provider self.provider = provider
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj=None) -> None: def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
if name == None or func_args == None or desc == None or func_obj == None:
raise FuncCallJsonFormatError(
"name, func_args, desc must be provided.")
params = { params = {
"type": "object", # hardcore here "type": "object", # hardcore here
"properties": {} "properties": {}
@@ -47,7 +42,7 @@ class FuncCall():
} }
self.func_list.append(self._func) self.func_list.append(self._func)
def func_dump(self, intent: int = 2) -> str: def func_dump(self) -> str:
_l = [] _l = []
for f in self.func_list: for f in self.func_list:
_l.append({ _l.append({
@@ -55,7 +50,7 @@ class FuncCall():
"parameters": f["parameters"], "parameters": f["parameters"],
"description": f["description"], "description": f["description"],
}) })
return json.dumps(_l, indent=intent, ensur_ascii=False) return json.dumps(_l, ensure_ascii=False)
def get_func(self) -> list: def get_func(self) -> list:
_l = [] _l = []
@@ -70,64 +65,36 @@ class FuncCall():
}) })
return _l return _l
def func_call(self, question, func_definition, is_task=False, tasks=None, taskindex=-1, is_summary=True, session_id=None): async def func_call(self, question: str, func_definition: str, session_id: str=None):
funccall_prompt = """ prompt = textwrap.dedent(f"""
我正实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(意味着你不是创造函数)。 ROLE:
下面会给你提供可能用到的函数相关信息和一个问题,你需要将其转换成给定的函数调用。 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
- 你的返回信息只含json,请严格仿照以下内容(不含注释),必须含有`res`,`func_call`字段:
```
{
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
"func_call": [ // 这是一个数组,里面包含了所有的函数调用,如果没有函数调用,那么这个数组是空数组。
{
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
"name": str, // 函数的名字
"args_type": {
"arg1": str, // 函数的参数的类型
"arg2": str,
...
},
"args": {
"arg1": any, // 函数的参数
"arg2": any,
...
}
},
... // 可能在这个问题中会有多个函数调用
],
}
```
- 如果用户的要求较复杂,允许返回多个函数调用,但需保证这些函数调用的顺序正确。
- 当问题没有提到给定的函数时,相当于提问方不打算使用function call功能,这时你可以在res中正常输出这个问题的回答(以AI的身份正常回答该问题,并将答案输出在res字段中,回答不要涉及到任何函数调用的内容,就只是正常讨论这个问题。)
提供的函数是: TOOLS:
可用的函数列表:
""" {func_definition}
prompt = f"{funccall_prompt}\n```\n{func_definition}\n```\n" LIMIT:
prompt += f""" 1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
用户的提问是: 2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
``` 3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
{question} 4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
```
"""
# if is_task: EXAMPLE:
# # task_prompt = f"\n任务列表为{str(tasks)}\n你目前进行到了任务{str(taskindex)}, **你不需要重新进行已经进行过的任务, 不要生成已经进行过的**" 1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
# prompt += task_prompt
# provider.forget() 用户的提问是:{question}
""")
_c = 0 _c = 0
while _c < 3: while _c < 3:
try: try:
res = self.provider.text_chat(prompt=prompt, session_id=session_id) res = await self.provider.text_chat(prompt, session_id)
print(res)
if res.find('```') != -1: if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')] res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT func_call json result",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
print(res)
res = json.loads(res) res = json.loads(res)
break break
except Exception as e: except Exception as e:
@@ -136,112 +103,25 @@ class FuncCall():
raise e raise e
if "The message you submitted was too long" in str(e): if "The message you submitted was too long" in str(e):
raise e raise e
if 'res' in res and not res['res']:
return "", False
invoke_func_res = "" tool_call_result = []
for tool in res:
if "func_call" in res and len(res["func_call"]) > 0: # 说明有函数调用
task_list = res["func_call"] func_name = tool["name"]
args = tool["args"]
invoke_func_res_list = [] # 调用函数
tool_callable = None
for res in task_list: for func in self.func_list:
# 说明有函数调用 if func["name"] == func_name:
func_name = res["name"] tool_callable = func["func_obj"]
# args_type = res["args_type"] break
args = res["args"] if not tool_callable:
# 调用函数 raise FuncNotFoundError(
# func = eval(func_name) f"Request function {func_name} not found.")
func_target = None ret = await tool_callable(**args)
for func in self.func_list: if ret:
if func["name"] == func_name: tool_call_result.append(str(ret))
func_target = func["func_obj"] return tool_call_result, True
break
if func_target == None:
raise FuncNotFoundError(
f"Request function {func_name} not found.")
t_res = str(func_target(**args))
invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n"
invoke_func_res_list.append(invoke_func_res)
gu.log(f"[FUNC| {func_name} invoked]",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
# print(str(t_res))
if is_summary:
# 生成返回结果
after_prompt = """
有以下内容:"""+invoke_func_res+"""
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
用户的提问是:
```""" + question + """```
- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
```json
{
"res": string, // 回答的内容
"func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
}
```
- 如果func_call_again为trueres请你设为空值,否则请你填写回答的内容。"""
_c = 0
while _c < 5:
try:
res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
# 截取```之间的内容
gu.log(
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
print(res)
gu.log(
"DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
if res.find('```') != -1:
res = res[res.find('```json') +
7: res.rfind('```')]
gu.log("REVGPT after_func_call json result",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
after_prompt_res = res
after_prompt_res = json.loads(after_prompt_res)
break
except Exception as e:
_c += 1
if _c == 5:
raise e
if "The message you submitted was too long" in str(e):
# 如果返回的内容太长了,那么就截取一部分
time.sleep(3)
invoke_func_res = invoke_func_res[:int(
len(invoke_func_res) / 2)]
after_prompt = """
函数返回以下内容:"""+invoke_func_res+"""
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
用户的提问是:
```""" + question + """```
- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
```json
{
"res": string, // 回答的内容
"func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
}
```
- 如果func_call_again为trueres请你设为空值,否则请你填写回答的内容。"""
else:
raise e
if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]:
# 如果需要重新调用函数
# 重新调用函数
gu.log("REVGPT func_call_again",
bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"])
res = self.func_call(question, func_definition)
return res, True
gu.log("REVGPT func callback:",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
# print(after_prompt_res["res"])
return after_prompt_res["res"], True
else:
return str(invoke_func_res_list), True
else:
# print(res["res"])
return res["res"], False
+21 -15
View File
@@ -1,13 +1,13 @@
import traceback import traceback
import random import random
import json import json
import asyncio
import aiohttp import aiohttp
import os import os
from readability import Document from readability import Document
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat.chat_completion_message_tool_call import Function
from openai._exceptions import *
from util.agent.func_call import FuncCall from util.agent.func_call import FuncCall
from util.websearch.config import HEADERS, USER_AGENTS from util.websearch.config import HEADERS, USER_AGENTS
from util.websearch.bing import Bing from util.websearch.bing import Bing
@@ -100,9 +100,9 @@ async def fetch_website_content(url):
return ret return ret
async def web_search(prompt, provider: Provider, session_id, official_fc=False): async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False):
''' '''
official_fc: 使用官方 function-calling @param official_fc: 使用官方 function-calling
''' '''
new_func_call = FuncCall(provider) new_func_call = FuncCall(provider)
@@ -127,9 +127,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
function_invoked_ret = "" function_invoked_ret = ""
if official_fc: if official_fc:
# we use official function-calling # we use official function-calling
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func()) try:
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
except BadRequestError as e:
# seems dont support function-calling
logger.error(f"error: {e}. Try to use local function-calling implementation")
return await web_search(prompt, provider, session_id, official_fc=False)
if isinstance(result, Function): if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}") logger.debug(f"function-calling: {result}")
func_obj = None func_obj = None
for i in new_func_call.func_list: for i in new_func_call.func_list:
if i["name"] == result.name: if i["name"] == result.name:
@@ -152,30 +157,31 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
args = { args = {
'question': prompt, 'question': prompt,
'func_definition': new_func_call.func_dump(), 'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
} }
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args) function_invoked_ret, has_func = await new_func_call.func_call(**args)
if not has_func:
return await provider.text_chat(prompt, session_id)
except BaseException as e: except BaseException as e:
res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)" logger.error(traceback.format_exc())
return res return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)"
has_func = True
if has_func: if has_func:
await provider.forget(session_id=session_id, ) await provider.forget(session_id=session_id)
summary_prompt = f""" summary_prompt = f"""
你是一个专业且高效的助手,你的任务是 你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; 1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
2. 简单地发表你对这个问题的简略看法。 2. 简单地发表你对这个问题的看法。
# 例子 # 例子
1. 从网上的信息来看,可以知道...我个人认为...你觉得呢? 1. 从网上的信息来看,可以知道...我个人认为...你觉得呢?
2. 根据网上的最新信息,可以得知...我觉得...你怎么看? 2. 根据网上的最新信息,可以得知...我觉得...你怎么看?
# 限制 # 限制
1. 限制在 200 字以内 1. 限制在 200-300 字;
2. 请**直接输出总结**,不要输出多余的内容和提示语。 2. 请**直接输出总结**,不要输出多余的内容和提示语。
# 相关材料 # 相关材料
{function_invoked_ret}""" {function_invoked_ret}"""
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id) ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
-30
View File
@@ -1,30 +0,0 @@
import time
import asyncio
import requests
import json
import sys
import psutil
from type.types import Context
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot')
def run_monitor(global_object: Context):
'''
监测机器性能
- Bot 内存使用量
- CPU 占用率
'''
start_time = time.time()
while True:
stat = global_object.dashboard_data.stats
# 程序占用的内存大小
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
stat['sys_perf'] = {
'memory': mem,
'cpu': psutil.cpu_percent()
}
stat['sys_start_time'] = start_time
time.sleep(30)
+3
View File
@@ -65,6 +65,9 @@ class MetricUploader():
except BaseException as e: except BaseException as e:
pass pass
await asyncio.sleep(30*60) await asyncio.sleep(30*60)
def increment_platform_stat(self, platform_name: str):
self.platform_stats[platform_name] = self.platform_stats.get(platform_name, 0) + 1
def clear(self): def clear(self):
self.platform_stats.clear() self.platform_stats.clear()
+5 -3
View File
@@ -9,7 +9,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotUpdator(RepoZipUpdator): class AstrBotUpdator(RepoZipUpdator):
def __init__(self): def __init__(self):
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases" self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
def terminate_child_processes(self): def terminate_child_processes(self):
@@ -30,9 +30,11 @@ class AstrBotUpdator(RepoZipUpdator):
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
def _reboot(self, delay: int = None): def _reboot(self, delay: int = None, context = None):
if delay: time.sleep(delay) # if delay: time.sleep(delay)
py = sys.executable py = sys.executable
context.running = False
time.sleep(3)
self.terminate_child_processes() self.terminate_child_processes()
py = py.replace(" ", "\\ ") py = py.replace(" ", "\\ ")
try: try: