diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..e9113832b --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +comment: + # add "condensed_" to "header", "files" and "footer" + layout: "condensed_header, condensed_files, condensed_footer" + hide_project_coverage: TRUE # set to true \ No newline at end of file diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..1385093f4 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +omit = + */site-packages/* + */dist-packages/* + your_package_name/tests/* \ No newline at end of file diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..f27cf068c --- /dev/null +++ b/.dockerignore @@ -0,0 +1,18 @@ +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 +# github acions +.github/ +.*ignore +.git/ +# User-specific stuff +.idea/ +# Byte-compiled / optimized / DLL files +__pycache__/ +# Environments +.env +.venv +env/ +venv*/ +ENV/ +.conda/ +README*.md diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 000000000..e7a5263c2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,82 @@ +name: '🐛 报告 Bug' +title: '[Bug]' +description: 提交报告帮助我们改进。 +labels: [ 'bug' ] +body: + - type: markdown + attributes: + value: | + 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。 + - type: textarea + attributes: + label: 发生了什么 + description: 描述你遇到的异常 + placeholder: > + 一个清晰且具体的描述这个异常是什么。 + validations: + required: true + + - type: textarea + attributes: + label: 如何复现? + description: > + 复现该问题的步骤 + placeholder: > + 如: 1. 打开 '...' + validations: + required: true + + - type: textarea + attributes: + label: AstrBot 版本与部署方式 + description: > + 请提供您的 AstrBot 版本和部署方式。 + placeholder: > + 如: 3.1.8 Docker, 3.1.7 Windows启动器 + validations: + required: true + + - type: dropdown + attributes: + label: 操作系统 + description: | + 你在哪个操作系统上遇到了这个问题? + multiple: false + options: + - 'Windows' + - 'macOS' + - 'Linux' + - 'Other' + - 'Not sure' + validations: + required: true + + - type: textarea + attributes: + label: 额外信息 + description: > + 任何额外信息,如报错日志、截图等。 + placeholder: > + 请提供完整的报错日志或截图。 + validations: + required: true + + - type: checkboxes + attributes: + label: 你愿意提交 PR 吗? + description: > + 这绝对不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。 + options: + - label: 是的,我愿意提交 PR! + + - type: checkboxes + attributes: + label: Code of Conduct + options: + - label: > + 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。 + required: true + + - type: markdown + attributes: + value: "感谢您填写我们的表单!" \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 000000000..484959318 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,42 @@ + +name: '🎉 功能建议' +title: "[Feature]" +description: 提交建议帮助我们改进。 +labels: [ "enhancement" ] +body: + - type: markdown + attributes: + value: | + 感谢您抽出时间提出新功能建议,请准确解释您的想法。 + + - type: textarea + attributes: + label: 描述 + description: 简短描述您的功能建议。 + + - type: textarea + attributes: + label: 使用场景 + description: 你想要发生什么? + placeholder: > + 一个清晰且具体的描述这个功能的使用场景。 + + - type: checkboxes + attributes: + label: 你愿意提交PR吗? + description: > + 这不是必须的,但我们欢迎您的贡献。 + options: + - label: 是的, 我愿意提交PR! + + - type: checkboxes + attributes: + label: Code of Conduct + options: + - label: > + 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。 + required: true + + - type: markdown + attributes: + value: "感谢您填写我们的表单!" \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..da603d465 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,10 @@ + +修复了 #XYZ + +### Motivation + + + +### Modifications + + diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..8503bb715 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,93 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + schedule: + - cron: '21 15 * * 5' + +jobs: + analyze: + name: Analyze (${{ matrix.language }}) + # Runner size impacts CodeQL analysis time. To learn more, please see: + # - https://gh.io/recommended-hardware-resources-for-running-codeql + # - https://gh.io/supported-runners-and-hardware-resources + # - https://gh.io/using-larger-runners (GitHub.com only) + # Consider using larger runners or machines with greater resources for possible analysis time improvements. + runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} + timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + permissions: + # required for all workflows + security-events: write + + # required to fetch internal or private CodeQL packs + packages: read + + # only required for workflows in private repositories + actions: read + contents: read + + strategy: + fail-fast: false + matrix: + include: + - language: python + build-mode: none + # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' + # Use `c-cpp` to analyze code written in C, C++ or both + # Use 'java-kotlin' to analyze code written in Java, Kotlin or both + # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both + # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, + # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. + # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how + # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + build-mode: ${{ matrix.build-mode }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + # If the analyze step fails for one of the languages you are analyzing with + # "We were unable to automatically build your code", modify the matrix above + # to set the build mode to "manual" for that language. Then modify this step + # to build your code. + # ℹ️ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + - if: matrix.build-mode == 'manual' + shell: bash + run: | + echo 'If you are using a "manual" build mode for one or more of the' \ + 'languages you are analyzing, replace this with the commands to build' \ + 'your code, for example:' + echo ' make bootstrap' + echo ' make release' + exit 1 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml new file mode 100644 index 000000000..a021daa7c --- /dev/null +++ b/.github/workflows/coverage_test.yml @@ -0,0 +1,34 @@ +name: Run tests and upload coverage + +on: + push + +jobs: + test: + name: Run tests and collect coverage + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest pytest-cov pytest-asyncio + mkdir data + mkdir data/config + mkdir temp + + - name: Run tests + run: PYTHONPATH=./ pytest --cov=. tests/ -v + + - name: Upload results to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 1828268bf..e374cfde8 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -4,20 +4,39 @@ on: release: types: [published] workflow_dispatch: + jobs: - publish-latest-docker-image: + publish-docker: runs-on: ubuntu-latest - name: Build and publish docker image steps: - - name: Checkout - uses: actions/checkout@v2 - - name: Build image - run: | - git clone https://github.com/Soulter/AstrBot - cd AstrBot - docker build -t ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest . - - name: Publish image - run: | - docker login -u ${{ secrets.DOCKER_HUB_USERNAME }} -p ${{ secrets.DOCKER_HUB_PASSWORD }} - docker push ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest + - name: 拉取源码 + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 设置 QEMU + uses: docker/setup-qemu-action@v3 + + - name: 设置 Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: 登录到 DockerHub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_HUB_USERNAME }} + password: ${{ secrets.DOCKER_HUB_PASSWORD }} + + - name: 构建和推送 Docker hub + uses: docker/build-push-action@v6 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: true + tags: | + ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest + ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }} + + - name: Post build notifications + run: echo "Docker image has been built and pushed successfully" + diff --git a/Dockerfile b/Dockerfile index 93c0a2914..055d37bae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,18 @@ WORKDIR /AstrBot COPY . /AstrBot/ +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + build-essential \ + python3-dev \ + libffi-dev \ + libssl-dev \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + RUN python -m pip install -r requirements.txt +EXPOSE 6185 +EXPOSE 6186 + CMD [ "python", "main.py" ] diff --git a/README.md b/README.md index 130e24633..9cd432c9c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

-image +image

@@ -8,6 +8,7 @@ [![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest) python Docker pull +[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot) Static Badge @@ -21,27 +22,42 @@ 🌍 支持的消息平台 - QQ 群、QQ 频道(OneBot、QQ 官方接口) -- Telegram(由 [astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件支持) -- WeChat(微信) (由 [astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件支持) +- Telegram([astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件) -🌍 支持的大模型一览: +🌍 支持的大模型/底座: - OpenAI GPT、DallE 系列 - Claude(由[LLMs插件](https://github.com/Soulter/llms)支持) - HuggingChat(由[LLMs插件](https://github.com/Soulter/llms)支持) - Gemini(由[LLMs插件](https://github.com/Soulter/llms)支持) +- Ollama +- 几乎所有已知模型(可接入 [OneAPI](https://astrbot.soulter.top/docs/docs/adavanced/one-api)) 🌍 机器人支持的能力一览: - 大模型对话、人格、网页搜索 -- 可视化管理面板 +- 可视化仪表盘 - 同时处理多平台消息 - 精确到个人的会话隔离 - 插件支持 - 文本转图片回复(Markdown) -## 🧩 插件支持 +## 🧩 插件 -有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/center/docs/%E4%BD%BF%E7%94%A8/%E6%8F%92%E4%BB%B6) +有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/docs/get-started/plugin) + +## 云部署 + +[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot) + +## ❤️ 贡献 + +欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :) + +对于新功能的添加,请先通过 Issue 进行讨论。 + +## 🔭 展望 + +- [ ] 更多、更开放的 LLM Agent 能力 ## ✨ Demo diff --git a/addons/plugins/helloworld/main.py b/addons/plugins/helloworld/main.py index 56eb99162..25f8e6974 100644 --- a/addons/plugins/helloworld/main.py +++ b/addons/plugins/helloworld/main.py @@ -21,6 +21,14 @@ class HelloWorldPlugin: def __init__(self, context: Context) -> None: self.context = context self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld) + self.context.register_llm_tool("welcome_somebody", [{ + "type": "string", + "name": "name", + "description": "要欢迎的人的名字" + }], "给一个用户发送欢迎文本。", self.welcome_somebody) + + async def welcome_somebody(self, name: str): + return CommandResult().message(f"欢迎{name}!") """ 指令处理函数。 diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 81787c904..300e07b8d 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -22,7 +22,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot') class AstrBotBootstrap(): - def __init__(self) -> None: + def __init__(self) -> None: self.context = Context() # load configs and ensure the backward compatibility @@ -43,6 +43,8 @@ class AstrBotBootstrap(): logger.info(f"使用代理: {http_proxy}, {https_proxy}") else: logger.info("未使用代理。") + + self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' async def run(self): self.command_manager = CommandManager() @@ -63,6 +65,10 @@ class AstrBotBootstrap(): self.context.updator = self.updator self.context.plugin_updator = self.plugin_manager.updator self.context.message_handler = self.message_handler + self.context.command_manager = self.command_manager + + if self.test_mode: + return # load plugins, plugins' commands. self.load_plugins() @@ -84,10 +90,13 @@ class AstrBotBootstrap(): try: result = await task return result + except asyncio.CancelledError: + logger.info(f"{task.get_name()} 任务已取消。") + return except Exception as e: logger.error(traceback.format_exc()) - logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。") - await asyncio.sleep(5) + logger.error(f"{task.get_name()} 任务发生错误。") + return def load_llm(self): f = False diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index 45c8cff4a..501960825 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -1,5 +1,5 @@ -import time -import re +import time, json +import re, os import asyncio import traceback import astrbot.message.unfit_words as uw @@ -14,7 +14,10 @@ from type.command import CommandResult from SparkleLogging.utils.core import LogManager from logging import Logger from nakuru.entities.components import Image +from util.agent.func_call import FuncCall import util.agent.web_searcher as web_searcher +from openai._exceptions import * +from openai.types.chat.chat_completion_message_tool_call import Function logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -109,8 +112,9 @@ class MessageHandler(): self.llm_wake_prefix = self.llm_wake_prefix.strip() self.nicks = self.context.config_helper.wake_prefix self.provider = self.context.llms[0] if len(self.context.llms) > 0 else None - self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix) - + self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix) + self.llm_tools = FuncCall(self.provider) + def set_provider(self, provider: Provider): self.provider = provider @@ -121,18 +125,19 @@ class MessageHandler(): `llm_provider`: the provider to use for LLM. If None, use the default provider ''' msg_plain = message.message_str.strip() - provider = llm_provider if llm_provider else self.provider - inner_provider = False if llm_provider else True + provider = llm_provider if llm_provider else self.provider - self.persist_manager.record_message(message.platform.platform_name, message.session_id) + if os.environ.get('TEST_MODE', 'off') != 'on': + self.persist_manager.record_message(message.platform.platform_name, message.session_id) # TODO: this should be configurable - if not message.message_str: - return MessageResult("Hi~") + # if not message.message_str: + # return MessageResult("Hi~") # check the rate limit if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id): - return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。') + logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。") + return # remove the nick prefix for nick in self.nicks: @@ -151,6 +156,11 @@ class MessageHandler(): use_t2i=cmd_res.is_use_t2i ) + # next is the LLM part + + if message.only_command: + return + # check if the message is a llm-wake-up command if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix): logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。") @@ -169,31 +179,95 @@ class MessageHandler(): if isinstance(comp, Image): image_url = comp.url if comp.url else comp.file break - web_search = self.context.config_helper.llm_settings.web_search - if not web_search and msg_plain.startswith("ws"): - # leverage web search feature - web_search = True - msg_plain = msg_plain.removeprefix("ws").strip() - try: - if web_search: - llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider) + if not self.llm_tools.empty(): + # tools-use + tool_use_flag = True + llm_result = await provider.text_chat( + prompt=msg_plain, + session_id=message.session_id, + tools=self.llm_tools.get_func() + ) + + if isinstance(llm_result, Function): + logger.debug(f"function-calling: {llm_result}") + func_obj = None + for i in self.llm_tools.func_list: + if i["name"] == llm_result.name: + func_obj = i["func_obj"] + break + if not func_obj: + return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。") + try: + args = json.loads(llm_result.arguments) + args['ame'] = message + args['context'] = self.context + try: + cmd_res = await func_obj(**args) + except TypeError as e: + args.pop('ame') + args.pop('context') + cmd_res = await func_obj(**args) + if isinstance(cmd_res, CommandResult): + return MessageResult( + cmd_res.message_chain, + is_command_call=True, + use_t2i=cmd_res.is_use_t2i + ) + elif isinstance(cmd_res, str): + return MessageResult(cmd_res) + elif not cmd_res: + return + else: + return MessageResult(f"AstrBot Function-calling 异常:调用:{llm_result} 时,返回了未知的返回值类型。") + except BaseException as e: + traceback.print_exc() + return MessageResult("AstrBot Function-calling 异常:" + str(e)) + else: + return MessageResult(llm_result) + else: + # normal chat + tool_use_flag = False llm_result = await provider.text_chat( prompt=msg_plain, session_id=message.session_id, image_url=image_url ) + except BadRequestError as e: + if tool_use_flag: + # seems like the model don't support function-calling + logger.error(f"error: {e}. Using local function-calling implementation") + + try: + # use local function-calling implementation + args = { + 'question': llm_result, + 'func_definition': self.llm_tools.func_dump(), + } + _, has_func = await self.llm_tools.func_call(**args) + + if not has_func: + # normal chat + llm_result = await provider.text_chat( + prompt=msg_plain, + session_id=message.session_id, + image_url=image_url + ) + except BaseException as e: + logger.error(traceback.format_exc()) + return CommandResult("AstrBot Function-calling 异常:" + str(e)) + except BaseException as e: logger.error(traceback.format_exc()) logger.error(f"LLM 调用失败。") return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e)) - - # concatenate the reply prefix + + # concatenate reply prefix if self.reply_prefix: llm_result = self.reply_prefix + llm_result - # mask the unsafe content + # mask unsafe content llm_result = self.content_safety_helper.filter_content(llm_result) check = self.content_safety_helper.baidu_check(llm_result) if not check: diff --git a/dashboard/server.py b/dashboard/server.py index f507d54cf..f1675b1b2 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -207,10 +207,11 @@ class AstrBotDashBoard(): try: logger.info(f"正在安装插件 {repo_url}") self.plugin_manager.install_plugin(repo_url) - logger.info(f"安装插件 {repo_url} 成功") + threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() + logger.info(f"安装插件 {repo_url} 成功,2秒后重启") return Response( status="success", - message="安装成功~", + message="安装成功,机器人将在 2 秒内重启。", data=None ).__dict__ except Exception as e: @@ -273,10 +274,11 @@ class AstrBotDashBoard(): try: logger.info(f"正在更新插件 {plugin_name}") self.plugin_manager.update_plugin(plugin_name) - logger.info(f"更新插件 {plugin_name} 成功") + threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() + logger.info(f"更新插件 {plugin_name} 成功,2秒后重启") return Response( status="success", - message="更新成功~", + message="更新成功,机器人将在 2 秒内重启。", data=None ).__dict__ except Exception as e: @@ -326,7 +328,7 @@ class AstrBotDashBoard(): latest = False try: self.astrbot_updator.update(latest=latest, version=version) - threading.Thread(target=self.astrbot_updator._reboot, args=(3, )).start() + threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() return Response( status="success", message="更新成功,机器人将在 3 秒内重启。", diff --git a/main.py b/main.py index 9803ffe0d..3eead7548 100644 --- a/main.py +++ b/main.py @@ -53,7 +53,7 @@ if __name__ == "__main__": check_env() logger = LogManager.GetLogger( - log_name='astrbot', + log_name='astrbot', out_to_console=True, custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S") ) diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index a8b97de25..8245f984f 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -9,6 +9,7 @@ from type.config import VERSION from SparkleLogging.utils.core import LogManager from logging import Logger from nakuru.entities.components import Image +from util.agent.web_searcher import search_from_bing, fetch_website_content logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -116,11 +117,11 @@ class InternalCommandHandler: success=False, message_chain="你没有权限使用该指令", ) - context.updator._reboot(5) + context.updator._reboot(3, context) return CommandResult( hit=True, success=True, - message_chain="AstrBot 将在 5s 后重启。", + message_chain="AstrBot 将在 3s 后重启。", ) def plugin(self, message: AstrMessageEvent, context: Context): @@ -211,6 +212,23 @@ class InternalCommandHandler: ) elif l[1] == 'on': context.web_search = True + context.register_llm_tool("web_search", [{ + "type": "string", + "name": "keyword", + "description": "搜索关键词" + }], + "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", + search_from_bing + ) + context.register_llm_tool("fetch_website_content", [{ + "type": "string", + "name": "url", + "description": "要获取内容的网页链接" + }], + "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", + fetch_website_content + ) + return CommandResult( hit=True, success=True, @@ -218,6 +236,9 @@ class InternalCommandHandler: ) elif l[1] == 'off': context.web_search = False + context.unregister_llm_tool("web_search") + context.unregister_llm_tool("fetch_website_content") + return CommandResult( hit=True, success=True, diff --git a/model/command/manager.py b/model/command/manager.py index 3f7608c71..f6e90290c 100644 --- a/model/command/manager.py +++ b/model/command/manager.py @@ -21,6 +21,7 @@ class CommandMetadata(): plugin_metadata: PluginMetadata handler: callable use_regex: bool = False + ignore_prefix: bool = False description: str = "" class CommandManager(): @@ -35,6 +36,7 @@ class CommandManager(): priority: int, handler: callable, use_regex: bool = False, + ignore_prefix: bool = False, plugin_metadata: PluginMetadata = None, ): ''' @@ -53,6 +55,7 @@ class CommandManager(): plugin_metadata=plugin_metadata, handler=handler, use_regex=use_regex, + ignore_prefix=ignore_prefix, description=description ) if plugin_metadata: @@ -75,9 +78,23 @@ class CommandManager(): priority=request.priority, handler=request.handler, use_regex=request.use_regex, + ignore_prefix=request.ignore_prefix, plugin_metadata=plugin.metadata) self.plugin_commands_waitlist = [] - + + async def check_command_ignore_prefix(self, message_str: str) -> bool: + for _, command in self.commands: + command_metadata = self.commands_handler[command] + if command_metadata.ignore_prefix: + trig = False + if self.commands_handler[command].use_regex: + trig = self.command_parser.regex_match(message_str, command) + else: + trig = message_str.startswith(command) + if trig: + return True + return False + async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult: message_str = message_event.message_str for _, command in self.commands: @@ -89,6 +106,8 @@ class CommandManager(): if trig: logger.info(f"触发 {command} 指令。") command_result = await self.execute_handler(command, message_event, context) + if not command_result: + continue if command_result.hit: return command_result diff --git a/model/platform/__init__.py b/model/platform/__init__.py index 56ea9ded0..7003af473 100644 --- a/model/platform/__init__.py +++ b/model/platform/__init__.py @@ -3,11 +3,13 @@ from typing import Union, Any, List from nakuru.entities.components import Plain, At, Image, BaseMessageComponent from type.astrbot_message import AstrBotMessage from type.command import CommandResult +from type.astrbot_message import MessageType class Platform(): - def __init__(self) -> None: - pass + def __init__(self, platform_name: str, context) -> None: + self.PLATFORM_NAME = platform_name + self.context = context @abc.abstractmethod async def handle_msg(self, message: AstrBotMessage): @@ -30,6 +32,13 @@ class Platform(): 发送消息(主动) ''' pass + + @abc.abstractmethod + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + ''' + 发送消息(主动) + ''' + pass def parse_message_outline(self, message: AstrBotMessage) -> str: ''' @@ -72,4 +81,6 @@ class Platform(): else: rendered_images.append(Image.fromFileSystem(p)) return rendered_images - \ No newline at end of file + + async def record_metrics(self): + self.context.metrics_uploader.increment_platform_stat(self.PLATFORM_NAME) \ No newline at end of file diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 3f03c6cf4..b79f28ff8 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -21,6 +21,7 @@ class AIOCQHTTP(Platform): def __init__(self, context: Context, message_handler: MessageHandler, platform_config: PlatformConfig) -> None: + super().__init__("aiocqhttp", context) assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。" self.message_handler = message_handler @@ -74,7 +75,9 @@ class AIOCQHTTP(Platform): message_str += m['data']['text'].strip() abm.message.append(a) if t == 'image': - a = Image(file=m['data']['file']) + file = m['data']['file'] if 'file' in m['data'] else None + url = m['data']['url'] if 'url' in m['data'] else None + a = Image(file=file, url=url) abm.message.append(a) abm.timestamp = int(time.time()) abm.message_str = message_str @@ -84,7 +87,7 @@ class AIOCQHTTP(Platform): def run_aiocqhttp(self): if not self.host or not self.port: return - self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp') + self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) @self.bot.on_message('group') async def group(event: Event): abm = self.convert_message(event) @@ -106,26 +109,31 @@ class AIOCQHTTP(Platform): return bot async def shutdown_trigger_placeholder(self): - while True: + while self.context.running: await asyncio.sleep(1) def pre_check(self, message: AstrBotMessage) -> bool: - # if message chain contains Plain components or At components which points to self_id, return True + # if message chain contains Plain components or + # At components which points to self_id, return True if message.type == MessageType.FRIEND_MESSAGE: - return True + return True, "friend" for comp in message.message: if isinstance(comp, At) and str(comp.qq) == message.self_id: - return True + return True, "at" + # check commands which ignore prefix + if self.context.command_manager.check_command_ignore_prefix(message.message_str): + return True, "command" # check nicks if self.check_nick(message.message_str): - return True - return False + return True, "nick" + return False, "none" async def handle_msg(self, message: AstrBotMessage): logger.info( f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") - if not self.pre_check(message): + ok, reason = self.pre_check(message) + if not ok: return # 解析 role @@ -134,15 +142,31 @@ class AIOCQHTTP(Platform): role = 'admin' else: role = 'member' + + # parse unified message origin + unified_msg_origin = None + assert isinstance(message.raw_message, Event) + if message.type == MessageType.GROUP_MESSAGE: + unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}" + elif message.type == MessageType.FRIEND_MESSAGE: + unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}" + + logger.debug(f"unified_msg_origin: {unified_msg_origin}") # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, + self.context, + "aiocqhttp", + message.session_id, + role, + unified_msg_origin, + reason == "command") # only_command # transfer control to message handler message_result = await self.message_handler.handle(ame) if not message_result: return - await self.reply_msg(message, message_result.result_message) + await self.reply_msg(message, message_result.result_message, message_result.use_t2i) if message_result.callback: message_result.callback() @@ -153,20 +177,18 @@ class AIOCQHTTP(Platform): async def reply_msg(self, message: AstrBotMessage, - result_message: list): + result_message: list, + use_t2i: bool = None): """ 回复用户唤醒机器人的消息。(被动回复) """ - logger.info( - f"{message.sender.user_id} <- {self.parse_message_outline(message)}") - res = result_message if isinstance(res, str): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.config_helper.t2i and isinstance(res, list): + if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: @@ -179,9 +201,16 @@ class AIOCQHTTP(Platform): await self._reply(message, res) async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]): + await self.record_metrics() if isinstance(message_chain, str): message_chain = [Plain(text=message_chain), ] - + + if isinstance(message, AstrBotMessage): + logger.info( + f"{message.sender.user_id} <- {self.parse_message_outline(message)}") + else: + logger.info(f"回复消息: {message_chain}") + ret = [] image_idx = [] for idx, segment in enumerate(message_chain): @@ -191,24 +220,17 @@ class AIOCQHTTP(Platform): if isinstance(segment, Image): image_idx.append(idx) ret.append(d) + if os.environ.get('TEST_MODE', 'off') == 'on': + logger.info(f"回复消息: {ret}") + return try: - if isinstance(message, AstrBotMessage): - await self.bot.send(message.raw_message, ret) - if isinstance(message, dict): - if 'group_id' in message: - await self.bot.send_group_msg(group_id=message['group_id'], message=ret) - elif 'user_id' in message: - await self.bot.send_private_msg(user_id=message['user_id'], message=ret) - else: - raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。") + await self._reply_wrapper(message, ret) except ActionFailed as e: - logger.error(traceback.format_exc()) - logger.error(f"回复消息失败: {e}") if e.retcode == 1200: # ENOENT if not image_idx: raise e - logger.info("检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。") + logger.warn("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。") for idx in image_idx: if ret[idx]['data']['file'].startswith('file://'): logger.info(f"正在上传图片: {ret[idx]['data']['path']}") @@ -216,8 +238,23 @@ class AIOCQHTTP(Platform): logger.info(f"上传成功。") ret[idx]['data']['file'] = image_url ret[idx]['data']['path'] = image_url - await self.bot.send(message.raw_message, ret) - + await self._reply_wrapper(message, ret) + else: + logger.error(traceback.format_exc()) + logger.error(f"回复消息失败: {e}") + raise e + + async def _reply_wrapper(self, message: Union[AstrBotMessage, Dict], ret: List): + if isinstance(message, AstrBotMessage): + await self.bot.send(message.raw_message, ret) + if isinstance(message, dict): + if 'group_id' in message: + await self.bot.send_group_msg(group_id=message['group_id'], message=ret) + elif 'user_id' in message: + await self.bot.send_private_msg(user_id=message['user_id'], message=ret) + else: + raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。") + async def send_msg(self, target: Dict[str, int], result_message: CommandResult): ''' 以主动的方式给QQ用户、QQ群发送一条消息。 @@ -229,4 +266,12 @@ class AIOCQHTTP(Platform): ''' - await self._reply(target, result_message.message_chain) \ No newline at end of file + await self._reply(target, result_message.message_chain) + + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + if message_type == MessageType.GROUP_MESSAGE: + await self.send_msg({'group_id': int(target)}, result_message) + elif message_type == MessageType.FRIEND_MESSAGE: + await self.send_msg({'user_id': int(target)}, result_message) + else: + raise Exception("aiocqhttp: 无法识别的消息类型。") \ No newline at end of file diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index 0e586f1f2..3edb69d7a 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -33,6 +33,7 @@ class QQNakuru(Platform): def __init__(self, context: Context, message_handler: MessageHandler, platform_config: PlatformConfig) -> None: + super().__init__("nakuru", context) assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。" self.loop = asyncio.new_event_loop() @@ -81,14 +82,17 @@ class QQNakuru(Platform): def pre_check(self, message: AstrBotMessage) -> bool: # if message chain contains Plain components or At components which points to self_id, return True if message.type == MessageType.FRIEND_MESSAGE: - return True + return True, "friend" for comp in message.message: if isinstance(comp, At) and str(comp.qq) == message.self_id: - return True + return True, "at" + # check commands which ignore prefix + if self.context.command_manager.check_command_ignore_prefix(message.message_str): + return True, "command" # check nicks if self.check_nick(message.message_str): - return True - return False + return True, "nick" + return False, "none" def run(self): coro = self.client._run() @@ -102,7 +106,8 @@ class QQNakuru(Platform): (GroupMessage, FriendMessage, GuildMessage)) # 判断是否响应消息 - if not self.pre_check(message): + ok, reason = self.pre_check(message) + if not ok: return # 解析 session_id @@ -124,14 +129,35 @@ class QQNakuru(Platform): else: role = 'member' + # parse unified message origin + unified_msg_origin = None + if message.type == MessageType.GROUP_MESSAGE: + assert isinstance(message.raw_message, GroupMessage) + unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}" + elif message.type == MessageType.FRIEND_MESSAGE: + assert isinstance(message.raw_message, FriendMessage) + unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}" + elif message.type == MessageType.GUILD_MESSAGE: + assert isinstance(message.raw_message, GuildMessage) + unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}" + + logger.debug(f"unified_msg_origin: {unified_msg_origin}") + + # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, + self.context, + "nakuru", + session_id, + role, + unified_msg_origin, + reason == 'command') # only_command # transfer control to message handler message_result = await self.message_handler.handle(ame) if not message_result: return - await self.reply_msg(message, message_result.result_message) + await self.reply_msg(message, message_result.result_message, message_result.use_t2i) if message_result.callback: message_result.callback() @@ -141,7 +167,8 @@ class QQNakuru(Platform): async def reply_msg(self, message: AstrBotMessage, - result_message: List[BaseMessageComponent]): + result_message: List[BaseMessageComponent], + use_t2i: bool = None): """ 回复用户唤醒机器人的消息。(被动回复) """ @@ -158,7 +185,7 @@ class QQNakuru(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.config_helper.t2i and isinstance(res, list): + if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: @@ -171,18 +198,31 @@ class QQNakuru(Platform): await self._reply(source, res) async def _reply(self, source, message_chain: List[BaseMessageComponent]): + await self.record_metrics() if isinstance(message_chain, str): message_chain = [Plain(text=message_chain), ] is_dict = isinstance(source, dict) - if source.type == "GuildMessage": + + typ = None + if is_dict: + if "group_id" in source: + typ = "GroupMessage" + elif "user_id" in source: + typ = "FriendMessage" + elif "guild_id" in source: + typ = "GuildMessage" + else: + typ = source.type + + if typ == "GuildMessage": guild_id = source['guild_id'] if is_dict else source.guild_id chan_id = source['channel_id'] if is_dict else source.channel_id await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain) - elif source.type == "FriendMessage": + elif typ == "FriendMessage": user_id = source['user_id'] if is_dict else source.user_id await self.client.sendFriendMessage(user_id, message_chain) - elif source.type == "GroupMessage": + elif typ == "GroupMessage": group_id = source['group_id'] if is_dict else source.group_id # 过长时forward发送 plain_text_len = 0 @@ -219,6 +259,23 @@ class QQNakuru(Platform): guild_id 不是频道号。 ''' await self._reply(target, result_message.message_chain) + + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + ''' + 以主动的方式给用户、群或者频道发送一条消息。 + + `message_type` 为 MessageType 枚举类型。 + + - 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE; + - 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE; + - 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。 + ''' + if message_type == MessageType.FRIEND_MESSAGE: + await self.send_msg({"user_id": int(target)}, result_message) + elif message_type == MessageType.GROUP_MESSAGE: + await self.send_msg({"group_id": int(target)}, result_message) + elif message_type == MessageType.GUILD_MESSAGE: + await self.send_msg({"channel_id": int(target)}, result_message) def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage: abm = AstrBotMessage() @@ -239,7 +296,7 @@ class QQNakuru(Platform): str(message.sender.user_id), str(message.sender.nickname) ) - abm.tag = "gocq" + abm.tag = "nakuru" abm.message = message.message return abm diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 297ceb4b8..e6716532f 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -57,6 +57,7 @@ class QQOfficial(Platform): message_handler: MessageHandler, platform_config: PlatformConfig, test_mode = False) -> None: + super().__init__("qqofficial", context) assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。" self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -86,12 +87,13 @@ class QQOfficial(Platform): ) self.client = botClient( intents=self.intents, - bot_log=False + bot_log=False, + timeout=20, ) self.client.set_platform(self) - self.test_mode = test_mode + self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False): plain_text = "" @@ -117,7 +119,7 @@ class QQOfficial(Platform): abm.timestamp = int(time.time()) abm.raw_message = message abm.message_id = message.id - abm.tag = "qqchan" + abm.tag = "qqofficial" msg: List[BaseMessageComponent] = [] if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage): @@ -177,7 +179,7 @@ class QQOfficial(Platform): appid=self.appid, secret=self.secret ) - + async def handle_msg(self, message: AstrBotMessage): assert isinstance(message.raw_message, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage)) @@ -207,13 +209,13 @@ class QQOfficial(Platform): role = 'member' # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqchan", session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id, role) message_result = await self.message_handler.handle(ame) if not message_result: return - ret = await self.reply_msg(message, message_result.result_message) + ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i) if message_result.callback: message_result.callback() @@ -225,7 +227,8 @@ class QQOfficial(Platform): async def reply_msg(self, message: AstrBotMessage, - result_message: List[BaseMessageComponent]): + result_message: List[BaseMessageComponent], + use_t2i: bool = None): ''' 回复频道消息 ''' @@ -240,7 +243,7 @@ class QQOfficial(Platform): msg_ref = None rendered_images = [] - if self.context.config_helper.t2i and isinstance(result_message, list): + if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(result_message) if isinstance(result_message, list): @@ -311,6 +314,7 @@ class QQOfficial(Platform): return await self._reply(**data) async def _reply(self, **kwargs): + await self.record_metrics() if 'group_openid' in kwargs or 'openid' in kwargs: # QQ群组消息 if 'file_image' in kwargs and kwargs['file_image']: @@ -379,6 +383,9 @@ class QQOfficial(Platform): if image_path: payload['file_image'] = image_path await self._reply(**payload) + + async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): + raise NotImplementedError("qqofficial 不支持此方法。") def wait_for_message(self, channel_id: int) -> AstrBotMessage: ''' @@ -395,4 +402,4 @@ class QQOfficial(Platform): cnt += 1 if cnt > 300: raise Exception("等待消息超时。") - time.sleep(1)() + time.sleep(1) diff --git a/model/plugin/command.py b/model/plugin/command.py index 3321d52c7..1e4d8fab9 100644 --- a/model/plugin/command.py +++ b/model/plugin/command.py @@ -15,12 +15,13 @@ class CommandRegisterRequest(): handler: Callable use_regex: bool = False plugin_name: str = None + ignore_prefix: bool = False class PluginCommandBridge(): def __init__(self, cached_plugins: RegisteredPlugins): self.plugin_commands_waitlist: List[CommandRegisterRequest] = [] self.cached_plugins = cached_plugins - def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False): - self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name)) + def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False): + self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix)) \ No newline at end of file diff --git a/model/plugin/manager.py b/model/plugin/manager.py index 35fcb90e7..43e9cc30f 100644 --- a/model/plugin/manager.py +++ b/model/plugin/manager.py @@ -5,6 +5,7 @@ import traceback import uuid import shutil import yaml +import subprocess from util.updator.plugin_updator import PluginUpdator from util.io import remove_dir, download_file @@ -84,8 +85,28 @@ class PluginManager(): def update_plugin_dept(self, path): mirror = "https://mirrors.aliyun.com/pypi/simple/" py = sys.executable - os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet") - + # os.system(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com") + + process = subprocess.Popen(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com", + stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True) + + while True: + output = process.stdout.readline() + if output == '' and process.poll() is not None: + break + if output: + output = output.strip() + if output.startswith("Requirement already satisfied"): + continue + if output.startswith("Using cached"): + continue + if output.startswith("Looking in indexes"): + continue + logger.info(output) + + rc = process.poll() + + def install_plugin(self, repo_url: str): ppath = self.plugin_store_path @@ -95,10 +116,13 @@ class PluginManager(): plugin_path = self.updator.update(repo_url) with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: f.write(repo_url) + + self.check_plugin_dept_update() - ok, err = self.plugin_reload() - if not ok: - raise Exception(err) + return plugin_path + # ok, err = self.plugin_reload() + # if not ok: + # raise Exception(err) def download_from_repo_url(self, target_path: str, repo_url: str): repo_namespace = repo_url.split("/")[-2:] @@ -158,7 +182,7 @@ class PluginManager(): logger.info(f"正在加载插件 {root_dir_name} ...") - # self.check_plugin_dept_update(cached_plugins, root_dir_name) + self.check_plugin_dept_update(target_plugin=root_dir_name) module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p]) @@ -227,10 +251,12 @@ class PluginManager(): # remove the temp dir remove_dir(temp_dir) + + self.check_plugin_dept_update() - ok, err = self.plugin_reload() - if not ok: - raise Exception(err) + # ok, err = self.plugin_reload() + # if not ok: + # raise Exception(err) def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata: metadata = None diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 8f2ffb870..e1bf13ab0 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -1,3 +1,5 @@ +import os +import asyncio import json import time import tiktoken @@ -6,13 +8,12 @@ import traceback import base64 from openai import AsyncOpenAI -from openai.types.images_response import ImagesResponse from openai.types.chat.chat_completion import ChatCompletion from openai._exceptions import * +from util.io import download_image_by_url from astrbot.persist.helper import dbConn from model.provider.provider import Provider -from util import general_utils as gu from util.cmd_config import LLMConfig from SparkleLogging.utils.core import LogManager from logging import Logger @@ -149,7 +150,7 @@ class ProviderOpenAIOfficial(Provider): 将图片转换为 base64 ''' if image_url.startswith("http"): - image_url = await gu.download_image_by_url(image_url) + image_url = await download_image_by_url(image_url) with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode() @@ -292,6 +293,9 @@ class ProviderOpenAIOfficial(Provider): extra_conf: Dict = None, **kwargs ) -> str: + if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on": + return "这是一个测试消息。" + super().accu_model_stat() if not session_id: session_id = "unknown" @@ -364,7 +368,9 @@ class ProviderOpenAIOfficial(Provider): logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}") await self.switch_to_next_key() rate_limit_retry += 1 - time.sleep(1) + await asyncio.sleep(1) + except NotFoundError as e: + raise e except Exception as e: retry += 1 if retry >= 3: @@ -376,7 +382,7 @@ class ProviderOpenAIOfficial(Provider): logger.warning(traceback.format_exc()) logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。") - time.sleep(1) + await asyncio.sleep(1) assert isinstance(completion, ChatCompletion) logger.debug(f"openai completion: {completion.usage}") @@ -446,7 +452,7 @@ class ProviderOpenAIOfficial(Provider): logger.error(traceback.format_exc()) raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。") logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。") - time.sleep(1) + await asyncio.sleep(1) async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool: if session_id is None: return False diff --git a/tests/mocks/onebot.py b/tests/mocks/onebot.py new file mode 100644 index 000000000..66df3d1ee --- /dev/null +++ b/tests/mocks/onebot.py @@ -0,0 +1,13 @@ +from aiocqhttp import Event + +class MockOneBotMessage(): + def __init__(self): + # 这些数据不是敏感的 + self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470}) + self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'}) + + def create_random_group_message(self): + return self.group_event_sample + + def create_random_direct_message(self): + return self.friend_event_sample \ No newline at end of file diff --git a/tests/mocks/qq_official.py b/tests/mocks/qq_official.py new file mode 100644 index 000000000..0d665d289 --- /dev/null +++ b/tests/mocks/qq_official.py @@ -0,0 +1,45 @@ +import botpy.message + +class MockQQOfficialMessage(): + def __init__(self): + # 这些数据已经经过去敏处理 + self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'} + self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'} + self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'} + self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849" + + self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'} + self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'} + self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'} + self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" + + self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'} + self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'} + self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'} + self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" + + def create_random_group_message(self): + mocked = botpy.message.GroupMessage( + api=None, + event_id=self.group_event_id_sample, + data=self.group_plain_text_sample + ) + return mocked + + def create_random_guild_message(self): + mocked = botpy.message.Message( + api=None, + event_id=self.guild_event_id_sample, + data=self.guild_plain_text_sample + ) + return mocked + + def create_random_direct_message(self): + mocked = botpy.message.DirectMessage( + api=None, + event_id=self.direct_event_id_sample, + data=self.direct_plain_text_sample + ) + return mocked + + diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 000000000..a5fc4578a --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,65 @@ +import asyncio +import pytest +import os + +from tests.mocks.qq_official import MockQQOfficialMessage +from tests.mocks.onebot import MockOneBotMessage + +from astrbot.bootstrap import AstrBotBootstrap +from model.platform.qq_official import QQOfficial +from model.platform.qq_aiocqhttp import AIOCQHTTP +from type.astrbot_message import * +from type.message_event import * +from SparkleLogging.utils.core import LogManager +from logging import Formatter + +logger = LogManager.GetLogger( +log_name='astrbot', + out_to_console=True, + custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S") +) +pytest_plugins = ('pytest_asyncio',) + +os.environ['TEST_MODE'] = 'on' +bootstrap = AstrBotBootstrap() +asyncio.run(bootstrap.run()) + +qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler) +aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler) + +class TestBasicMessageHandle(): + @pytest.mark.asyncio + async def test_qqofficial_group_message(self): + group_message = MockQQOfficialMessage().create_random_group_message() + abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE) + ret = await qq_official.handle_msg(abm) + print(ret) + + @pytest.mark.asyncio + async def test_qqofficial_guild_message(self): + guild_message = MockQQOfficialMessage().create_random_guild_message() + abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE) + ret = await qq_official.handle_msg(abm) + print(ret) + + # 有共同性,为了节约开销,不测试频道私聊。 + # @pytest.mark.asyncio + # async def test_qqofficial_private_message(self): + # private_message = MockQQOfficialMessage().create_random_direct_message() + # abm = qq_official._parse_from_qqofficial(private_message, MessageType.FRIEND_MESSAGE) + # ret = await qq_official.handle_msg(abm) + # print(ret) + + @pytest.mark.asyncio + async def test_aiocqhttp_group_message(self): + event = MockOneBotMessage().create_random_group_message() + abm = aiocqhttp.convert_message(event) + ret = await aiocqhttp.handle_msg(abm) + print(ret) + + @pytest.mark.asyncio + async def test_aiocqhttp_direct_message(self): + event = MockOneBotMessage().create_random_direct_message() + abm = aiocqhttp.convert_message(event) + ret = await aiocqhttp.handle_msg(abm) + print(ret) \ No newline at end of file diff --git a/type/command.py b/type/command.py index a8723063f..ac9ca0d50 100644 --- a/type/command.py +++ b/type/command.py @@ -2,7 +2,6 @@ from typing import Union, List, Callable from dataclasses import dataclass from nakuru.entities.components import Plain, Image - @dataclass class CommandItem(): ''' @@ -19,12 +18,17 @@ class CommandResult(): 用于在Command中返回多个值 ''' - def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None: + def __init__(self, + hit: bool = True, + success: bool = True, + message_chain: list = [], + command_name: str = "unknown_command", + use_t2i: bool = None) -> None: self.hit = hit self.success = success self.message_chain = message_chain self.command_name = command_name - self.is_use_t2i = None # default + self.is_use_t2i = use_t2i def message(self, message: str): ''' @@ -63,14 +67,12 @@ class CommandResult(): self.message_chain = [Image.fromFileSystem(path), ] return self - # def use_t2i(self, use_t2i: bool): - # ''' - # 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 - - # CommandResult().use_t2i(False) - # ''' - # self.is_use_t2i = use_t2i - # return self + def use_t2i(self, use_t2i: bool): + ''' + 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 + ''' + self.is_use_t2i = use_t2i + return self def _result_tuple(self): return (self.success, self.message_chain, self.command_name) diff --git a/type/config.py b/type/config.py index 17096ea5c..a8230fd7e 100644 --- a/type/config.py +++ b/type/config.py @@ -1,4 +1,4 @@ -VERSION = '3.3.7' +VERSION = '3.3.9' DEFAULT_CONFIG = { "qqbot": { @@ -353,4 +353,4 @@ CONFIG_METADATA_2 = { "password": {"description": "密码", "type": "string"}, } }, -} \ No newline at end of file +} diff --git a/type/message_event.py b/type/message_event.py index 222ac91e0..dc0221897 100644 --- a/type/message_event.py +++ b/type/message_event.py @@ -2,7 +2,14 @@ from typing import List, Union, Optional from dataclasses import dataclass from type.register import RegisteredPlatform from type.types import Context -from type.astrbot_message import AstrBotMessage +from type.astrbot_message import AstrBotMessage, MessageType + +@dataclass +class MessageResult(): + result_message: Union[str, list] + is_command_call: Optional[bool] = False + use_t2i: Optional[bool] = None # None 为跟随用户设置 + callback: Optional[callable] = None class AstrMessageEvent(): @@ -12,7 +19,9 @@ class AstrMessageEvent(): platform: RegisteredPlatform, role: str, context: Context, - session_id: str = None): + session_id: str = None, + unified_msg_origin: str = None, + only_command: bool = False): ''' AstrBot 消息事件。 @@ -22,6 +31,8 @@ class AstrMessageEvent(): `role`: 角色,`admin` or `member` `context`: 全局对象 `session_id`: 会话id + `unified_msg_origin`: 统一消息来源 + `only_command`: 是否只处理指令,而不使用 LLM 回复 ''' self.context = context self.message_str = message_str @@ -29,24 +40,24 @@ class AstrMessageEvent(): self.platform = platform self.role = role self.session_id = session_id + self.unified_msg_origin = unified_msg_origin + self.only_command = only_command def from_astrbot_message(message: AstrBotMessage, context: Context, platform_name: str, session_id: str, - role: str = "member"): + role: str = "member", + unified_msg_origin: str = None, + only_command: bool = False): ame = AstrMessageEvent(message.message_str, message, context.find_platform(platform_name), role, context, - session_id) + session_id, + unified_msg_origin, + only_command=only_command) return ame -@dataclass -class MessageResult(): - result_message: Union[str, list] - is_command_call: Optional[bool] = False - use_t2i: Optional[bool] = None # None 为跟随用户设置 - callback: Optional[callable] = None diff --git a/type/types.py b/type/types.py index 9f9035ca1..a0dace291 100644 --- a/type/types.py +++ b/type/types.py @@ -1,4 +1,4 @@ -import asyncio +import asyncio, os from asyncio import Task from type.register import * from typing import List, Awaitable @@ -8,8 +8,11 @@ from util.t2i.renderer import TextToImageRenderer from util.updator.astrbot_updator import AstrBotUpdator from util.image_uploader import ImageUploader from util.updator.plugin_updator import PluginUpdator +from type.command import CommandResult +from type.astrbot_message import MessageType from model.plugin.command import PluginCommandBridge from model.provider.provider import Provider +from util.agent.func_call import FuncCall class Context: @@ -40,6 +43,9 @@ class Context: self.image_uploader = ImageUploader() self.message_handler = None # see astrbot/message/handler.py self.ext_tasks: List[Task] = [] + + self.command_manager = None + self.running = True # useless # self.reply_prefix = "" @@ -50,7 +56,8 @@ class Context: description: str, priority: int, handler: callable, - use_regex: bool = False): + use_regex: bool = False, + ignore_prefix: bool = False): ''' 注册插件指令。 @@ -60,8 +67,19 @@ class Context: @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context @param use_regex: 是否使用正则表达式匹配指令名。 + @param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。 + + .. Example:: + + ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。 ''' - self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex) + self.plugin_command_bridge.register_command(plugin_name, + command_name, + description, + priority, + handler, + use_regex, + ignore_prefix) def register_task(self, coro: Awaitable, task_name: str): ''' @@ -80,10 +98,48 @@ class Context: `provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。 ''' self.llms.append(RegisteredLLM(llm_name, provider, origin)) + + def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable): + ''' + 为函数调用(function-calling / tools-use)添加工具。 + + @param name: 函数名 + @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] + @param desc: 函数描述 + @param func_obj: 处理函数 + ''' + self.message_handler.llm_tools.add_func(tool_name, params, desc, func) + def unregister_llm_tool(self, tool_name: str): + ''' + 删除一个函数调用工具。 + ''' + self.message_handler.llm_tools.remove_func(tool_name) + def find_platform(self, platform_name: str) -> RegisteredPlatform: for platform in self.platforms: if platform_name == platform.platform_name: return platform - - raise ValueError("couldn't find the platform you specified") + + if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错 + raise ValueError("couldn't find the platform you specified") + + async def send_message(self, unified_msg_origin: str, message: CommandResult): + ''' + 发送消息。 + + `unified_msg_origin`: 统一消息来源 + `message`: 消息内容 + ''' + l = unified_msg_origin.split(":") + if len(l) != 3: + raise ValueError("Invalid unified_msg_origin") + platform_name, message_type, id = l + platform = self.find_platform(platform_name) + await platform.platform_instance.send_msg_new(MessageType(message_type), id, message) + + def get_current_llm_provider(self) -> Provider: + ''' + 获取当前的 LLM Provider。 + ''' + return self.message_handler.provider \ No newline at end of file diff --git a/util/agent/func_call.py b/util/agent/func_call.py index ffacf242b..830496cab 100644 --- a/util/agent/func_call.py +++ b/util/agent/func_call.py @@ -1,9 +1,6 @@ - +from model.provider.provider import Provider import json -import util.general_utils as gu - -import time - +import textwrap class FuncCallJsonFormatError(Exception): def __init__(self, msg): @@ -22,16 +19,24 @@ class FuncNotFoundError(Exception): class FuncCall(): - def __init__(self, provider) -> None: + def __init__(self, provider: Provider) -> None: self.func_list = [] self.provider = provider + + def empty(self) -> bool: + return len(self.func_list) == 0 - def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj=None) -> None: - if name == None or func_args == None or desc == None or func_obj == None: - raise FuncCallJsonFormatError( - "name, func_args, desc must be provided.") + def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None: + ''' + 为函数调用(function-calling / tools-use)添加工具。 + + @param name: 函数名 + @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] + @param desc: 函数描述 + @param func_obj: 处理函数 + ''' params = { - "type": "object", # hardcore here + "type": "object", # hard-coded here "properties": {} } for param in func_args: @@ -39,15 +44,24 @@ class FuncCall(): "type": param['type'], "description": param['description'] } - self._func = { + _func = { "name": name, "parameters": params, "description": desc, "func_obj": func_obj, } - self.func_list.append(self._func) - - def func_dump(self, intent: int = 2) -> str: + self.func_list.append(_func) + + def remove_func(self, name: str) -> None: + ''' + 删除一个函数调用工具。 + ''' + for i, f in enumerate(self.func_list): + if f["name"] == name: + self.func_list.pop(i) + break + + def func_dump(self) -> str: _l = [] for f in self.func_list: _l.append({ @@ -55,7 +69,7 @@ class FuncCall(): "parameters": f["parameters"], "description": f["description"], }) - return json.dumps(_l, indent=intent, ensur_ascii=False) + return json.dumps(_l, ensure_ascii=False) def get_func(self) -> list: _l = [] @@ -70,64 +84,39 @@ class FuncCall(): }) return _l - def func_call(self, question, func_definition, is_task=False, tasks=None, taskindex=-1, is_summary=True, session_id=None): + async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple: + + if not provider: + provider = self.provider - funccall_prompt = """ -我正实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(意味着你不是创造函数)。 -下面会给你提供可能用到的函数相关信息和一个问题,你需要将其转换成给定的函数调用。 -- 你的返回信息只含json,请严格仿照以下内容(不含注释),必须含有`res`,`func_call`字段: -``` -{ - "res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。 - "func_call": [ // 这是一个数组,里面包含了所有的函数调用,如果没有函数调用,那么这个数组是空数组。 - { - "res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。 - "name": str, // 函数的名字 - "args_type": { - "arg1": str, // 函数的参数的类型 - "arg2": str, - ... - }, - "args": { - "arg1": any, // 函数的参数 - "arg2": any, - ... - } - }, - ... // 可能在这个问题中会有多个函数调用 - ], -} -``` -- 如果用户的要求较复杂,允许返回多个函数调用,但需保证这些函数调用的顺序正确。 -- 当问题没有提到给定的函数时,相当于提问方不打算使用function call功能,这时你可以在res中正常输出这个问题的回答(以AI的身份正常回答该问题,并将答案输出在res字段中,回答不要涉及到任何函数调用的内容,就只是正常讨论这个问题。) + prompt = textwrap.dedent(f""" + ROLE: + 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。 -提供的函数是: + TOOLS: + 可用的函数列表: -""" + {func_definition} - prompt = f"{funccall_prompt}\n```\n{func_definition}\n```\n" - prompt += f""" -用户的提问是: -``` -{question} -``` -""" + LIMIT: + 1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。 + 2. 你的 Json 返回的格式如下:`[{{"name": "", "args": }}, ...]`。参数根据上面提供的函数列表中的参数来填写。 + 3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。 + 4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。 - # if is_task: - # # task_prompt = f"\n任务列表为{str(tasks)}\n你目前进行到了任务{str(taskindex)}, **你不需要重新进行已经进行过的任务, 不要生成已经进行过的**" - # prompt += task_prompt + EXAMPLE: + 1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}] - # provider.forget() + 用户的提问是:{question} + """) _c = 0 while _c < 3: try: - res = self.provider.text_chat(prompt=prompt, session_id=session_id) + res = await provider.text_chat(prompt, session_id) + print(res) if res.find('```') != -1: res = res[res.find('```json') + 7: res.rfind('```')] - gu.log("REVGPT func_call json result", - bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) - print(res) res = json.loads(res) break except Exception as e: @@ -136,112 +125,25 @@ class FuncCall(): raise e if "The message you submitted was too long" in str(e): raise e + + if 'res' in res and not res['res']: + return "", False - invoke_func_res = "" - - if "func_call" in res and len(res["func_call"]) > 0: - task_list = res["func_call"] - - invoke_func_res_list = [] - - for res in task_list: - # 说明有函数调用 - func_name = res["name"] - # args_type = res["args_type"] - args = res["args"] - # 调用函数 - # func = eval(func_name) - func_target = None - for func in self.func_list: - if func["name"] == func_name: - func_target = func["func_obj"] - break - if func_target == None: - raise FuncNotFoundError( - f"Request function {func_name} not found.") - t_res = str(func_target(**args)) - invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n" - invoke_func_res_list.append(invoke_func_res) - gu.log(f"[FUNC| {func_name} invoked]", - bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) - # print(str(t_res)) - - if is_summary: - - # 生成返回结果 - after_prompt = """ -有以下内容:"""+invoke_func_res+""" -请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。 -用户的提问是: -```""" + question + """``` -- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。 -- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释): -```json -{ - "res": string, // 回答的内容 - "func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false -} -``` -- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。""" - - _c = 0 - while _c < 5: - try: - res = self.provider.text_chat(prompt=after_prompt, session_id=session_id) - # 截取```之间的内容 - gu.log( - "DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) - print(res) - gu.log( - "DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) - if res.find('```') != -1: - res = res[res.find('```json') + - 7: res.rfind('```')] - gu.log("REVGPT after_func_call json result", - bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) - after_prompt_res = res - after_prompt_res = json.loads(after_prompt_res) - break - except Exception as e: - _c += 1 - if _c == 5: - raise e - if "The message you submitted was too long" in str(e): - # 如果返回的内容太长了,那么就截取一部分 - time.sleep(3) - invoke_func_res = invoke_func_res[:int( - len(invoke_func_res) / 2)] - after_prompt = """ -函数返回以下内容:"""+invoke_func_res+""" -请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。 -用户的提问是: -```""" + question + """``` -- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。 -- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释): -```json -{ - "res": string, // 回答的内容 - "func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false -} -``` -- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。""" - else: - raise e - - if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]: - # 如果需要重新调用函数 - # 重新调用函数 - gu.log("REVGPT func_call_again", - bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"]) - res = self.func_call(question, func_definition) - return res, True - - gu.log("REVGPT func callback:", - bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"]) - # print(after_prompt_res["res"]) - return after_prompt_res["res"], True - else: - return str(invoke_func_res_list), True - else: - # print(res["res"]) - return res["res"], False + tool_call_result = [] + for tool in res: + # 说明有函数调用 + func_name = tool["name"] + args = tool["args"] + # 调用函数 + tool_callable = None + for func in self.func_list: + if func["name"] == func_name: + tool_callable = func["func_obj"] + break + if not tool_callable: + raise FuncNotFoundError( + f"Request function {func_name} not found.") + ret = await tool_callable(**args) + if ret: + tool_call_result.append(str(ret)) + return tool_call_result, True diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py index 6badf8188..d9b384314 100644 --- a/util/agent/web_searcher.py +++ b/util/agent/web_searcher.py @@ -1,13 +1,11 @@ -import traceback import random -import json -import asyncio import aiohttp import os from readability import Document from bs4 import BeautifulSoup from openai.types.chat.chat_completion_message_tool_call import Function +from openai._exceptions import * from util.agent.func_call import FuncCall from util.websearch.config import HEADERS, USER_AGENTS from util.websearch.bing import Bing @@ -16,6 +14,8 @@ from util.websearch.google import Google from model.provider.provider import Provider from SparkleLogging.utils.core import LogManager from logging import Logger +from type.types import Context +from type.message_event import AstrMessageEvent logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -31,24 +31,7 @@ def tidy_text(text: str) -> str: ''' return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") -# def special_fetch_zhihu(link: str) -> str: -# ''' -# function-calling 函数, 用于获取知乎文章的内容 -# ''' -# response = requests.get(link, headers=HEADERS) -# response.encoding = "utf-8" -# soup = BeautifulSoup(response.text, "html.parser") - -# if "zhuanlan.zhihu.com" in link: -# r = soup.find(class_="Post-RichTextContainer") -# else: -# r = soup.find(class_="List-item").find(class_="RichContent-inner") -# if r is None: -# print("debug: zhihu none") -# raise Exception("zhihu none") -# return tidy_text(r.text) - -async def search_from_bing(keyword: str) -> str: +async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str: ''' tools, 从 bing 搜索引擎搜索 ''' @@ -84,10 +67,11 @@ async def search_from_bing(keyword: str) -> str: site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n" idx += 1 - return ret + + return await summarize(context, ame, ret) -async def fetch_website_content(url): +async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str): header = HEADERS header.update({'User-Agent': random.choice(USER_AGENTS)}) async with aiohttp.ClientSession() as session: @@ -97,87 +81,25 @@ async def fetch_website_content(url): ret = doc.summary(html_partial=True) soup = BeautifulSoup(ret, 'html.parser') ret = tidy_text(soup.get_text()) - return ret - - -async def web_search(prompt, provider: Provider, session_id, official_fc=False): - ''' - official_fc: 使用官方 function-calling - ''' - new_func_call = FuncCall(provider) - - new_func_call.add_func("web_search", [{ - "type": "string", - "name": "keyword", - "description": "搜索关键词" - }], - "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", - search_from_bing - ) - new_func_call.add_func("fetch_website_content", [{ - "type": "string", - "name": "url", - "description": "要获取内容的网页链接" - }], - "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", - fetch_website_content - ) + return await summarize(context, ame, ret) - has_func = False - function_invoked_ret = "" - if official_fc: - # we use official function-calling - result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func()) - if isinstance(result, Function): - logger.debug(f"web_searcher - function-calling: {result}") - func_obj = None - for i in new_func_call.func_list: - if i["name"] == result.name: - func_obj = i["func_obj"] - break - if not func_obj: - return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)" - try: - args = json.loads(result.arguments) - function_invoked_ret = await func_obj(**args) - has_func = True - except BaseException as e: - traceback.print_exc() - return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)" - else: - return result - else: - # we use our own function-calling - try: - args = { - 'question': prompt, - 'func_definition': new_func_call.func_dump(), - 'is_task': False, - 'is_summary': False, - } - function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args) - except BaseException as e: - res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)" - return res - has_func = True - - if has_func: - await provider.forget(session_id=session_id, ) - summary_prompt = f""" +async def summarize(context: Context, ame: AstrMessageEvent, text: str): + + summary_prompt = f""" 你是一个专业且高效的助手,你的任务是 -1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; -2. 简单地发表你对这个问题的简略看法。 +1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结; +2. 简单地发表你对这个问题的看法。 # 例子 1. 从网上的信息来看,可以知道...我个人认为...你觉得呢? 2. 根据网上的最新信息,可以得知...我觉得...你怎么看? # 限制 -1. 限制在 200 字以内; +1. 限制在 200-300 字; 2. 请**直接输出总结**,不要输出多余的内容和提示语。 - + # 相关材料 -{function_invoked_ret}""" - ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id) - return ret - return function_invoked_ret +{text}""" + + provider = context.get_current_llm_provider() + return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id) \ No newline at end of file diff --git a/util/general_utils.py b/util/general_utils.py deleted file mode 100644 index 270faf418..000000000 --- a/util/general_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -import time -import asyncio -import requests -import json -import sys -import psutil - -from type.types import Context -from SparkleLogging.utils.core import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - -def run_monitor(global_object: Context): - ''' - 监测机器性能 - - Bot 内存使用量 - - CPU 占用率 - ''' - start_time = time.time() - while True: - stat = global_object.dashboard_data.stats - # 程序占用的内存大小 - mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB - stat['sys_perf'] = { - 'memory': mem, - 'cpu': psutil.cpu_percent() - } - stat['sys_start_time'] = start_time - time.sleep(30) diff --git a/util/metrics.py b/util/metrics.py index e905476f8..84472ba45 100644 --- a/util/metrics.py +++ b/util/metrics.py @@ -66,6 +66,9 @@ class MetricUploader(): except BaseException as e: pass await asyncio.sleep(30*60) + + def increment_platform_stat(self, platform_name: str): + self.platform_stats[platform_name] = self.platform_stats.get(platform_name, 0) + 1 def clear(self): self.platform_stats.clear() diff --git a/util/updator/astrbot_updator.py b/util/updator/astrbot_updator.py index eccdf2089..b6fbd9243 100644 --- a/util/updator/astrbot_updator.py +++ b/util/updator/astrbot_updator.py @@ -9,7 +9,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot') class AstrBotUpdator(RepoZipUpdator): def __init__(self): - self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases" def terminate_child_processes(self): @@ -30,9 +30,11 @@ class AstrBotUpdator(RepoZipUpdator): except psutil.NoSuchProcess: pass - def _reboot(self, delay: int = None): - if delay: time.sleep(delay) + def _reboot(self, delay: int = None, context = None): + # if delay: time.sleep(delay) py = sys.executable + context.running = False + time.sleep(3) self.terminate_child_processes() py = py.replace(" ", "\\ ") try: