Compare commits

...

35 Commits

Author SHA1 Message Date
Soulter a876efb95f fix: 更新后覆盖文件路径错误 2024-08-10 04:35:07 -04:00
Soulter 95a8cc9498 fix: 修复部分字段未更新导致的错误 2024-08-10 04:13:24 -04:00
Soulter f02731055e fix: 修复插件启用忽略前缀之后可能的逻辑冲突 2024-08-10 03:25:50 -04:00
Soulter 1df83addfc update: add gcc 2024-08-10 14:59:00 +08:00
Soulter 9db43ac5e6 feat: 注册指令支持忽略指令前缀;快捷主动回复 2024-08-10 02:35:54 -04:00
Soulter 0f470cf96f Update README.md 2024-08-09 12:26:00 +08:00
Soulter da3fcb7b86 Merge pull request #186 from itgpt-com/master
优化 docker build
2024-08-08 22:15:48 +08:00
Soulter 73dd4703b9 Update .dockerignore 2024-08-08 22:15:05 +08:00
itgpt 0c679a0151 添加 .dockerignore 过滤 docker cp 不必要文件。缩小镜像 2024-08-08 16:21:30 +08:00
itgpt 1d6ea2dbe6 添加端口输出 2024-08-08 16:16:55 +08:00
itgpt 933df57654 优化 docker build 2024-08-08 15:53:44 +08:00
Soulter cbe761fc33 Update README.md 2024-08-07 00:49:00 +08:00
Soulter 14dbdb2d83 feat: 插件支持正则匹配 2024-08-05 12:12:00 -04:00
Soulter abda226d63 Merge pull request #183 from irorange27/master
fix: fix logo syntax warning
2024-08-05 23:37:57 +08:00
niina a2dc6f0a49 fix: fix logo syntax warning 2024-08-05 22:53:45 +08:00
Soulter 7a94c26333 fix: 修复 wake 唤醒词无法触发 command 的问题 2024-08-05 05:02:57 -04:00
Soulter 9b1ffb384b perf: 优化aiocqhttp适配器的异常处理 2024-08-05 04:46:12 -04:00
Soulter 9566bfe122 workaround for issue #181 2024-08-03 17:03:38 +08:00
Soulter 89ff103bda chore: Add mimetypes workaround for issue #188 2024-08-03 17:02:45 +08:00
Soulter 6c788db53a Merge remote-tracking branch 'refs/remotes/origin/master' 2024-08-03 16:17:25 +08:00
Soulter 344b5fa419 fix: f-string eror 2024-08-03 16:17:04 +08:00
Soulter c6d161b837 Update README.md 2024-08-03 15:04:20 +08:00
Soulter 2065ba0c60 Update README.md 2024-08-03 01:05:27 +08:00
Soulter a481fd1a3e fix: Strip leading and trailing whitespace from llm_wake_prefix 2024-08-02 23:17:35 +08:00
Soulter c50bcdbdb9 fix: Register command only if plugin is found 2024-08-02 22:48:04 +08:00
Soulter 36a2a7632c fix: 优化初始化、消息处理时的配置读取过程,减少性能损耗 2024-07-31 23:38:31 +08:00
Soulter e77b7014e6 fix: 修复更新、卸载插件时的报错 2024-07-30 09:15:45 +08:00
Soulter d57fd0f827 fix: metadata is not seralizable 2024-07-29 09:47:42 +08:00
Soulter 6a83d2a62a update version 2024-07-28 12:11:07 +08:00
Soulter 2d29726c18 fix: 修复带空格路径导致的重启失败 2024-07-28 11:55:57 +08:00
Soulter b241b0f954 update version 2024-07-27 12:31:15 -04:00
Soulter 171dd1dc02 feat: qq 官方机器人接口支持C2C 2024-07-27 12:30:09 -04:00
Soulter af62d969d7 perf: 更改 send_msg 接口 2024-07-27 11:26:02 -04:00
Soulter c4fd9a66c6 update version to 3.3.3 2024-07-27 11:08:51 -04:00
Soulter d191997a39 feat: aiocqhttp 适配器适配主动发送消息接口 2024-07-27 11:07:26 -04:00
28 changed files with 646 additions and 333 deletions
+18
View File
@@ -0,0 +1,18 @@
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# github acions
.github/
.*ignore
.git/
# User-specific stuff
.idea/
# Byte-compiled / optimized / DLL files
__pycache__/
# Environments
.env
.venv
env/
venv*/
ENV/
.conda/
README*.md
+32 -13
View File
@@ -4,20 +4,39 @@ on:
release: release:
types: [published] types: [published]
workflow_dispatch: workflow_dispatch:
jobs: jobs:
publish-latest-docker-image: publish-docker:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: Build and publish docker image
steps: steps:
- name: Checkout - name: 拉取源码
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: Build image with:
run: | fetch-depth: 1
git clone https://github.com/Soulter/AstrBot
cd AstrBot - name: 设置 QEMU
docker build -t ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest . uses: docker/setup-qemu-action@v3
- name: Publish image
run: | - name: 设置 Docker Buildx
docker login -u ${{ secrets.DOCKER_HUB_USERNAME }} -p ${{ secrets.DOCKER_HUB_PASSWORD }} uses: docker/setup-buildx-action@v3
docker push ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
- name: 登录到 DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: 构建和推送 Docker hub
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"
+12
View File
@@ -3,6 +3,18 @@ WORKDIR /AstrBot
COPY . /AstrBot/ COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python -m pip install -r requirements.txt RUN python -m pip install -r requirements.txt
EXPOSE 6185
EXPOSE 6186
CMD [ "python", "main.py" ] CMD [ "python", "main.py" ]
+21 -9
View File
@@ -1,6 +1,6 @@
<p align="center"> <p align="center">
<img width="806" alt="image" src="https://github.com/Soulter/AstrBot/assets/37870767/c6f057d9-46d7-4144-8116-00a962941746"> <img width="750" alt="image" src="https://github.com/Soulter/AstrBot/assets/37870767/c6f057d9-46d7-4144-8116-00a962941746">
</p> </p>
<div align="center"> <div align="center">
@@ -12,36 +12,48 @@
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
</a> </a>
<a href="https://astrbot.soulter.top/center">项目部署</a> <a href="https://astrbot.soulter.top/docs/main">快速开始</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a> <a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
<a href="https://astrbot.soulter.top/center/docs/%E5%BC%80%E5%8F%91/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91">插件开发</a> <a href="https://astrbot.soulter.top/docs/develop/plugin4p">插件开发</a>
</div> </div>
## 🛠️ 功能 ## 🛠️ 功能
🌍 支持的消息平台 🌍 支持的消息平台
- QQ 群、QQ 频道(OneBot、QQ 官方接口) - QQ 群、QQ 频道(OneBot、QQ 官方接口)
- Telegram[astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件支持 - Telegram[astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件)
- WeChat(微信) ([astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件支持) - WeChat(微信) ([astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件)
🌍 支持的大模型一览 🌍 支持的大模型/底座
- OpenAI GPT、DallE 系列 - OpenAI GPT、DallE 系列
- Claude(由[LLMs插件](https://github.com/Soulter/llms)支持) - Claude(由[LLMs插件](https://github.com/Soulter/llms)支持)
- HuggingChat(由[LLMs插件](https://github.com/Soulter/llms)支持) - HuggingChat(由[LLMs插件](https://github.com/Soulter/llms)支持)
- Gemini(由[LLMs插件](https://github.com/Soulter/llms)支持) - Gemini(由[LLMs插件](https://github.com/Soulter/llms)支持)
- Ollama
- 几乎所有已知模型(可接入 [OneAPI](https://astrbot.soulter.top/docs/docs/adavanced/one-api)
🌍 机器人支持的能力一览: 🌍 机器人支持的能力一览:
- 大模型对话、人格、网页搜索 - 大模型对话、人格、网页搜索
- 可视化管理面板 - 可视化仪表盘
- 同时处理多平台消息 - 同时处理多平台消息
- 精确到个人的会话隔离 - 精确到个人的会话隔离
- 插件支持 - 插件支持
- 文本转图片回复(Markdown - 文本转图片回复(Markdown
## 🧩 插件支持 ## 🧩 插件
有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/center/docs/%E4%BD%BF%E7%94%A8/%E6%8F%92%E4%BB%B6) 有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/docs/get-started/plugin)
## ❤️ 贡献
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
对于新功能的添加,请先通过 Issue 进行讨论。
## 🔭 展望
- [ ] 更多、更开放的 LLM Agent 能力
## ✨ Demo ## ✨ Demo
+20 -7
View File
@@ -10,6 +10,7 @@ from model.plugin.manager import PluginManager
from model.platform.manager import PlatformManager from model.platform.manager import PlatformManager
from typing import Dict, List, Union from typing import Dict, List, Union
from type.types import Context from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
from logging import Logger from logging import Logger
from util.cmd_config import CmdConfig from util.cmd_config import CmdConfig
@@ -23,14 +24,25 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotBootstrap(): class AstrBotBootstrap():
def __init__(self) -> None: def __init__(self) -> None:
self.context = Context() self.context = Context()
self.config_helper: CmdConfig = CmdConfig() self.config_helper = CmdConfig()
# load configs and ensure the backward compatibility # load configs and ensure the backward compatibility
init_configs()
try_migrate_config() try_migrate_config()
self.configs = inject_to_context(self.context)
logger.info("AstrBot v" + self.context.version)
self.context.config_helper = self.config_helper self.context.config_helper = self.config_helper
self.context.base_config = self.config_helper.cached_config
self.context.default_personality = {
"name": "default",
"prompt": self.context.base_config.get("default_personality_str", ""),
}
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
self.context.nick = nick_qq
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
self.context.version = VERSION
logger.info("AstrBot v" + self.context.version)
# apply proxy settings # apply proxy settings
http_proxy = self.context.base_config.get("http_proxy") http_proxy = self.context.base_config.get("http_proxy")
@@ -66,6 +78,7 @@ class AstrBotBootstrap():
self.context.updator = self.updator self.context.updator = self.updator
self.context.plugin_updator = self.plugin_manager.updator self.context.plugin_updator = self.plugin_manager.updator
self.context.message_handler = self.message_handler self.context.message_handler = self.message_handler
self.context.command_manager = self.command_manager
# load plugins, plugins' commands. # load plugins, plugins' commands.
self.load_plugins() self.load_plugins()
@@ -93,9 +106,9 @@ class AstrBotBootstrap():
await asyncio.sleep(5) await asyncio.sleep(5)
def load_llm(self): def load_llm(self):
if 'openai' in self.configs and \ if 'openai' in self.config_helper.cached_config and \
len(self.configs['openai']['key']) and \ len(self.config_helper.cached_config['openai']['key']) and \
self.configs['openai']['key'][0] is not None: self.config_helper.cached_config['openai']['key'][0] is not None:
from model.provider.openai_official import ProviderOpenAIOfficial from model.provider.openai_official import ProviderOpenAIOfficial
from model.command.openai_official_handler import OpenAIOfficialCommandHandler from model.command.openai_official_handler import OpenAIOfficialCommandHandler
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
+14 -4
View File
@@ -112,9 +112,11 @@ class MessageHandler():
self.rate_limit_helper = RateLimitHelper(context) self.rate_limit_helper = RateLimitHelper(context)
self.content_safety_helper = ContentSafetyHelper(context) self.content_safety_helper = ContentSafetyHelper(context)
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix'] self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
if self.llm_wake_prefix:
self.llm_wake_prefix = self.llm_wake_prefix.strip()
self.nicks = self.context.nick self.nicks = self.context.nick
self.provider = provider self.provider = provider
self.reply_prefix = self.context.reply_prefix self.reply_prefix = str(self.context.reply_prefix)
def set_provider(self, provider: Provider): def set_provider(self, provider: Provider):
self.provider = provider self.provider = provider
@@ -132,8 +134,8 @@ class MessageHandler():
self.persist_manager.record_message(message.platform.platform_name, message.session_id) self.persist_manager.record_message(message.platform.platform_name, message.session_id)
# TODO: this should be configurable # TODO: this should be configurable
if not message.message_str: # if not message.message_str:
return MessageResult("Hi~") # return MessageResult("Hi~")
# check the rate limit # check the rate limit
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id): if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
@@ -144,6 +146,7 @@ class MessageHandler():
if msg_plain.startswith(nick): if msg_plain.startswith(nick):
msg_plain = msg_plain.removeprefix(nick) msg_plain = msg_plain.removeprefix(nick)
break break
message.message_str = msg_plain
# scan candidate commands # scan candidate commands
cmd_res = await self.command_manager.scan_command(message, self.context) cmd_res = await self.command_manager.scan_command(message, self.context)
@@ -155,11 +158,18 @@ class MessageHandler():
use_t2i=cmd_res.is_use_t2i use_t2i=cmd_res.is_use_t2i
) )
# next is the LLM part
if message.only_command:
return
# check if the message is a llm-wake-up command # check if the message is a llm-wake-up command
if not msg_plain.startswith(self.llm_wake_prefix): if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
return return
if not provider: if not provider:
logger.debug("没有任何 LLM 可用,忽略。")
return return
# check the content safety # check the content safety
+5 -2
View File
@@ -46,6 +46,10 @@ class AstrBotDashBoard():
# 返回页面 # 返回页面
return self.dashboard_be.send_static_file("index.html") return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/auth/login")
def _():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/config") @self.dashboard_be.get("/config")
def rt_config(): def rt_config():
return self.dashboard_be.send_static_file("index.html") return self.dashboard_be.send_static_file("index.html")
@@ -86,12 +90,11 @@ class AstrBotDashBoard():
@self.dashboard_be.post("/api/change_password") @self.dashboard_be.post("/api/change_password")
def change_password(): def change_password():
password = self.context.base_config("dashboard_password", "") password = self.context.base_config.get("dashboard_password", "")
# 获得请求体 # 获得请求体
post_data = request.json post_data = request.json
if post_data["password"] == password: if post_data["password"] == password:
self.context.config_helper.put("dashboard_password", post_data["new_password"]) self.context.config_helper.put("dashboard_password", post_data["new_password"])
self.context.base_config['dashboard_password'] = post_data["new_password"]
return Response( return Response(
status="success", status="success",
message="修改成功。", message="修改成功。",
+7 -1
View File
@@ -4,12 +4,13 @@ import asyncio
import sys import sys
import warnings import warnings
import traceback import traceback
import mimetypes
from astrbot.bootstrap import AstrBotBootstrap from astrbot.bootstrap import AstrBotBootstrap
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
from logging import Formatter from logging import Formatter
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
logo_tmpl = """ logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________. ___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | | / \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----` / ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
@@ -43,6 +44,11 @@ def check_env():
os.makedirs("data/config", exist_ok=True) os.makedirs("data/config", exist_ok=True)
os.makedirs("temp", exist_ok=True) os.makedirs("temp", exist_ok=True)
# workaround for issue #181
mimetypes.add_type("text/javascript", ".js")
mimetypes.add_type("text/javascript", ".mjs")
mimetypes.add_type("application/json", ".json")
if __name__ == "__main__": if __name__ == "__main__":
check_env() check_env()
+8 -4
View File
@@ -62,14 +62,16 @@ class InternalCommandHandler:
return CommandResult().message("你没有权限使用该指令。") return CommandResult().message("你没有权限使用该指令。")
l = message_str.split(" ") l = message_str.split(" ")
if len(l) == 1: if len(l) == 1:
return CommandResult().message("设置机器人唤醒词,支持多唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称1 昵称2 昵称3") return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.nick}")
nick = l[1:] nick = l[1].strip()
if not nick:
return CommandResult().message("wake: 请指定唤醒词。")
context.config_helper.put("nick_qq", nick) context.config_helper.put("nick_qq", nick)
context.nick = tuple(nick) context.nick = tuple(nick)
return CommandResult( return CommandResult(
hit=True, hit=True,
success=True, success=True,
message_chain=f"已经成功将唤醒词设定为 {nick}", message_chain=f"已经成功将唤醒词设定为 {nick}",
) )
def update(self, message: AstrMessageEvent, context: Context): def update(self, message: AstrMessageEvent, context: Context):
@@ -230,15 +232,17 @@ class InternalCommandHandler:
) )
def t2i_toggle(self, message: AstrMessageEvent, context: Context): def t2i_toggle(self, message: AstrMessageEvent, context: Context):
p = context.config_helper.get("qq_pic_mode", True) p = context.t2i_mode
if p: if p:
context.config_helper.put("qq_pic_mode", False) context.config_helper.put("qq_pic_mode", False)
context.t2i_mode = False
return CommandResult( return CommandResult(
hit=True, hit=True,
success=True, success=True,
message_chain="已关闭文本转图片模式。", message_chain="已关闭文本转图片模式。",
) )
context.config_helper.put("qq_pic_mode", True) context.config_helper.put("qq_pic_mode", True)
context.t2i_mode = True
return CommandResult( return CommandResult(
hit=True, hit=True,
+38 -3
View File
@@ -20,7 +20,9 @@ class CommandMetadata():
inner_command: bool inner_command: bool
plugin_metadata: PluginMetadata plugin_metadata: PluginMetadata
handler: callable handler: callable
description: str use_regex: bool = False
ignore_prefix: bool = False
description: str = ""
class CommandManager(): class CommandManager():
def __init__(self): def __init__(self):
@@ -33,10 +35,14 @@ class CommandManager():
description: str, description: str,
priority: int, priority: int,
handler: callable, handler: callable,
use_regex: bool = False,
ignore_prefix: bool = False,
plugin_metadata: PluginMetadata = None, plugin_metadata: PluginMetadata = None,
): ):
''' '''
优先级越高,越先被处理。 优先级越高,越先被处理。
use_regex: 是否使用正则表达式匹配指令。
''' '''
if command in self.commands_handler: if command in self.commands_handler:
raise ValueError(f"Command {command} already exists.") raise ValueError(f"Command {command} already exists.")
@@ -48,6 +54,8 @@ class CommandManager():
inner_command=plugin_metadata == None, inner_command=plugin_metadata == None,
plugin_metadata=plugin_metadata, plugin_metadata=plugin_metadata,
handler=handler, handler=handler,
use_regex=use_regex,
ignore_prefix=ignore_prefix,
description=description description=description
) )
if plugin_metadata: if plugin_metadata:
@@ -64,15 +72,42 @@ class CommandManager():
break break
if not plugin: if not plugin:
logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}") logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}")
self.register(request.command_name, request.description, request.priority, request.handler, plugin.metadata) else:
self.register(command=request.command_name,
description=request.description,
priority=request.priority,
handler=request.handler,
use_regex=request.use_regex,
ignore_prefix=request.ignore_prefix,
plugin_metadata=plugin.metadata)
self.plugin_commands_waitlist = [] self.plugin_commands_waitlist = []
async def check_command_ignore_prefix(self, message_str: str) -> bool:
for _, command in self.commands:
command_metadata = self.commands_handler[command]
if command_metadata.ignore_prefix:
trig = False
if self.commands_handler[command].use_regex:
trig = self.command_parser.regex_match(message_str, command)
else:
trig = message_str.startswith(command)
if trig:
return True
return False
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult: async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
message_str = message_event.message_str message_str = message_event.message_str
for _, command in self.commands: for _, command in self.commands:
if message_str.startswith(command): trig = False
if self.commands_handler[command].use_regex:
trig = self.command_parser.regex_match(message_str, command)
else:
trig = message_str.startswith(command)
if trig:
logger.info(f"触发 {command} 指令。") logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context) command_result = await self.execute_handler(command, message_event, context)
if not command_result:
continue
if command_result.hit: if command_result.hit:
return command_result return command_result
+6
View File
@@ -1,3 +1,5 @@
import re
class CommandTokens(): class CommandTokens():
def __init__(self) -> None: def __init__(self) -> None:
self.tokens = [] self.tokens = []
@@ -17,3 +19,7 @@ class CommandParser():
cmd_tokens.tokens = message.split(" ") cmd_tokens.tokens = message.split(" ")
cmd_tokens.len = len(cmd_tokens.tokens) cmd_tokens.len = len(cmd_tokens.tokens)
return cmd_tokens return cmd_tokens
def regex_match(self, message: str, command: str) -> bool:
return re.search(command, message, re.MULTILINE) is not None
+10 -1
View File
@@ -2,6 +2,8 @@ import abc
from typing import Union, Any, List from typing import Union, Any, List
from nakuru.entities.components import Plain, At, Image, BaseMessageComponent from nakuru.entities.components import Plain, At, Image, BaseMessageComponent
from type.astrbot_message import AstrBotMessage from type.astrbot_message import AstrBotMessage
from type.command import CommandResult
from type.astrbot_message import MessageType
class Platform(): class Platform():
@@ -24,7 +26,14 @@ class Platform():
pass pass
@abc.abstractmethod @abc.abstractmethod
async def send_msg(self, target: Any, result_message: Union[List[BaseMessageComponent], str]): async def send_msg(self, target: Any, result_message: CommandResult):
'''
发送消息(主动)
'''
pass
@abc.abstractmethod
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
''' '''
发送消息(主动) 发送消息(主动)
''' '''
+2 -2
View File
@@ -58,7 +58,7 @@ class PlatformManager():
try: try:
qq_gocq = QQGOCQ(self.context, self.msg_handler) qq_gocq = QQGOCQ(self.context, self.msg_handler)
self.context.platforms.append(RegisteredPlatform( self.context.platforms.append(RegisteredPlatform(
platform_name="gocq", platform_instance=qq_gocq, origin="internal")) platform_name="nakuru", platform_instance=qq_gocq, origin="internal"))
await qq_gocq.run() await qq_gocq.run()
except BaseException as e: except BaseException as e:
logger.error("启动 nakuru 适配器时出现错误: " + str(e)) logger.error("启动 nakuru 适配器时出现错误: " + str(e))
@@ -81,7 +81,7 @@ class PlatformManager():
from model.platform.qq_official import QQOfficial from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(self.context, self.msg_handler) qqchannel_bot = QQOfficial(self.context, self.msg_handler)
self.context.platforms.append(RegisteredPlatform( self.context.platforms.append(RegisteredPlatform(
platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal"))
return qqchannel_bot.run() return qqchannel_bot.run()
except BaseException as e: except BaseException as e:
logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e)) logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e))
+74 -16
View File
@@ -7,6 +7,7 @@ from aiocqhttp.exceptions import ActionFailed
from . import Platform from . import Platform
from type.astrbot_message import * from type.astrbot_message import *
from type.message_event import * from type.message_event import *
from type.command import *
from typing import Union, List, Dict from typing import Union, List, Dict
from nakuru.entities.components import * from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
@@ -47,6 +48,14 @@ class AIOCQHTTP(Platform):
abm.message = [] abm.message = []
message_str = "" message_str = ""
if not isinstance(event.message, list):
err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。"
logger.critical(err)
try:
self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return
for m in event.message: for m in event.message:
t = m['type'] t = m['type']
a = None a = None
@@ -74,14 +83,12 @@ class AIOCQHTTP(Platform):
abm = self.convert_message(event) abm = self.convert_message(event)
if abm: if abm:
await self.handle_msg(abm) await self.handle_msg(abm)
# return {'reply': event.message}
@self.bot.on_message('private') @self.bot.on_message('private')
async def private(event: Event): async def private(event: Event):
abm = self.convert_message(event) abm = self.convert_message(event)
if abm: if abm:
await self.handle_msg(abm) await self.handle_msg(abm)
# return {'reply': event.message}
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
@@ -96,40 +103,61 @@ class AIOCQHTTP(Platform):
await asyncio.sleep(1) await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool: def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True # if message chain contains Plain components or
# At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE: if message.type == MessageType.FRIEND_MESSAGE:
return True return True, "friend"
for comp in message.message: for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id: if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True return True, "at"
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
return True, "command"
# check nicks # check nicks
if self.check_nick(message.message_str): if self.check_nick(message.message_str):
return True return True, "nick"
return False return False, "none"
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
logger.info( logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
if not self.pre_check(message): ok, reason = self.pre_check(message)
if not ok:
return return
# 解析 role # 解析 role
sender_id = str(message.sender.user_id) sender_id = str(message.sender.user_id)
if sender_id == self.context.config_helper.get('admin_qq', '') or \ if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.config_helper.get('other_admins', []): sender_id in self.context.base_config.get('other_admins', []):
role = 'admin' role = 'admin'
else: else:
role = 'member' role = 'member'
# parse unified message origin
unified_msg_origin = None
assert isinstance(message.raw_message, Event)
if message.type == MessageType.GROUP_MESSAGE:
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}"
elif message.type == MessageType.FRIEND_MESSAGE:
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}"
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
# construct astrbot message event # construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role) ame = AstrMessageEvent.from_astrbot_message(message,
self.context,
"aiocqhttp",
message.session_id,
role,
unified_msg_origin,
reason == "command") # only_command
# transfer control to message handler # transfer control to message handler
message_result = await self.message_handler.handle(ame) message_result = await self.message_handler.handle(ame)
if not message_result: return if not message_result: return
await self.reply_msg(message, message_result.result_message) await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback: if message_result.callback:
message_result.callback() message_result.callback()
@@ -140,7 +168,8 @@ class AIOCQHTTP(Platform):
async def reply_msg(self, async def reply_msg(self,
message: AstrBotMessage, message: AstrBotMessage,
result_message: list): result_message: list,
use_t2i: bool = None):
""" """
回复用户唤醒机器人的消息。(被动回复) 回复用户唤醒机器人的消息。(被动回复)
""" """
@@ -153,7 +182,7 @@ class AIOCQHTTP(Platform):
res = [Plain(text=res), ] res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture. # if image mode, put all Plain texts into a new picture.
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list): if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res) rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images: if rendered_images:
try: try:
@@ -165,7 +194,7 @@ class AIOCQHTTP(Platform):
await self._reply(message, res) await self._reply(message, res)
async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]): async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
if isinstance(message_chain, str): if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ] message_chain = [Plain(text=message_chain), ]
@@ -179,7 +208,15 @@ class AIOCQHTTP(Platform):
image_idx.append(idx) image_idx.append(idx)
ret.append(d) ret.append(d)
try: try:
await self.bot.send(message.raw_message, ret) if isinstance(message, AstrBotMessage):
await self.bot.send(message.raw_message, ret)
if isinstance(message, dict):
if 'group_id' in message:
await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
elif 'user_id' in message:
await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
else:
raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
except ActionFailed as e: except ActionFailed as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"回复消息失败: {e}") logger.error(f"回复消息失败: {e}")
@@ -196,3 +233,24 @@ class AIOCQHTTP(Platform):
ret[idx]['data']['file'] = image_url ret[idx]['data']['file'] = image_url
ret[idx]['data']['path'] = image_url ret[idx]['data']['path'] = image_url
await self.bot.send(message.raw_message, ret) await self.bot.send(message.raw_message, ret)
async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
'''
以主动的方式给QQ用户、QQ群发送一条消息。
`target` 接收一个 dict 类型的值引用。
- 要发给 QQ 下的某个用户,请添加 key `user_id`,值为 int 类型的 qq 号;
- 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号;
'''
await self._reply(target, result_message.message_chain)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
if message_type == MessageType.GROUP_MESSAGE:
await self.send_msg({'group_id': int(target)}, result_message)
elif message_type == MessageType.FRIEND_MESSAGE:
await self.send_msg({'user_id': int(target)}, result_message)
else:
raise Exception("aiocqhttp: 无法识别的消息类型。")
+74 -18
View File
@@ -14,6 +14,7 @@ from type.types import Context
from . import Platform from . import Platform
from type.astrbot_message import * from type.astrbot_message import *
from type.message_event import * from type.message_event import *
from type.command import *
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
from logging import Logger from logging import Logger
from astrbot.message.handler import MessageHandler from astrbot.message.handler import MessageHandler
@@ -73,14 +74,17 @@ class QQGOCQ(Platform):
def pre_check(self, message: AstrBotMessage) -> bool: def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True # if message chain contains Plain components or At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE: if message.type == MessageType.FRIEND_MESSAGE:
return True return True, "friend"
for comp in message.message: for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id: if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True return True, "at"
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
return True, "command"
# check nicks # check nicks
if self.check_nick(message.message_str): if self.check_nick(message.message_str):
return True return True, "nick"
return False return False, "none"
def run(self): def run(self):
coro = self.client._run() coro = self.client._run()
@@ -94,7 +98,8 @@ class QQGOCQ(Platform):
(GroupMessage, FriendMessage, GuildMessage)) (GroupMessage, FriendMessage, GuildMessage))
# 判断是否响应消息 # 判断是否响应消息
if not self.pre_check(message): ok, reason = self.pre_check(message)
if not ok:
return return
# 解析 session_id # 解析 session_id
@@ -111,20 +116,41 @@ class QQGOCQ(Platform):
# 解析 role # 解析 role
sender_id = str(message.raw_message.user_id) sender_id = str(message.raw_message.user_id)
if sender_id == self.context.config_helper.get('admin_qq', '') or \ if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.config_helper.get('other_admins', []): sender_id in self.context.base_config.get('other_admins', []):
role = 'admin' role = 'admin'
else: else:
role = 'member' role = 'member'
# parse unified message origin
unified_msg_origin = None
if message.type == MessageType.GROUP_MESSAGE:
assert isinstance(message.raw_message, GroupMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}"
elif message.type == MessageType.FRIEND_MESSAGE:
assert isinstance(message.raw_message, FriendMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}"
elif message.type == MessageType.GUILD_MESSAGE:
assert isinstance(message.raw_message, GuildMessage)
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}"
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
# construct astrbot message event # construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role) ame = AstrMessageEvent.from_astrbot_message(message,
self.context,
"nakuru",
session_id,
role,
unified_msg_origin,
reason == 'command') # only_command
# transfer control to message handler # transfer control to message handler
message_result = await self.message_handler.handle(ame) message_result = await self.message_handler.handle(ame)
if not message_result: return if not message_result: return
await self.reply_msg(message, message_result.result_message) await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback: if message_result.callback:
message_result.callback() message_result.callback()
@@ -134,7 +160,8 @@ class QQGOCQ(Platform):
async def reply_msg(self, async def reply_msg(self,
message: AstrBotMessage, message: AstrBotMessage,
result_message: List[BaseMessageComponent]): result_message: List[BaseMessageComponent],
use_t2i: bool = None):
""" """
回复用户唤醒机器人的消息。(被动回复) 回复用户唤醒机器人的消息。(被动回复)
""" """
@@ -151,7 +178,7 @@ class QQGOCQ(Platform):
res = [Plain(text=res), ] res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture. # if image mode, put all Plain texts into a new picture.
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list): if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res) rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images: if rendered_images:
try: try:
@@ -168,14 +195,26 @@ class QQGOCQ(Platform):
message_chain = [Plain(text=message_chain), ] message_chain = [Plain(text=message_chain), ]
is_dict = isinstance(source, dict) is_dict = isinstance(source, dict)
if source.type == "GuildMessage":
typ = None
if is_dict:
if "group_id" in source:
typ = "GroupMessage"
elif "user_id" in source:
typ = "FriendMessage"
elif "guild_id" in source:
typ = "GuildMessage"
else:
typ = source.type
if typ == "GuildMessage":
guild_id = source['guild_id'] if is_dict else source.guild_id guild_id = source['guild_id'] if is_dict else source.guild_id
chan_id = source['channel_id'] if is_dict else source.channel_id chan_id = source['channel_id'] if is_dict else source.channel_id
await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain) await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain)
elif source.type == "FriendMessage": elif typ == "FriendMessage":
user_id = source['user_id'] if is_dict else source.user_id user_id = source['user_id'] if is_dict else source.user_id
await self.client.sendFriendMessage(user_id, message_chain) await self.client.sendFriendMessage(user_id, message_chain)
elif source.type == "GroupMessage": elif typ == "GroupMessage":
group_id = source['group_id'] if is_dict else source.group_id group_id = source['group_id'] if is_dict else source.group_id
# 过长时forward发送 # 过长时forward发送
plain_text_len = 0 plain_text_len = 0
@@ -185,7 +224,7 @@ class QQGOCQ(Platform):
plain_text_len += len(i.text) plain_text_len += len(i.text)
elif isinstance(i, Image): elif isinstance(i, Image):
image_num += 1 image_num += 1
if plain_text_len > self.context.config_helper.get('qq_forward_threshold', 200): if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
# 删除At # 删除At
for i in message_chain: for i in message_chain:
if isinstance(i, At): if isinstance(i, At):
@@ -199,7 +238,7 @@ class QQGOCQ(Platform):
return return
await self.client.sendGroupMessage(group_id, message_chain) await self.client.sendGroupMessage(group_id, message_chain)
async def send_msg(self, target: Dict[str, int], result_message: Union[List[BaseMessageComponent], str]): async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
''' '''
以主动的方式给用户、群或者频道发送一条消息。 以主动的方式给用户、群或者频道发送一条消息。
@@ -211,7 +250,24 @@ class QQGOCQ(Platform):
guild_id 不是频道号。 guild_id 不是频道号。
''' '''
await self._reply(target, result_message) await self._reply(target, result_message.message_chain)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
'''
以主动的方式给用户、群或者频道发送一条消息。
`message_type` 为 MessageType 枚举类型。
- 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE
- 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE
- 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。
'''
if message_type == MessageType.FRIEND_MESSAGE:
await self.send_msg({"user_id": int(target)}, result_message)
elif message_type == MessageType.GROUP_MESSAGE:
await self.send_msg({"group_id": int(target)}, result_message)
elif message_type == MessageType.GUILD_MESSAGE:
await self.send_msg({"channel_id": int(target)}, result_message)
def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage: def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage:
abm = AstrBotMessage() abm = AstrBotMessage()
@@ -232,7 +288,7 @@ class QQGOCQ(Platform):
str(message.sender.user_id), str(message.sender.user_id),
str(message.sender.nickname) str(message.sender.nickname)
) )
abm.tag = "gocq" abm.tag = "nakuru"
abm.message = message.message abm.message = message.message
return abm return abm
+73 -41
View File
@@ -13,6 +13,7 @@ from util.io import save_temp_img, download_image_by_url
from . import Platform from . import Platform
from type.astrbot_message import * from type.astrbot_message import *
from type.message_event import * from type.message_event import *
from type.command import *
from typing import Union, List, Dict from typing import Union, List, Dict
from nakuru.entities.components import * from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
@@ -43,10 +44,15 @@ class botClient(Client):
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
await self.platform.handle_msg(abm) await self.platform.handle_msg(abm)
# 收到 C2C 消息
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
await self.platform.handle_msg(abm)
class QQOfficial(Platform): class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None: def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
super().__init__() super().__init__()
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
@@ -80,6 +86,8 @@ class QQOfficial(Platform):
self.client.set_platform(self) self.client.set_platform(self)
self.test_mode = test_mode
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False): async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
plain_text = "" plain_text = ""
image_path = None # only one img supported image_path = None # only one img supported
@@ -104,14 +112,20 @@ class QQOfficial(Platform):
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.raw_message = message abm.raw_message = message
abm.message_id = message.id abm.message_id = message.id
abm.tag = "qqchan" abm.tag = "qqofficial"
msg: List[BaseMessageComponent] = [] msg: List[BaseMessageComponent] = []
if message_type == MessageType.GROUP_MESSAGE: if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
abm.sender = MessageMember( if isinstance(message, botpy.message.GroupMessage):
message.author.member_openid, abm.sender = MessageMember(
"" message.author.member_openid,
) ""
)
else:
abm.sender = MessageMember(
message.author.user_openid,
""
)
abm.message_str = message.content.strip() abm.message_str = message.content.strip()
abm.self_id = "unknown_selfid" abm.self_id = "unknown_selfid"
@@ -126,8 +140,7 @@ class QQOfficial(Platform):
msg.append(img) msg.append(img)
abm.message = msg abm.message = msg
elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE: elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage):
# 目前对于 FRIEND_MESSAGE 只处理频道私聊
try: try:
abm.self_id = str(message.mentions[0].id) abm.self_id = str(message.mentions[0].id)
except: except:
@@ -175,7 +188,7 @@ class QQOfficial(Platform):
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
assert isinstance(message.raw_message, (botpy.message.Message, assert isinstance(message.raw_message, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage)) botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
is_group = message.type != MessageType.FRIEND_MESSAGE is_group = message.type != MessageType.FRIEND_MESSAGE
_t = "/私聊" if not is_group else "" _t = "/私聊" if not is_group else ""
@@ -196,8 +209,8 @@ class QQOfficial(Platform):
# 解析出 role # 解析出 role
sender_id = message.sender.user_id sender_id = message.sender.user_id
if sender_id == self.context.config_helper.get('admin_qqchan', None) or \ if sender_id == self.context.base_config.get('admin_qqchan', None) or \
sender_id in self.context.config_helper.get('other_admins', None): sender_id in self.context.base_config.get('other_admins', None):
role = 'admin' role = 'admin'
else: else:
role = 'member' role = 'member'
@@ -209,7 +222,7 @@ class QQOfficial(Platform):
if not message_result: if not message_result:
return return
await self.reply_msg(message, message_result.result_message) ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback: if message_result.callback:
message_result.callback() message_result.callback()
@@ -217,15 +230,18 @@ class QQOfficial(Platform):
if session_id in self.waiting and self.waiting[session_id] == '': if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message self.waiting[session_id] = message
return ret
async def reply_msg(self, async def reply_msg(self,
message: AstrBotMessage, message: AstrBotMessage,
result_message: List[BaseMessageComponent]): result_message: List[BaseMessageComponent],
use_t2i: bool = None):
''' '''
回复频道消息 回复频道消息
''' '''
source = message.raw_message source = message.raw_message
assert isinstance(source, (botpy.message.Message, assert isinstance(source, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage)) botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
logger.info( logger.info(
f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}") f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}")
@@ -234,7 +250,7 @@ class QQOfficial(Platform):
msg_ref = None msg_ref = None
rendered_images = [] rendered_images = []
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(result_message, list): if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(result_message) rendered_images = await self.convert_to_t2i_chain(result_message)
if isinstance(result_message, list): if isinstance(result_message, list):
@@ -253,12 +269,14 @@ class QQOfficial(Platform):
'message_reference': msg_ref 'message_reference': msg_ref
} }
if message.type == MessageType.GROUP_MESSAGE: if isinstance(message.raw_message, botpy.message.GroupMessage):
data['group_openid'] = str(source.group_openid) data['group_openid'] = str(source.group_openid)
elif message.type == MessageType.GUILD_MESSAGE: elif isinstance(message.raw_message, botpy.message.Message):
data['channel_id'] = source.channel_id data['channel_id'] = source.channel_id
elif message.type == MessageType.FRIEND_MESSAGE: elif isinstance(message.raw_message, botpy.message.DirectMessage):
data['guild_id'] = source.guild_id data['guild_id'] = source.guild_id
elif isinstance(message.raw_message, botpy.message.C2CMessage):
data['openid'] = source.author.user_openid
if image_path: if image_path:
data['file_image'] = image_path data['file_image'] = image_path
if rendered_images: if rendered_images:
@@ -269,14 +287,13 @@ class QQOfficial(Platform):
_data['message_reference'] = None _data['message_reference'] = None
try: try:
await self._reply(**_data) return await self._reply(**_data)
return
except BaseException as e: except BaseException as e:
logger.warn(traceback.format_exc()) logger.warn(traceback.format_exc())
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
try: try:
await self._reply(**data) return await self._reply(**data)
except BaseException as e: except BaseException as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
# 分割过长的消息 # 分割过长的消息
@@ -286,28 +303,27 @@ class QQOfficial(Platform):
split_res.append(plain_text[len(plain_text)//2:]) split_res.append(plain_text[len(plain_text)//2:])
for i in split_res: for i in split_res:
data['content'] = i data['content'] = i
await self._reply(**data) return await self._reply(**data)
else: else:
try: try:
# 防止被qq频道过滤消息 # 防止被qq频道过滤消息
plain_text = plain_text.replace(".", " . ") plain_text = plain_text.replace(".", " . ")
await self._reply(**data) return await self._reply(**data)
except BaseException as e: except BaseException as e:
try: try:
data['content'] = str.join(" ", plain_text) data['content'] = str.join(" ", plain_text)
await self._reply(**data) return await self._reply(**data)
except BaseException as e: except BaseException as e:
plain_text = re.sub( plain_text = re.sub(
r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = plain_text.replace(".", "·") plain_text = plain_text.replace(".", "·")
data['content'] = plain_text data['content'] = plain_text
await self._reply(**data) return await self._reply(**data)
async def _reply(self, **kwargs): async def _reply(self, **kwargs):
if 'group_openid' in kwargs: if 'group_openid' in kwargs or 'openid' in kwargs:
# QQ群组消息 # QQ群组消息
# qq群组消息需要自行上传,暂时不处理 if 'file_image' in kwargs and kwargs['file_image']:
if 'file_image' in kwargs:
file_image_path = kwargs['file_image'].replace("file:///", "") file_image_path = kwargs['file_image'].replace("file:///", "")
if file_image_path: if file_image_path:
@@ -317,50 +333,66 @@ class QQOfficial(Platform):
logger.debug(f"上传图片: {file_image_path}") logger.debug(f"上传图片: {file_image_path}")
image_url = await self.context.image_uploader.upload_image(file_image_path) image_url = await self.context.image_uploader.upload_image(file_image_path)
logger.debug(f"上传成功: {image_url}") logger.debug(f"上传成功: {image_url}")
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url) if 'group_openid' in kwargs:
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
elif 'openid' in kwargs:
media = await self.client.api.post_c2c_file(kwargs['openid'], 1, image_url)
del kwargs['file_image'] del kwargs['file_image']
kwargs['media'] = media kwargs['media'] = media
logger.debug(f"发送群图片: {media}") logger.debug(f"发送群图片: {media}")
kwargs['msg_type'] = 7 # 富媒体 kwargs['msg_type'] = 7 # 富媒体
await self.client.api.post_group_message(**kwargs) if self.test_mode:
return kwargs
if 'group_openid' in kwargs:
await self.client.api.post_group_message(**kwargs)
elif 'openid' in kwargs:
await self.client.api.post_c2c_message(**kwargs)
elif 'channel_id' in kwargs: elif 'channel_id' in kwargs:
# 频道消息 # 频道消息
if 'file_image' in kwargs: if 'file_image' in kwargs and kwargs['file_image']:
kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") kwargs['file_image'] = kwargs['file_image'].replace("file:///", "")
# 频道消息发图只支持本地 # 频道消息发图只支持本地
if kwargs['file_image'].startswith("http"): if kwargs['file_image'].startswith("http"):
kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) kwargs['file_image'] = await download_image_by_url(kwargs['file_image'])
if self.test_mode:
return kwargs
await self.client.api.post_message(**kwargs) await self.client.api.post_message(**kwargs)
else: elif 'guild_id' in kwargs:
# 频道私聊消息 # 频道私聊消息
if 'file_image' in kwargs: if 'file_image' in kwargs and kwargs['file_image']:
kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") kwargs['file_image'] = kwargs['file_image'].replace("file:///", "")
if kwargs['file_image'].startswith("http"): if kwargs['file_image'].startswith("http"):
kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) kwargs['file_image'] = await download_image_by_url(kwargs['file_image'])
if self.test_mode:
return kwargs
await self.client.api.post_dms(**kwargs) await self.client.api.post_dms(**kwargs)
else:
raise ValueError("Unknown target type.")
async def send_msg(self, target: Dict[str, str], result_message: Union[List[BaseMessageComponent], str]): async def send_msg(self, target: Dict[str, str], result_message: CommandResult):
''' '''
以主动的方式给用户、群或者频道发送一条消息。 以主动的方式给频道用户、群、频道或者消息列表用户(QQ用户)发送一条消息。
`target` 接收一个 dict 类型的值引用。 `target` 接收一个 dict 类型的值引用。
- 如果目标是 QQ 群,请添加 key `group_openid`。 - 如果目标是 QQ 群,请添加 key `group_openid`。
- 如果目标是 频道消息,请添加 key `channel_id`。 - 如果目标是 频道消息,请添加 key `channel_id`。
- 如果目标是 频道私聊,请添加 key `guild_id`。 - 如果目标是 频道私聊,请添加 key `guild_id`。
- 如果目标是 QQ 用户,请添加 key `openid`。
''' '''
if isinstance(result_message, list): plain_text, image_path = await self._parse_to_qqofficial(result_message.message_chain)
plain_text, image_path = await self._parse_to_qqofficial(result_message)
else:
plain_text = result_message
payload = { payload = {
'content': plain_text, 'content': plain_text,
'file_image': image_path,
**target **target
} }
if image_path:
payload['file_image'] = image_path
await self._reply(**payload) await self._reply(**payload)
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
raise NotImplementedError("qqofficial 不支持此方法。")
def wait_for_message(self, channel_id: int) -> AstrBotMessage: def wait_for_message(self, channel_id: int) -> AstrBotMessage:
''' '''
等待指定 channel_id 的下一条信息,超时 300s 后抛出异常 等待指定 channel_id 的下一条信息,超时 300s 后抛出异常
+4 -2
View File
@@ -13,13 +13,15 @@ class CommandRegisterRequest():
description: str description: str
priority: int priority: int
handler: Callable handler: Callable
use_regex: bool = False
plugin_name: str = None plugin_name: str = None
ignore_prefix: bool = False
class PluginCommandBridge(): class PluginCommandBridge():
def __init__(self, cached_plugins: RegisteredPlugins): def __init__(self, cached_plugins: RegisteredPlugins):
self.plugin_commands_waitlist: List[CommandRegisterRequest] = [] self.plugin_commands_waitlist: List[CommandRegisterRequest] = []
self.cached_plugins = cached_plugins self.cached_plugins = cached_plugins
def register_command(self, plugin_name, command_name, description, priority, handler): def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False):
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, plugin_name)) self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix))
+7 -3
View File
@@ -123,7 +123,7 @@ class PluginManager():
return p return p
def uninstall_plugin(self, plugin_name: str): def uninstall_plugin(self, plugin_name: str):
plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins) plugin = self.get_registered_plugin(plugin_name)
if not plugin: if not plugin:
raise Exception("插件不存在。") raise Exception("插件不存在。")
root_dir_name = plugin.root_dir_name root_dir_name = plugin.root_dir_name
@@ -133,7 +133,7 @@ class PluginManager():
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
def update_plugin(self, plugin_name: str): def update_plugin(self, plugin_name: str):
plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins) plugin = self.get_registered_plugin(plugin_name)
if not plugin: if not plugin:
raise Exception("插件不存在。") raise Exception("插件不存在。")
@@ -156,6 +156,8 @@ class PluginManager():
module_path = plugin['module_path'] module_path = plugin['module_path']
root_dir_name = plugin['pname'] root_dir_name = plugin['pname']
logger.info(f"正在加载插件 {root_dir_name} ...")
# self.check_plugin_dept_update(cached_plugins, root_dir_name) # self.check_plugin_dept_update(cached_plugins, root_dir_name)
module = __import__("addons.plugins." + module = __import__("addons.plugins." +
@@ -166,8 +168,10 @@ class PluginManager():
try: try:
# 尝试传入 ctx # 尝试传入 ctx
obj = getattr(module, cls[0])(context=self.context) obj = getattr(module, cls[0])(context=self.context)
except: except TypeError:
obj = getattr(module, cls[0])() obj = getattr(module, cls[0])()
except BaseException as e:
raise e
metadata = None metadata = None
+5 -3
View File
@@ -53,7 +53,7 @@ class ProviderOpenAIOfficial(Provider):
os.makedirs("data/openai", exist_ok=True) os.makedirs("data/openai", exist_ok=True)
self.cc = CmdConfig self.context = context
self.key_data_path = "data/openai/keys.json" self.key_data_path = "data/openai/keys.json"
self.api_keys = [] self.api_keys = []
self.chosen_api_key = None self.chosen_api_key = None
@@ -78,7 +78,7 @@ class ProviderOpenAIOfficial(Provider):
) )
self.model_configs: Dict = cfg['chatGPTConfigs'] self.model_configs: Dict = cfg['chatGPTConfigs']
super().set_curr_model(self.model_configs['model']) super().set_curr_model(self.model_configs['model'])
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None) self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
self.session_memory: Dict[str, List] = {} # 会话记忆 self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock() self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
@@ -386,6 +386,8 @@ class ProviderOpenAIOfficial(Provider):
assert isinstance(completion, ChatCompletion) assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}") logger.debug(f"openai completion: {completion.usage}")
if len(completion.choices) == 0:
raise Exception("OpenAI API 返回的 completion 为空。")
choice = completion.choices[0] choice = completion.choices[0]
usage_tokens = completion.usage.total_tokens usage_tokens = completion.usage.total_tokens
@@ -492,7 +494,7 @@ class ProviderOpenAIOfficial(Provider):
def set_model(self, model: str): def set_model(self, model: str):
self.model_configs['model'] = model self.model_configs['model'] = model
self.cc.put_by_dot_str("openai.chatGPTConfigs.model", model) self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
super().set_curr_model(model) super().set_curr_model(model)
def get_configs(self): def get_configs(self):
+13 -11
View File
@@ -2,7 +2,6 @@ from typing import Union, List, Callable
from dataclasses import dataclass from dataclasses import dataclass
from nakuru.entities.components import Plain, Image from nakuru.entities.components import Plain, Image
@dataclass @dataclass
class CommandItem(): class CommandItem():
''' '''
@@ -19,12 +18,17 @@ class CommandResult():
用于在Command中返回多个值 用于在Command中返回多个值
''' '''
def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None: def __init__(self,
hit: bool = True,
success: bool = True,
message_chain: list = [],
command_name: str = "unknown_command",
use_t2i: bool = None) -> None:
self.hit = hit self.hit = hit
self.success = success self.success = success
self.message_chain = message_chain self.message_chain = message_chain
self.command_name = command_name self.command_name = command_name
self.is_use_t2i = None # default self.is_use_t2i = use_t2i
def message(self, message: str): def message(self, message: str):
''' '''
@@ -63,14 +67,12 @@ class CommandResult():
self.message_chain = [Image.fromFileSystem(path), ] self.message_chain = [Image.fromFileSystem(path), ]
return self return self
# def use_t2i(self, use_t2i: bool): def use_t2i(self, use_t2i: bool):
# ''' '''
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
'''
# CommandResult().use_t2i(False) self.is_use_t2i = use_t2i
# ''' return self
# self.is_use_t2i = use_t2i
# return self
def _result_tuple(self): def _result_tuple(self):
return (self.success, self.message_chain, self.command_name) return (self.success, self.message_chain, self.command_name)
+75 -1
View File
@@ -1 +1,75 @@
VERSION = '3.3.2' VERSION = '3.3.8'
DEFAULT_CONFIG = {
"qqbot": {
"enable": False,
"appid": "",
"token": "",
},
"gocqbot": {
"enable": False,
},
"uniqueSessionMode": False,
"dump_history_interval": 10,
"limit": {
"time": 60,
"count": 30,
},
"notice": "",
"direct_message_mode": True,
"reply_prefix": "",
"baidu_aip": {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": ""
},
"openai": {
"key": [],
"api_base": "",
"chatGPTConfigs": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"total_tokens_limit": 10000,
},
"qq_forward_threshold": 200,
"qq_welcome": "",
"qq_pic_mode": True,
"gocq_host": "127.0.0.1",
"gocq_http_port": 5700,
"gocq_websocket_port": 6700,
"gocq_react_group": True,
"gocq_react_guild": True,
"gocq_react_friend": True,
"gocq_react_group_increase": True,
"other_admins": [],
"CHATGPT_BASE_URL": "",
"qqbot_secret": "",
"qqofficial_enable_group_message": False,
"admin_qq": "",
"nick_qq": ["/", "!"],
"admin_qqchan": "",
"llm_env_prompt": "",
"llm_wake_prefix": "",
"default_personality_str": "",
"openai_image_generate": {
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
},
"http_proxy": "",
"https_proxy": "",
"dashboard_username": "",
"dashboard_password": "",
"aiocqhttp": {
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 0,
}
}
+21 -10
View File
@@ -2,7 +2,14 @@ from typing import List, Union, Optional
from dataclasses import dataclass from dataclasses import dataclass
from type.register import RegisteredPlatform from type.register import RegisteredPlatform
from type.types import Context from type.types import Context
from type.astrbot_message import AstrBotMessage from type.astrbot_message import AstrBotMessage, MessageType
@dataclass
class MessageResult():
result_message: Union[str, list]
is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None
class AstrMessageEvent(): class AstrMessageEvent():
@@ -12,7 +19,9 @@ class AstrMessageEvent():
platform: RegisteredPlatform, platform: RegisteredPlatform,
role: str, role: str,
context: Context, context: Context,
session_id: str = None): session_id: str = None,
unified_msg_origin: str = None,
only_command: bool = False):
''' '''
AstrBot 消息事件。 AstrBot 消息事件。
@@ -22,6 +31,8 @@ class AstrMessageEvent():
`role`: 角色,`admin` or `member` `role`: 角色,`admin` or `member`
`context`: 全局对象 `context`: 全局对象
`session_id`: 会话id `session_id`: 会话id
`unified_msg_origin`: 统一消息来源
`only_command`: 是否只处理指令,而不使用 LLM 回复
''' '''
self.context = context self.context = context
self.message_str = message_str self.message_str = message_str
@@ -29,24 +40,24 @@ class AstrMessageEvent():
self.platform = platform self.platform = platform
self.role = role self.role = role
self.session_id = session_id self.session_id = session_id
self.unified_msg_origin = unified_msg_origin
self.only_command = only_command
def from_astrbot_message(message: AstrBotMessage, def from_astrbot_message(message: AstrBotMessage,
context: Context, context: Context,
platform_name: str, platform_name: str,
session_id: str, session_id: str,
role: str = "member"): role: str = "member",
unified_msg_origin: str = None,
only_command: bool = False):
ame = AstrMessageEvent(message.message_str, ame = AstrMessageEvent(message.message_str,
message, message,
context.find_platform(platform_name), context.find_platform(platform_name),
role, role,
context, context,
session_id) session_id,
unified_msg_origin,
only_command=only_command)
return ame return ame
@dataclass
class MessageResult():
result_message: Union[str, list]
is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None
+46 -12
View File
@@ -8,6 +8,8 @@ from util.t2i.renderer import TextToImageRenderer
from util.updator.astrbot_updator import AstrBotUpdator from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator from util.updator.plugin_updator import PluginUpdator
from type.command import CommandResult
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider from model.provider.provider import Provider
@@ -28,37 +30,54 @@ class Context:
self.unique_session = False # 独立会话 self.unique_session = False # 独立会话
self.version: str = None # 机器人版本 self.version: str = None # 机器人版本
self.nick = None # gocq 的唤醒词 self.nick: tuple = None # gocq 的唤醒词
self.stat = {}
self.t2i_mode = False self.t2i_mode = False
self.web_search = False # 是否开启了网页搜索 self.web_search = False # 是否开启了网页搜索
self.reply_prefix = ""
self.metrics_uploader = None
self.updator: AstrBotUpdator = None self.updator: AstrBotUpdator = None
self.plugin_updator: PluginUpdator = None self.plugin_updator: PluginUpdator = None
self.metrics_uploader = None
self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins) self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins)
self.image_renderer = TextToImageRenderer() self.image_renderer = TextToImageRenderer()
self.image_uploader = ImageUploader() self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = [] self.ext_tasks: List[Task] = []
self.command_manager = None
# useless
self.reply_prefix = ""
def register_commands(self, def register_commands(self,
plugin_name: str, plugin_name: str,
command_name: str, command_name: str,
description: str, description: str,
priority: int, priority: int,
handler: callable): handler: callable,
use_regex: bool = False,
ignore_prefix: bool = False):
''' '''
注册插件指令。 注册插件指令。
`plugin_name`: 插件名,注意需要和你的 metadata 中的一致。 @param plugin_name: 插件名,注意需要和你的 metadata 中的一致。
`command_name`: 指令名,如 "help"。不需要带前缀。 @param command_name: 指令名,如 "help"。不需要带前缀。
`description`: 指令描述。 @param description: 指令描述。
`priority`: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
`handler`: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
@param use_regex: 是否使用正则表达式匹配指令名。
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
.. Example::
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
''' '''
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler) self.plugin_command_bridge.register_command(plugin_name,
command_name,
description,
priority,
handler,
use_regex,
ignore_prefix)
def register_task(self, coro: Awaitable, task_name: str): def register_task(self, coro: Awaitable, task_name: str):
''' '''
@@ -84,3 +103,18 @@ class Context:
return platform return platform
raise ValueError("couldn't find the platform you specified") raise ValueError("couldn't find the platform you specified")
async def send_message(self, unified_msg_origin: str, message: CommandResult):
'''
发送消息。
`unified_msg_origin`: 统一消息来源
`message`: 消息内容
'''
l = unified_msg_origin.split(":")
if len(l) != 3:
raise ValueError("Invalid unified_msg_origin")
platform_name, message_type, id = l
platform = self.find_platform(platform_name)
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
+38 -29
View File
@@ -1,19 +1,31 @@
import os import os
import json import json
from typing import Union from type.config import DEFAULT_CONFIG
cpath = "data/cmd_config.json" cpath = "data/cmd_config.json"
def check_exist(): def check_exist():
if not os.path.exists(cpath): if not os.path.exists(cpath):
with open(cpath, "w", encoding="utf-8-sig") as f: with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump({}, f, indent=4, ensure_ascii=False) json.dump({}, f, ensure_ascii=False)
f.flush() f.flush()
class CmdConfig(): class CmdConfig():
def __init__(self) -> None:
self.cached_config: dict = {}
self.init_configs()
def init_configs(self):
'''
初始化必需的配置项
'''
self.init_config_items(DEFAULT_CONFIG)
@staticmethod @staticmethod
def get(key, default=None): def get(key, default=None):
'''
从文件系统中直接获取配置
'''
check_exist() check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f: with open(cpath, "r", encoding="utf-8-sig") as f:
d = json.load(f) d = json.load(f)
@@ -22,28 +34,33 @@ class CmdConfig():
else: else:
return default return default
@staticmethod def get_all(self):
def get_all(): '''
从文件系统中获取所有配置
'''
check_exist() check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f: with open(cpath, "r", encoding="utf-8-sig") as f:
return json.load(f) conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
return conf
@staticmethod def put(self, key, value):
def put(key, value):
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f: with open(cpath, "r", encoding="utf-8-sig") as f:
d = json.load(f) d = json.load(f)
d[key] = value d[key] = value
with open(cpath, "w", encoding="utf-8-sig") as f: with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=4, ensure_ascii=False) json.dump(d, f, indent=2, ensure_ascii=False)
f.flush() f.flush()
self.cached_config[key] = value
@staticmethod @staticmethod
def put_by_dot_str(key: str, value): def put_by_dot_str(key: str, value):
''' '''
根据点分割的字符串将值写入配置文件 根据点分割的字符串将值写入配置文件
''' '''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f: with open(cpath, "r", encoding="utf-8-sig") as f:
d = json.load(f) d = json.load(f)
_d = d _d = d
@@ -54,30 +71,22 @@ class CmdConfig():
else: else:
_d = _d[_ks[i]] _d = _d[_ks[i]]
with open(cpath, "w", encoding="utf-8-sig") as f: with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=4, ensure_ascii=False) json.dump(d, f, indent=2, ensure_ascii=False)
f.flush() f.flush()
@staticmethod def init_config_items(self, d: dict):
def init_attributes(key: Union[str, list], init_val=""): conf = self.get_all()
check_exist()
conf_str = '' if not self.cached_config:
with open(cpath, "r", encoding="utf-8-sig") as f: self.cached_config = conf
conf_str = f.read()
if conf_str.startswith(u'/ufeff'):
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
d = json.loads(conf_str)
_tag = False _tag = False
if isinstance(key, str): for key, val in d.items():
if key not in d: if key not in conf:
d[key] = init_val conf[key] = val
_tag = True _tag = True
elif isinstance(key, list):
for k in key:
if k not in d:
d[k] = init_val
_tag = True
if _tag: if _tag:
with open(cpath, "w", encoding="utf-8-sig") as f: with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=4, ensure_ascii=False) json.dump(conf, f, indent=2, ensure_ascii=False)
f.flush() f.flush()
-131
View File
@@ -1,89 +1,5 @@
import json, os import json, os
from util.cmd_config import CmdConfig from util.cmd_config import CmdConfig
from type.config import VERSION
from type.types import Context
def init_configs():
'''
初始化必需的配置项
'''
cc = CmdConfig()
cc.init_attributes("qqbot", {
"enable": False,
"appid": "",
"token": "",
})
cc.init_attributes("gocqbot", {
"enable": False,
})
cc.init_attributes("uniqueSessionMode", False)
cc.init_attributes("dump_history_interval", 10)
cc.init_attributes("limit", {
"time": 60,
"count": 30,
})
cc.init_attributes("notice", "")
cc.init_attributes("direct_message_mode", True)
cc.init_attributes("reply_prefix", "")
cc.init_attributes("baidu_aip", {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": ""
})
cc.init_attributes("openai", {
"key": [],
"api_base": "",
"chatGPTConfigs": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"total_tokens_limit": 10000,
})
cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes("qq_welcome", "")
cc.init_attributes("qq_pic_mode", True)
cc.init_attributes("gocq_host", "127.0.0.1")
cc.init_attributes("gocq_http_port", 5700)
cc.init_attributes("gocq_websocket_port", 6700)
cc.init_attributes("gocq_react_group", True)
cc.init_attributes("gocq_react_guild", True)
cc.init_attributes("gocq_react_friend", True)
cc.init_attributes("gocq_react_group_increase", True)
cc.init_attributes("other_admins", [])
cc.init_attributes("CHATGPT_BASE_URL", "")
cc.init_attributes("qqbot_secret", "")
cc.init_attributes("qqofficial_enable_group_message", False)
cc.init_attributes("admin_qq", "")
cc.init_attributes("nick_qq", ["!", "", "ai"])
cc.init_attributes("admin_qqchan", "")
cc.init_attributes("llm_env_prompt", "")
cc.init_attributes("llm_wake_prefix", "")
cc.init_attributes("default_personality_str", "")
cc.init_attributes("openai_image_generate", {
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
})
cc.init_attributes("http_proxy", "")
cc.init_attributes("https_proxy", "")
cc.init_attributes("dashboard_username", "")
cc.init_attributes("dashboard_password", "")
# aiocqhttp 适配器
cc.init_attributes("aiocqhttp", {
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 0,
})
def try_migrate_config(): def try_migrate_config():
''' '''
@@ -98,50 +14,3 @@ def try_migrate_config():
os.remove("cmd_config.json") os.remove("cmd_config.json")
except Exception as e: except Exception as e:
pass pass
def inject_to_context(context: Context):
'''
将配置注入到 Context
this method returns all the configs
'''
cc = CmdConfig()
context.version = VERSION
context.base_config = cc.get_all()
cfg = context.base_config
if 'reply_prefix' in cfg:
# 适配旧版配置
if isinstance(cfg['reply_prefix'], dict):
context.reply_prefix = ""
cfg['reply_prefix'] = ""
cc.put("reply_prefix", "")
else:
context.reply_prefix = cfg['reply_prefix']
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
context.default_personality = None
else:
context.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']:
context.unique_session = True
else:
context.unique_session = False
nick_qq = cc.get("nick_qq", None)
if nick_qq == None:
nick_qq = ("/", )
if isinstance(nick_qq, str):
nick_qq = (nick_qq, )
if isinstance(nick_qq, list):
nick_qq = tuple(nick_qq)
context.nick = nick_qq
context.t2i_mode = cc.get("qq_pic_mode", True)
return cfg
+8 -1
View File
@@ -36,7 +36,14 @@ class MetricUploader():
for plugin in context.cached_plugins: for plugin in context.cached_plugins:
self.plugin_stats[plugin.metadata.plugin_name] = { self.plugin_stats[plugin.metadata.plugin_name] = {
"metadata": plugin.metadata "metadata": {
"plugin_name": plugin.metadata.plugin_name,
"plugin_type": plugin.metadata.plugin_type.value,
"author": plugin.metadata.author,
"desc": plugin.metadata.desc,
"version": plugin.metadata.version,
"repo": plugin.metadata.repo,
}
} }
try: try:
+1
View File
@@ -9,3 +9,4 @@ from model.platform import Platform
from model.platform.qq_nakuru import QQGOCQ from model.platform.qq_nakuru import QQGOCQ
from model.platform.qq_official import QQOfficial from model.platform.qq_official import QQOfficial
from model.platform.qq_aiocqhttp import AIOCQHTTP
+7 -2
View File
@@ -9,7 +9,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotUpdator(RepoZipUpdator): class AstrBotUpdator(RepoZipUpdator):
def __init__(self): def __init__(self):
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases" self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
def terminate_child_processes(self): def terminate_child_processes(self):
@@ -34,7 +34,12 @@ class AstrBotUpdator(RepoZipUpdator):
if delay: time.sleep(delay) if delay: time.sleep(delay)
py = sys.executable py = sys.executable
self.terminate_child_processes() self.terminate_child_processes()
os.execl(py, py, *sys.argv) py = py.replace(" ", "\\ ")
try:
os.execl(py, py, *sys.argv)
except Exception as e:
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
raise e
def check_update(self, url: str, current_version: str) -> ReleaseInfo: def check_update(self, url: str, current_version: str) -> ReleaseInfo:
return super().check_update(self.ASTRBOT_RELEASE_API, VERSION) return super().check_update(self.ASTRBOT_RELEASE_API, VERSION)