Merge branch 'master' into config-refactor
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
comment:
|
||||
# add "condensed_" to "header", "files" and "footer"
|
||||
layout: "condensed_header, condensed_files, condensed_footer"
|
||||
hide_project_coverage: TRUE # set to true
|
||||
@@ -0,0 +1,5 @@
|
||||
[run]
|
||||
omit =
|
||||
*/site-packages/*
|
||||
*/dist-packages/*
|
||||
your_package_name/tests/*
|
||||
@@ -0,0 +1,18 @@
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
# github acions
|
||||
.github/
|
||||
.*ignore
|
||||
.git/
|
||||
# User-specific stuff
|
||||
.idea/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv*/
|
||||
ENV/
|
||||
.conda/
|
||||
README*.md
|
||||
@@ -0,0 +1,82 @@
|
||||
name: '🐛 报告 Bug'
|
||||
title: '[Bug]'
|
||||
description: 提交报告帮助我们改进。
|
||||
labels: [ 'bug' ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 发生了什么
|
||||
description: 描述你遇到的异常
|
||||
placeholder: >
|
||||
一个清晰且具体的描述这个异常是什么。
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 如何复现?
|
||||
description: >
|
||||
复现该问题的步骤
|
||||
placeholder: >
|
||||
如: 1. 打开 '...'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: AstrBot 版本与部署方式
|
||||
description: >
|
||||
请提供您的 AstrBot 版本和部署方式。
|
||||
placeholder: >
|
||||
如: 3.1.8 Docker, 3.1.7 Windows启动器
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统
|
||||
description: |
|
||||
你在哪个操作系统上遇到了这个问题?
|
||||
multiple: false
|
||||
options:
|
||||
- 'Windows'
|
||||
- 'macOS'
|
||||
- 'Linux'
|
||||
- 'Other'
|
||||
- 'Not sure'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 额外信息
|
||||
description: >
|
||||
任何额外信息,如报错日志、截图等。
|
||||
placeholder: >
|
||||
请提供完整的报错日志或截图。
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 你愿意提交 PR 吗?
|
||||
description: >
|
||||
这绝对不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
|
||||
options:
|
||||
- label: 是的,我愿意提交 PR!
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
options:
|
||||
- label: >
|
||||
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "感谢您填写我们的表单!"
|
||||
@@ -0,0 +1,42 @@
|
||||
|
||||
name: '🎉 功能建议'
|
||||
title: "[Feature]"
|
||||
description: 提交建议帮助我们改进。
|
||||
labels: [ "enhancement" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您抽出时间提出新功能建议,请准确解释您的想法。
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 描述
|
||||
description: 简短描述您的功能建议。
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 使用场景
|
||||
description: 你想要发生什么?
|
||||
placeholder: >
|
||||
一个清晰且具体的描述这个功能的使用场景。
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 你愿意提交PR吗?
|
||||
description: >
|
||||
这不是必须的,但我们欢迎您的贡献。
|
||||
options:
|
||||
- label: 是的, 我愿意提交PR!
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
options:
|
||||
- label: >
|
||||
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "感谢您填写我们的表单!"
|
||||
@@ -0,0 +1,10 @@
|
||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||
修复了 #XYZ
|
||||
|
||||
### Motivation
|
||||
|
||||
<!--解释为什么要改动-->
|
||||
|
||||
### Modifications
|
||||
|
||||
<!--简单解释你的改动-->
|
||||
@@ -0,0 +1,93 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
#
|
||||
# ******** NOTE ********
|
||||
# We have attempted to detect the languages in your repository. Please check
|
||||
# the `language` matrix defined below to confirm you have the correct set of
|
||||
# supported CodeQL languages.
|
||||
#
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches: [ "master" ]
|
||||
schedule:
|
||||
- cron: '21 15 * * 5'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze (${{ matrix.language }})
|
||||
# Runner size impacts CodeQL analysis time. To learn more, please see:
|
||||
# - https://gh.io/recommended-hardware-resources-for-running-codeql
|
||||
# - https://gh.io/supported-runners-and-hardware-resources
|
||||
# - https://gh.io/using-larger-runners (GitHub.com only)
|
||||
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
|
||||
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
|
||||
timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }}
|
||||
permissions:
|
||||
# required for all workflows
|
||||
security-events: write
|
||||
|
||||
# required to fetch internal or private CodeQL packs
|
||||
packages: read
|
||||
|
||||
# only required for workflows in private repositories
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- language: python
|
||||
build-mode: none
|
||||
# CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
|
||||
# Use `c-cpp` to analyze code written in C, C++ or both
|
||||
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
|
||||
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
|
||||
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
|
||||
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
|
||||
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
|
||||
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
|
||||
# queries: security-extended,security-and-quality
|
||||
|
||||
# If the analyze step fails for one of the languages you are analyzing with
|
||||
# "We were unable to automatically build your code", modify the matrix above
|
||||
# to set the build mode to "manual" for that language. Then modify this step
|
||||
# to build your code.
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
|
||||
- if: matrix.build-mode == 'manual'
|
||||
shell: bash
|
||||
run: |
|
||||
echo 'If you are using a "manual" build mode for one or more of the' \
|
||||
'languages you are analyzing, replace this with the commands to build' \
|
||||
'your code, for example:'
|
||||
echo ' make bootstrap'
|
||||
echo ' make release'
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
@@ -0,0 +1,34 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
push
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run tests and collect coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
mkdir data
|
||||
mkdir data/config
|
||||
mkdir temp
|
||||
|
||||
- name: Run tests
|
||||
run: PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
@@ -4,20 +4,39 @@ on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
publish-latest-docker-image:
|
||||
publish-docker:
|
||||
runs-on: ubuntu-latest
|
||||
name: Build and publish docker image
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
- name: Build image
|
||||
run: |
|
||||
git clone https://github.com/Soulter/AstrBot
|
||||
cd AstrBot
|
||||
docker build -t ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest .
|
||||
- name: Publish image
|
||||
run: |
|
||||
docker login -u ${{ secrets.DOCKER_HUB_USERNAME }} -p ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
docker push ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||
- name: 拉取源码
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: 设置 QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: 设置 Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: 登录到 DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: 构建和推送 Docker hub
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
|
||||
|
||||
+12
@@ -3,6 +3,18 @@ WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install -r requirements.txt
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
|
||||
<img width="806" alt="image" src="https://github.com/Soulter/AstrBot/assets/37870767/c6f057d9-46d7-4144-8116-00a962941746">
|
||||
<img width="750" alt="image" src="https://github.com/Soulter/AstrBot/assets/37870767/c6f057d9-46d7-4144-8116-00a962941746">
|
||||
|
||||
</p>
|
||||
<div align="center">
|
||||
@@ -8,6 +8,7 @@
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.9+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
</a>
|
||||
@@ -21,27 +22,42 @@
|
||||
|
||||
🌍 支持的消息平台
|
||||
- QQ 群、QQ 频道(OneBot、QQ 官方接口)
|
||||
- Telegram(由 [astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件支持)
|
||||
- WeChat(微信) (由 [astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件支持)
|
||||
- Telegram([astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件)
|
||||
|
||||
🌍 支持的大模型一览:
|
||||
🌍 支持的大模型/底座:
|
||||
|
||||
- OpenAI GPT、DallE 系列
|
||||
- Claude(由[LLMs插件](https://github.com/Soulter/llms)支持)
|
||||
- HuggingChat(由[LLMs插件](https://github.com/Soulter/llms)支持)
|
||||
- Gemini(由[LLMs插件](https://github.com/Soulter/llms)支持)
|
||||
- Ollama
|
||||
- 几乎所有已知模型(可接入 [OneAPI](https://astrbot.soulter.top/docs/docs/adavanced/one-api))
|
||||
|
||||
🌍 机器人支持的能力一览:
|
||||
- 大模型对话、人格、网页搜索
|
||||
- 可视化管理面板
|
||||
- 可视化仪表盘
|
||||
- 同时处理多平台消息
|
||||
- 精确到个人的会话隔离
|
||||
- 插件支持
|
||||
- 文本转图片回复(Markdown)
|
||||
|
||||
## 🧩 插件支持
|
||||
## 🧩 插件
|
||||
|
||||
有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/center/docs/%E4%BD%BF%E7%94%A8/%E6%8F%92%E4%BB%B6)
|
||||
有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/docs/get-started/plugin)
|
||||
|
||||
## 云部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
|
||||
对于新功能的添加,请先通过 Issue 进行讨论。
|
||||
|
||||
## 🔭 展望
|
||||
|
||||
- [ ] 更多、更开放的 LLM Agent 能力
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
|
||||
@@ -21,6 +21,14 @@ class HelloWorldPlugin:
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.context = context
|
||||
self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld)
|
||||
self.context.register_llm_tool("welcome_somebody", [{
|
||||
"type": "string",
|
||||
"name": "name",
|
||||
"description": "要欢迎的人的名字"
|
||||
}], "给一个用户发送欢迎文本。", self.welcome_somebody)
|
||||
|
||||
async def welcome_somebody(self, name: str):
|
||||
return CommandResult().message(f"欢迎{name}!")
|
||||
|
||||
"""
|
||||
指令处理函数。
|
||||
|
||||
+12
-3
@@ -22,7 +22,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
|
||||
# load configs and ensure the backward compatibility
|
||||
@@ -43,6 +43,8 @@ class AstrBotBootstrap():
|
||||
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
|
||||
else:
|
||||
logger.info("未使用代理。")
|
||||
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
async def run(self):
|
||||
self.command_manager = CommandManager()
|
||||
@@ -63,6 +65,10 @@ class AstrBotBootstrap():
|
||||
self.context.updator = self.updator
|
||||
self.context.plugin_updator = self.plugin_manager.updator
|
||||
self.context.message_handler = self.message_handler
|
||||
self.context.command_manager = self.command_manager
|
||||
|
||||
if self.test_mode:
|
||||
return
|
||||
|
||||
# load plugins, plugins' commands.
|
||||
self.load_plugins()
|
||||
@@ -84,10 +90,13 @@ class AstrBotBootstrap():
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{task.get_name()} 任务已取消。")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。")
|
||||
await asyncio.sleep(5)
|
||||
logger.error(f"{task.get_name()} 任务发生错误。")
|
||||
return
|
||||
|
||||
def load_llm(self):
|
||||
f = False
|
||||
|
||||
+95
-21
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
import re
|
||||
import time, json
|
||||
import re, os
|
||||
import asyncio
|
||||
import traceback
|
||||
import astrbot.message.unfit_words as uw
|
||||
@@ -14,7 +14,10 @@ from type.command import CommandResult
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.func_call import FuncCall
|
||||
import util.agent.web_searcher as web_searcher
|
||||
from openai._exceptions import *
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -109,8 +112,9 @@ class MessageHandler():
|
||||
self.llm_wake_prefix = self.llm_wake_prefix.strip()
|
||||
self.nicks = self.context.config_helper.wake_prefix
|
||||
self.provider = self.context.llms[0] if len(self.context.llms) > 0 else None
|
||||
self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix)
|
||||
|
||||
self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix)
|
||||
self.llm_tools = FuncCall(self.provider)
|
||||
|
||||
def set_provider(self, provider: Provider):
|
||||
self.provider = provider
|
||||
|
||||
@@ -121,18 +125,19 @@ class MessageHandler():
|
||||
`llm_provider`: the provider to use for LLM. If None, use the default provider
|
||||
'''
|
||||
msg_plain = message.message_str.strip()
|
||||
provider = llm_provider if llm_provider else self.provider
|
||||
inner_provider = False if llm_provider else True
|
||||
provider = llm_provider if llm_provider else self.provider
|
||||
|
||||
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
|
||||
if os.environ.get('TEST_MODE', 'off') != 'on':
|
||||
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
|
||||
|
||||
# TODO: this should be configurable
|
||||
if not message.message_str:
|
||||
return MessageResult("Hi~")
|
||||
# if not message.message_str:
|
||||
# return MessageResult("Hi~")
|
||||
|
||||
# check the rate limit
|
||||
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
|
||||
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
|
||||
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。")
|
||||
return
|
||||
|
||||
# remove the nick prefix
|
||||
for nick in self.nicks:
|
||||
@@ -151,6 +156,11 @@ class MessageHandler():
|
||||
use_t2i=cmd_res.is_use_t2i
|
||||
)
|
||||
|
||||
# next is the LLM part
|
||||
|
||||
if message.only_command:
|
||||
return
|
||||
|
||||
# check if the message is a llm-wake-up command
|
||||
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
|
||||
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
|
||||
@@ -169,31 +179,95 @@ class MessageHandler():
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
break
|
||||
web_search = self.context.config_helper.llm_settings.web_search
|
||||
if not web_search and msg_plain.startswith("ws"):
|
||||
# leverage web search feature
|
||||
web_search = True
|
||||
msg_plain = msg_plain.removeprefix("ws").strip()
|
||||
|
||||
try:
|
||||
if web_search:
|
||||
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider)
|
||||
if not self.llm_tools.empty():
|
||||
# tools-use
|
||||
tool_use_flag = True
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
tools=self.llm_tools.get_func()
|
||||
)
|
||||
|
||||
if isinstance(llm_result, Function):
|
||||
logger.debug(f"function-calling: {llm_result}")
|
||||
func_obj = None
|
||||
for i in self.llm_tools.func_list:
|
||||
if i["name"] == llm_result.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。")
|
||||
try:
|
||||
args = json.loads(llm_result.arguments)
|
||||
args['ame'] = message
|
||||
args['context'] = self.context
|
||||
try:
|
||||
cmd_res = await func_obj(**args)
|
||||
except TypeError as e:
|
||||
args.pop('ame')
|
||||
args.pop('context')
|
||||
cmd_res = await func_obj(**args)
|
||||
if isinstance(cmd_res, CommandResult):
|
||||
return MessageResult(
|
||||
cmd_res.message_chain,
|
||||
is_command_call=True,
|
||||
use_t2i=cmd_res.is_use_t2i
|
||||
)
|
||||
elif isinstance(cmd_res, str):
|
||||
return MessageResult(cmd_res)
|
||||
elif not cmd_res:
|
||||
return
|
||||
else:
|
||||
return MessageResult(f"AstrBot Function-calling 异常:调用:{llm_result} 时,返回了未知的返回值类型。")
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return MessageResult("AstrBot Function-calling 异常:" + str(e))
|
||||
else:
|
||||
return MessageResult(llm_result)
|
||||
|
||||
else:
|
||||
# normal chat
|
||||
tool_use_flag = False
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
image_url=image_url
|
||||
)
|
||||
except BadRequestError as e:
|
||||
if tool_use_flag:
|
||||
# seems like the model don't support function-calling
|
||||
logger.error(f"error: {e}. Using local function-calling implementation")
|
||||
|
||||
try:
|
||||
# use local function-calling implementation
|
||||
args = {
|
||||
'question': llm_result,
|
||||
'func_definition': self.llm_tools.func_dump(),
|
||||
}
|
||||
_, has_func = await self.llm_tools.func_call(**args)
|
||||
|
||||
if not has_func:
|
||||
# normal chat
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
image_url=image_url
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return CommandResult("AstrBot Function-calling 异常:" + str(e))
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"LLM 调用失败。")
|
||||
return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e))
|
||||
|
||||
# concatenate the reply prefix
|
||||
|
||||
# concatenate reply prefix
|
||||
if self.reply_prefix:
|
||||
llm_result = self.reply_prefix + llm_result
|
||||
|
||||
# mask the unsafe content
|
||||
# mask unsafe content
|
||||
llm_result = self.content_safety_helper.filter_content(llm_result)
|
||||
check = self.content_safety_helper.baidu_check(llm_result)
|
||||
if not check:
|
||||
|
||||
+7
-5
@@ -207,10 +207,11 @@ class AstrBotDashBoard():
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
self.plugin_manager.install_plugin(repo_url)
|
||||
logger.info(f"安装插件 {repo_url} 成功")
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
|
||||
return Response(
|
||||
status="success",
|
||||
message="安装成功~",
|
||||
message="安装成功,机器人将在 2 秒内重启。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
@@ -273,10 +274,11 @@ class AstrBotDashBoard():
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
self.plugin_manager.update_plugin(plugin_name)
|
||||
logger.info(f"更新插件 {plugin_name} 成功")
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response(
|
||||
status="success",
|
||||
message="更新成功~",
|
||||
message="更新成功,机器人将在 2 秒内重启。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
@@ -326,7 +328,7 @@ class AstrBotDashBoard():
|
||||
latest = False
|
||||
try:
|
||||
self.astrbot_updator.update(latest=latest, version=version)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(3, )).start()
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
|
||||
return Response(
|
||||
status="success",
|
||||
message="更新成功,机器人将在 3 秒内重启。",
|
||||
|
||||
@@ -53,7 +53,7 @@ if __name__ == "__main__":
|
||||
check_env()
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.web_searcher import search_from_bing, fetch_website_content
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -116,11 +117,11 @@ class InternalCommandHandler:
|
||||
success=False,
|
||||
message_chain="你没有权限使用该指令",
|
||||
)
|
||||
context.updator._reboot(5)
|
||||
context.updator._reboot(3, context)
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain="AstrBot 将在 5s 后重启。",
|
||||
message_chain="AstrBot 将在 3s 后重启。",
|
||||
)
|
||||
|
||||
def plugin(self, message: AstrMessageEvent, context: Context):
|
||||
@@ -211,6 +212,23 @@ class InternalCommandHandler:
|
||||
)
|
||||
elif l[1] == 'on':
|
||||
context.web_search = True
|
||||
context.register_llm_tool("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "搜索关键词"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
search_from_bing
|
||||
)
|
||||
context.register_llm_tool("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "要获取内容的网页链接"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
@@ -218,6 +236,9 @@ class InternalCommandHandler:
|
||||
)
|
||||
elif l[1] == 'off':
|
||||
context.web_search = False
|
||||
context.unregister_llm_tool("web_search")
|
||||
context.unregister_llm_tool("fetch_website_content")
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
|
||||
@@ -21,6 +21,7 @@ class CommandMetadata():
|
||||
plugin_metadata: PluginMetadata
|
||||
handler: callable
|
||||
use_regex: bool = False
|
||||
ignore_prefix: bool = False
|
||||
description: str = ""
|
||||
|
||||
class CommandManager():
|
||||
@@ -35,6 +36,7 @@ class CommandManager():
|
||||
priority: int,
|
||||
handler: callable,
|
||||
use_regex: bool = False,
|
||||
ignore_prefix: bool = False,
|
||||
plugin_metadata: PluginMetadata = None,
|
||||
):
|
||||
'''
|
||||
@@ -53,6 +55,7 @@ class CommandManager():
|
||||
plugin_metadata=plugin_metadata,
|
||||
handler=handler,
|
||||
use_regex=use_regex,
|
||||
ignore_prefix=ignore_prefix,
|
||||
description=description
|
||||
)
|
||||
if plugin_metadata:
|
||||
@@ -75,9 +78,23 @@ class CommandManager():
|
||||
priority=request.priority,
|
||||
handler=request.handler,
|
||||
use_regex=request.use_regex,
|
||||
ignore_prefix=request.ignore_prefix,
|
||||
plugin_metadata=plugin.metadata)
|
||||
self.plugin_commands_waitlist = []
|
||||
|
||||
|
||||
async def check_command_ignore_prefix(self, message_str: str) -> bool:
|
||||
for _, command in self.commands:
|
||||
command_metadata = self.commands_handler[command]
|
||||
if command_metadata.ignore_prefix:
|
||||
trig = False
|
||||
if self.commands_handler[command].use_regex:
|
||||
trig = self.command_parser.regex_match(message_str, command)
|
||||
else:
|
||||
trig = message_str.startswith(command)
|
||||
if trig:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
|
||||
message_str = message_event.message_str
|
||||
for _, command in self.commands:
|
||||
@@ -89,6 +106,8 @@ class CommandManager():
|
||||
if trig:
|
||||
logger.info(f"触发 {command} 指令。")
|
||||
command_result = await self.execute_handler(command, message_event, context)
|
||||
if not command_result:
|
||||
continue
|
||||
if command_result.hit:
|
||||
return command_result
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@ from typing import Union, Any, List
|
||||
from nakuru.entities.components import Plain, At, Image, BaseMessageComponent
|
||||
from type.astrbot_message import AstrBotMessage
|
||||
from type.command import CommandResult
|
||||
from type.astrbot_message import MessageType
|
||||
|
||||
|
||||
class Platform():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, platform_name: str, context) -> None:
|
||||
self.PLATFORM_NAME = platform_name
|
||||
self.context = context
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
@@ -30,6 +32,13 @@ class Platform():
|
||||
发送消息(主动)
|
||||
'''
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
|
||||
'''
|
||||
发送消息(主动)
|
||||
'''
|
||||
pass
|
||||
|
||||
def parse_message_outline(self, message: AstrBotMessage) -> str:
|
||||
'''
|
||||
@@ -72,4 +81,6 @@ class Platform():
|
||||
else:
|
||||
rendered_images.append(Image.fromFileSystem(p))
|
||||
return rendered_images
|
||||
|
||||
|
||||
async def record_metrics(self):
|
||||
self.context.metrics_uploader.increment_platform_stat(self.PLATFORM_NAME)
|
||||
@@ -21,6 +21,7 @@ class AIOCQHTTP(Platform):
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig) -> None:
|
||||
super().__init__("aiocqhttp", context)
|
||||
assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
|
||||
|
||||
self.message_handler = message_handler
|
||||
@@ -74,7 +75,9 @@ class AIOCQHTTP(Platform):
|
||||
message_str += m['data']['text'].strip()
|
||||
abm.message.append(a)
|
||||
if t == 'image':
|
||||
a = Image(file=m['data']['file'])
|
||||
file = m['data']['file'] if 'file' in m['data'] else None
|
||||
url = m['data']['url'] if 'url' in m['data'] else None
|
||||
a = Image(file=file, url=url)
|
||||
abm.message.append(a)
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
@@ -84,7 +87,7 @@ class AIOCQHTTP(Platform):
|
||||
def run_aiocqhttp(self):
|
||||
if not self.host or not self.port:
|
||||
return
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp')
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
@@ -106,26 +109,31 @@ class AIOCQHTTP(Platform):
|
||||
return bot
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while True:
|
||||
while self.context.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
# if message chain contains Plain components or At components which points to self_id, return True
|
||||
# if message chain contains Plain components or
|
||||
# At components which points to self_id, return True
|
||||
if message.type == MessageType.FRIEND_MESSAGE:
|
||||
return True
|
||||
return True, "friend"
|
||||
for comp in message.message:
|
||||
if isinstance(comp, At) and str(comp.qq) == message.self_id:
|
||||
return True
|
||||
return True, "at"
|
||||
# check commands which ignore prefix
|
||||
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
|
||||
return True, "command"
|
||||
# check nicks
|
||||
if self.check_nick(message.message_str):
|
||||
return True
|
||||
return False
|
||||
return True, "nick"
|
||||
return False, "none"
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
logger.info(
|
||||
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
|
||||
|
||||
if not self.pre_check(message):
|
||||
ok, reason = self.pre_check(message)
|
||||
if not ok:
|
||||
return
|
||||
|
||||
# 解析 role
|
||||
@@ -134,15 +142,31 @@ class AIOCQHTTP(Platform):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
|
||||
# parse unified message origin
|
||||
unified_msg_origin = None
|
||||
assert isinstance(message.raw_message, Event)
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}"
|
||||
elif message.type == MessageType.FRIEND_MESSAGE:
|
||||
unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}"
|
||||
|
||||
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
|
||||
|
||||
# construct astrbot message event
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role)
|
||||
ame = AstrMessageEvent.from_astrbot_message(message,
|
||||
self.context,
|
||||
"aiocqhttp",
|
||||
message.session_id,
|
||||
role,
|
||||
unified_msg_origin,
|
||||
reason == "command") # only_command
|
||||
|
||||
# transfer control to message handler
|
||||
message_result = await self.message_handler.handle(ame)
|
||||
if not message_result: return
|
||||
|
||||
await self.reply_msg(message, message_result.result_message)
|
||||
await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
|
||||
if message_result.callback:
|
||||
message_result.callback()
|
||||
|
||||
@@ -153,20 +177,18 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
async def reply_msg(self,
|
||||
message: AstrBotMessage,
|
||||
result_message: list):
|
||||
result_message: list,
|
||||
use_t2i: bool = None):
|
||||
"""
|
||||
回复用户唤醒机器人的消息。(被动回复)
|
||||
"""
|
||||
logger.info(
|
||||
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
|
||||
|
||||
res = result_message
|
||||
|
||||
if isinstance(res, str):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if self.context.config_helper.t2i and isinstance(res, list):
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
@@ -179,9 +201,16 @@ class AIOCQHTTP(Platform):
|
||||
await self._reply(message, res)
|
||||
|
||||
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
|
||||
await self.record_metrics()
|
||||
if isinstance(message_chain, str):
|
||||
message_chain = [Plain(text=message_chain), ]
|
||||
|
||||
|
||||
if isinstance(message, AstrBotMessage):
|
||||
logger.info(
|
||||
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
|
||||
else:
|
||||
logger.info(f"回复消息: {message_chain}")
|
||||
|
||||
ret = []
|
||||
image_idx = []
|
||||
for idx, segment in enumerate(message_chain):
|
||||
@@ -191,24 +220,17 @@ class AIOCQHTTP(Platform):
|
||||
if isinstance(segment, Image):
|
||||
image_idx.append(idx)
|
||||
ret.append(d)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
logger.info(f"回复消息: {ret}")
|
||||
return
|
||||
try:
|
||||
if isinstance(message, AstrBotMessage):
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
if isinstance(message, dict):
|
||||
if 'group_id' in message:
|
||||
await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
|
||||
elif 'user_id' in message:
|
||||
await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
|
||||
else:
|
||||
raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
|
||||
await self._reply_wrapper(message, ret)
|
||||
except ActionFailed as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
if e.retcode == 1200:
|
||||
# ENOENT
|
||||
if not image_idx:
|
||||
raise e
|
||||
logger.info("检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
|
||||
logger.warn("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
|
||||
for idx in image_idx:
|
||||
if ret[idx]['data']['file'].startswith('file://'):
|
||||
logger.info(f"正在上传图片: {ret[idx]['data']['path']}")
|
||||
@@ -216,8 +238,23 @@ class AIOCQHTTP(Platform):
|
||||
logger.info(f"上传成功。")
|
||||
ret[idx]['data']['file'] = image_url
|
||||
ret[idx]['data']['path'] = image_url
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
|
||||
await self._reply_wrapper(message, ret)
|
||||
else:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _reply_wrapper(self, message: Union[AstrBotMessage, Dict], ret: List):
|
||||
if isinstance(message, AstrBotMessage):
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
if isinstance(message, dict):
|
||||
if 'group_id' in message:
|
||||
await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
|
||||
elif 'user_id' in message:
|
||||
await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
|
||||
else:
|
||||
raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
|
||||
|
||||
async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
|
||||
'''
|
||||
以主动的方式给QQ用户、QQ群发送一条消息。
|
||||
@@ -229,4 +266,12 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
'''
|
||||
|
||||
await self._reply(target, result_message.message_chain)
|
||||
await self._reply(target, result_message.message_chain)
|
||||
|
||||
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
|
||||
if message_type == MessageType.GROUP_MESSAGE:
|
||||
await self.send_msg({'group_id': int(target)}, result_message)
|
||||
elif message_type == MessageType.FRIEND_MESSAGE:
|
||||
await self.send_msg({'user_id': int(target)}, result_message)
|
||||
else:
|
||||
raise Exception("aiocqhttp: 无法识别的消息类型。")
|
||||
+70
-13
@@ -33,6 +33,7 @@ class QQNakuru(Platform):
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig) -> None:
|
||||
super().__init__("nakuru", context)
|
||||
assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。"
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
@@ -81,14 +82,17 @@ class QQNakuru(Platform):
|
||||
def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
# if message chain contains Plain components or At components which points to self_id, return True
|
||||
if message.type == MessageType.FRIEND_MESSAGE:
|
||||
return True
|
||||
return True, "friend"
|
||||
for comp in message.message:
|
||||
if isinstance(comp, At) and str(comp.qq) == message.self_id:
|
||||
return True
|
||||
return True, "at"
|
||||
# check commands which ignore prefix
|
||||
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
|
||||
return True, "command"
|
||||
# check nicks
|
||||
if self.check_nick(message.message_str):
|
||||
return True
|
||||
return False
|
||||
return True, "nick"
|
||||
return False, "none"
|
||||
|
||||
def run(self):
|
||||
coro = self.client._run()
|
||||
@@ -102,7 +106,8 @@ class QQNakuru(Platform):
|
||||
(GroupMessage, FriendMessage, GuildMessage))
|
||||
|
||||
# 判断是否响应消息
|
||||
if not self.pre_check(message):
|
||||
ok, reason = self.pre_check(message)
|
||||
if not ok:
|
||||
return
|
||||
|
||||
# 解析 session_id
|
||||
@@ -124,14 +129,35 @@ class QQNakuru(Platform):
|
||||
else:
|
||||
role = 'member'
|
||||
|
||||
# parse unified message origin
|
||||
unified_msg_origin = None
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
assert isinstance(message.raw_message, GroupMessage)
|
||||
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}"
|
||||
elif message.type == MessageType.FRIEND_MESSAGE:
|
||||
assert isinstance(message.raw_message, FriendMessage)
|
||||
unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}"
|
||||
elif message.type == MessageType.GUILD_MESSAGE:
|
||||
assert isinstance(message.raw_message, GuildMessage)
|
||||
unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}"
|
||||
|
||||
logger.debug(f"unified_msg_origin: {unified_msg_origin}")
|
||||
|
||||
|
||||
# construct astrbot message event
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role)
|
||||
ame = AstrMessageEvent.from_astrbot_message(message,
|
||||
self.context,
|
||||
"nakuru",
|
||||
session_id,
|
||||
role,
|
||||
unified_msg_origin,
|
||||
reason == 'command') # only_command
|
||||
|
||||
# transfer control to message handler
|
||||
message_result = await self.message_handler.handle(ame)
|
||||
if not message_result: return
|
||||
|
||||
await self.reply_msg(message, message_result.result_message)
|
||||
await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
|
||||
if message_result.callback:
|
||||
message_result.callback()
|
||||
|
||||
@@ -141,7 +167,8 @@ class QQNakuru(Platform):
|
||||
|
||||
async def reply_msg(self,
|
||||
message: AstrBotMessage,
|
||||
result_message: List[BaseMessageComponent]):
|
||||
result_message: List[BaseMessageComponent],
|
||||
use_t2i: bool = None):
|
||||
"""
|
||||
回复用户唤醒机器人的消息。(被动回复)
|
||||
"""
|
||||
@@ -158,7 +185,7 @@ class QQNakuru(Platform):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if self.context.config_helper.t2i and isinstance(res, list):
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
@@ -171,18 +198,31 @@ class QQNakuru(Platform):
|
||||
await self._reply(source, res)
|
||||
|
||||
async def _reply(self, source, message_chain: List[BaseMessageComponent]):
|
||||
await self.record_metrics()
|
||||
if isinstance(message_chain, str):
|
||||
message_chain = [Plain(text=message_chain), ]
|
||||
|
||||
is_dict = isinstance(source, dict)
|
||||
if source.type == "GuildMessage":
|
||||
|
||||
typ = None
|
||||
if is_dict:
|
||||
if "group_id" in source:
|
||||
typ = "GroupMessage"
|
||||
elif "user_id" in source:
|
||||
typ = "FriendMessage"
|
||||
elif "guild_id" in source:
|
||||
typ = "GuildMessage"
|
||||
else:
|
||||
typ = source.type
|
||||
|
||||
if typ == "GuildMessage":
|
||||
guild_id = source['guild_id'] if is_dict else source.guild_id
|
||||
chan_id = source['channel_id'] if is_dict else source.channel_id
|
||||
await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain)
|
||||
elif source.type == "FriendMessage":
|
||||
elif typ == "FriendMessage":
|
||||
user_id = source['user_id'] if is_dict else source.user_id
|
||||
await self.client.sendFriendMessage(user_id, message_chain)
|
||||
elif source.type == "GroupMessage":
|
||||
elif typ == "GroupMessage":
|
||||
group_id = source['group_id'] if is_dict else source.group_id
|
||||
# 过长时forward发送
|
||||
plain_text_len = 0
|
||||
@@ -219,6 +259,23 @@ class QQNakuru(Platform):
|
||||
guild_id 不是频道号。
|
||||
'''
|
||||
await self._reply(target, result_message.message_chain)
|
||||
|
||||
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
|
||||
'''
|
||||
以主动的方式给用户、群或者频道发送一条消息。
|
||||
|
||||
`message_type` 为 MessageType 枚举类型。
|
||||
|
||||
- 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE;
|
||||
- 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE;
|
||||
- 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。
|
||||
'''
|
||||
if message_type == MessageType.FRIEND_MESSAGE:
|
||||
await self.send_msg({"user_id": int(target)}, result_message)
|
||||
elif message_type == MessageType.GROUP_MESSAGE:
|
||||
await self.send_msg({"group_id": int(target)}, result_message)
|
||||
elif message_type == MessageType.GUILD_MESSAGE:
|
||||
await self.send_msg({"channel_id": int(target)}, result_message)
|
||||
|
||||
def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage:
|
||||
abm = AstrBotMessage()
|
||||
@@ -239,7 +296,7 @@ class QQNakuru(Platform):
|
||||
str(message.sender.user_id),
|
||||
str(message.sender.nickname)
|
||||
)
|
||||
abm.tag = "gocq"
|
||||
abm.tag = "nakuru"
|
||||
abm.message = message.message
|
||||
return abm
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class QQOfficial(Platform):
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig,
|
||||
test_mode = False) -> None:
|
||||
super().__init__("qqofficial", context)
|
||||
assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
@@ -86,12 +87,13 @@ class QQOfficial(Platform):
|
||||
)
|
||||
self.client = botClient(
|
||||
intents=self.intents,
|
||||
bot_log=False
|
||||
bot_log=False,
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
self.client.set_platform(self)
|
||||
|
||||
self.test_mode = test_mode
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
|
||||
plain_text = ""
|
||||
@@ -117,7 +119,7 @@ class QQOfficial(Platform):
|
||||
abm.timestamp = int(time.time())
|
||||
abm.raw_message = message
|
||||
abm.message_id = message.id
|
||||
abm.tag = "qqchan"
|
||||
abm.tag = "qqofficial"
|
||||
msg: List[BaseMessageComponent] = []
|
||||
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
|
||||
@@ -177,7 +179,7 @@ class QQOfficial(Platform):
|
||||
appid=self.appid,
|
||||
secret=self.secret
|
||||
)
|
||||
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
assert isinstance(message.raw_message, (botpy.message.Message,
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
@@ -207,13 +209,13 @@ class QQOfficial(Platform):
|
||||
role = 'member'
|
||||
|
||||
# construct astrbot message event
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqchan", session_id, role)
|
||||
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id, role)
|
||||
|
||||
message_result = await self.message_handler.handle(ame)
|
||||
if not message_result:
|
||||
return
|
||||
|
||||
ret = await self.reply_msg(message, message_result.result_message)
|
||||
ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
|
||||
if message_result.callback:
|
||||
message_result.callback()
|
||||
|
||||
@@ -225,7 +227,8 @@ class QQOfficial(Platform):
|
||||
|
||||
async def reply_msg(self,
|
||||
message: AstrBotMessage,
|
||||
result_message: List[BaseMessageComponent]):
|
||||
result_message: List[BaseMessageComponent],
|
||||
use_t2i: bool = None):
|
||||
'''
|
||||
回复频道消息
|
||||
'''
|
||||
@@ -240,7 +243,7 @@ class QQOfficial(Platform):
|
||||
msg_ref = None
|
||||
rendered_images = []
|
||||
|
||||
if self.context.config_helper.t2i and isinstance(result_message, list):
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(result_message)
|
||||
|
||||
if isinstance(result_message, list):
|
||||
@@ -311,6 +314,7 @@ class QQOfficial(Platform):
|
||||
return await self._reply(**data)
|
||||
|
||||
async def _reply(self, **kwargs):
|
||||
await self.record_metrics()
|
||||
if 'group_openid' in kwargs or 'openid' in kwargs:
|
||||
# QQ群组消息
|
||||
if 'file_image' in kwargs and kwargs['file_image']:
|
||||
@@ -379,6 +383,9 @@ class QQOfficial(Platform):
|
||||
if image_path:
|
||||
payload['file_image'] = image_path
|
||||
await self._reply(**payload)
|
||||
|
||||
async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
|
||||
raise NotImplementedError("qqofficial 不支持此方法。")
|
||||
|
||||
def wait_for_message(self, channel_id: int) -> AstrBotMessage:
|
||||
'''
|
||||
@@ -395,4 +402,4 @@ class QQOfficial(Platform):
|
||||
cnt += 1
|
||||
if cnt > 300:
|
||||
raise Exception("等待消息超时。")
|
||||
time.sleep(1)()
|
||||
time.sleep(1)
|
||||
|
||||
@@ -15,12 +15,13 @@ class CommandRegisterRequest():
|
||||
handler: Callable
|
||||
use_regex: bool = False
|
||||
plugin_name: str = None
|
||||
ignore_prefix: bool = False
|
||||
|
||||
class PluginCommandBridge():
|
||||
def __init__(self, cached_plugins: RegisteredPlugins):
|
||||
self.plugin_commands_waitlist: List[CommandRegisterRequest] = []
|
||||
self.cached_plugins = cached_plugins
|
||||
|
||||
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False):
|
||||
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name))
|
||||
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False):
|
||||
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix))
|
||||
|
||||
+35
-9
@@ -5,6 +5,7 @@ import traceback
|
||||
import uuid
|
||||
import shutil
|
||||
import yaml
|
||||
import subprocess
|
||||
|
||||
from util.updator.plugin_updator import PluginUpdator
|
||||
from util.io import remove_dir, download_file
|
||||
@@ -84,8 +85,28 @@ class PluginManager():
|
||||
def update_plugin_dept(self, path):
|
||||
mirror = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
py = sys.executable
|
||||
os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet")
|
||||
|
||||
# os.system(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com")
|
||||
|
||||
process = subprocess.Popen(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com",
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True)
|
||||
|
||||
while True:
|
||||
output = process.stdout.readline()
|
||||
if output == '' and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
output = output.strip()
|
||||
if output.startswith("Requirement already satisfied"):
|
||||
continue
|
||||
if output.startswith("Using cached"):
|
||||
continue
|
||||
if output.startswith("Looking in indexes"):
|
||||
continue
|
||||
logger.info(output)
|
||||
|
||||
rc = process.poll()
|
||||
|
||||
|
||||
def install_plugin(self, repo_url: str):
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
@@ -95,10 +116,13 @@ class PluginManager():
|
||||
plugin_path = self.updator.update(repo_url)
|
||||
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
|
||||
f.write(repo_url)
|
||||
|
||||
self.check_plugin_dept_update()
|
||||
|
||||
ok, err = self.plugin_reload()
|
||||
if not ok:
|
||||
raise Exception(err)
|
||||
return plugin_path
|
||||
# ok, err = self.plugin_reload()
|
||||
# if not ok:
|
||||
# raise Exception(err)
|
||||
|
||||
def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
@@ -158,7 +182,7 @@ class PluginManager():
|
||||
|
||||
logger.info(f"正在加载插件 {root_dir_name} ...")
|
||||
|
||||
# self.check_plugin_dept_update(cached_plugins, root_dir_name)
|
||||
self.check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
|
||||
module = __import__("addons.plugins." +
|
||||
root_dir_name + "." + p, fromlist=[p])
|
||||
@@ -227,10 +251,12 @@ class PluginManager():
|
||||
|
||||
# remove the temp dir
|
||||
remove_dir(temp_dir)
|
||||
|
||||
self.check_plugin_dept_update()
|
||||
|
||||
ok, err = self.plugin_reload()
|
||||
if not ok:
|
||||
raise Exception(err)
|
||||
# ok, err = self.plugin_reload()
|
||||
# if not ok:
|
||||
# raise Exception(err)
|
||||
|
||||
def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata:
|
||||
metadata = None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import tiktoken
|
||||
@@ -6,13 +8,12 @@ import traceback
|
||||
import base64
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import *
|
||||
from util.io import download_image_by_url
|
||||
|
||||
from astrbot.persist.helper import dbConn
|
||||
from model.provider.provider import Provider
|
||||
from util import general_utils as gu
|
||||
from util.cmd_config import LLMConfig
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
@@ -149,7 +150,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
将图片转换为 base64
|
||||
'''
|
||||
if image_url.startswith("http"):
|
||||
image_url = await gu.download_image_by_url(image_url)
|
||||
image_url = await download_image_by_url(image_url)
|
||||
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode()
|
||||
@@ -292,6 +293,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
extra_conf: Dict = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on":
|
||||
return "这是一个测试消息。"
|
||||
|
||||
super().accu_model_stat()
|
||||
if not session_id:
|
||||
session_id = "unknown"
|
||||
@@ -364,7 +368,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
|
||||
await self.switch_to_next_key()
|
||||
rate_limit_retry += 1
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
except NotFoundError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
retry += 1
|
||||
if retry >= 3:
|
||||
@@ -376,7 +382,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"openai completion: {completion.usage}")
|
||||
@@ -446,7 +452,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
logger.error(traceback.format_exc())
|
||||
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
|
||||
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
|
||||
if session_id is None: return False
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from aiocqhttp import Event
|
||||
|
||||
class MockOneBotMessage():
|
||||
def __init__(self):
|
||||
# 这些数据不是敏感的
|
||||
self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470})
|
||||
self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'})
|
||||
|
||||
def create_random_group_message(self):
|
||||
return self.group_event_sample
|
||||
|
||||
def create_random_direct_message(self):
|
||||
return self.friend_event_sample
|
||||
@@ -0,0 +1,45 @@
|
||||
import botpy.message
|
||||
|
||||
class MockQQOfficialMessage():
|
||||
def __init__(self):
|
||||
# 这些数据已经经过去敏处理
|
||||
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'}
|
||||
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'}
|
||||
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'}
|
||||
self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849"
|
||||
|
||||
self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'}
|
||||
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
|
||||
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
|
||||
self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'}
|
||||
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
|
||||
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
|
||||
self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
def create_random_group_message(self):
|
||||
mocked = botpy.message.GroupMessage(
|
||||
api=None,
|
||||
event_id=self.group_event_id_sample,
|
||||
data=self.group_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_random_guild_message(self):
|
||||
mocked = botpy.message.Message(
|
||||
api=None,
|
||||
event_id=self.guild_event_id_sample,
|
||||
data=self.guild_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_random_direct_message(self):
|
||||
mocked = botpy.message.DirectMessage(
|
||||
api=None,
|
||||
event_id=self.direct_event_id_sample,
|
||||
data=self.direct_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
import os
|
||||
|
||||
from tests.mocks.qq_official import MockQQOfficialMessage
|
||||
from tests.mocks.onebot import MockOneBotMessage
|
||||
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
os.environ['TEST_MODE'] = 'on'
|
||||
bootstrap = AstrBotBootstrap()
|
||||
asyncio.run(bootstrap.run())
|
||||
|
||||
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler)
|
||||
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler)
|
||||
|
||||
class TestBasicMessageHandle():
|
||||
@pytest.mark.asyncio
|
||||
async def test_qqofficial_group_message(self):
|
||||
group_message = MockQQOfficialMessage().create_random_group_message()
|
||||
abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE)
|
||||
ret = await qq_official.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qqofficial_guild_message(self):
|
||||
guild_message = MockQQOfficialMessage().create_random_guild_message()
|
||||
abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE)
|
||||
ret = await qq_official.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
# 有共同性,为了节约开销,不测试频道私聊。
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_qqofficial_private_message(self):
|
||||
# private_message = MockQQOfficialMessage().create_random_direct_message()
|
||||
# abm = qq_official._parse_from_qqofficial(private_message, MessageType.FRIEND_MESSAGE)
|
||||
# ret = await qq_official.handle_msg(abm)
|
||||
# print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_group_message(self):
|
||||
event = MockOneBotMessage().create_random_group_message()
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_direct_message(self):
|
||||
event = MockOneBotMessage().create_random_direct_message()
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
+13
-11
@@ -2,7 +2,6 @@ from typing import Union, List, Callable
|
||||
from dataclasses import dataclass
|
||||
from nakuru.entities.components import Plain, Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandItem():
|
||||
'''
|
||||
@@ -19,12 +18,17 @@ class CommandResult():
|
||||
用于在Command中返回多个值
|
||||
'''
|
||||
|
||||
def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None:
|
||||
def __init__(self,
|
||||
hit: bool = True,
|
||||
success: bool = True,
|
||||
message_chain: list = [],
|
||||
command_name: str = "unknown_command",
|
||||
use_t2i: bool = None) -> None:
|
||||
self.hit = hit
|
||||
self.success = success
|
||||
self.message_chain = message_chain
|
||||
self.command_name = command_name
|
||||
self.is_use_t2i = None # default
|
||||
self.is_use_t2i = use_t2i
|
||||
|
||||
def message(self, message: str):
|
||||
'''
|
||||
@@ -63,14 +67,12 @@ class CommandResult():
|
||||
self.message_chain = [Image.fromFileSystem(path), ]
|
||||
return self
|
||||
|
||||
# def use_t2i(self, use_t2i: bool):
|
||||
# '''
|
||||
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
|
||||
|
||||
# CommandResult().use_t2i(False)
|
||||
# '''
|
||||
# self.is_use_t2i = use_t2i
|
||||
# return self
|
||||
def use_t2i(self, use_t2i: bool):
|
||||
'''
|
||||
设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
|
||||
'''
|
||||
self.is_use_t2i = use_t2i
|
||||
return self
|
||||
|
||||
def _result_tuple(self):
|
||||
return (self.success, self.message_chain, self.command_name)
|
||||
|
||||
+2
-2
@@ -1,4 +1,4 @@
|
||||
VERSION = '3.3.7'
|
||||
VERSION = '3.3.9'
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"qqbot": {
|
||||
@@ -353,4 +353,4 @@ CONFIG_METADATA_2 = {
|
||||
"password": {"description": "密码", "type": "string"},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+21
-10
@@ -2,7 +2,14 @@ from typing import List, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
from type.register import RegisteredPlatform
|
||||
from type.types import Context
|
||||
from type.astrbot_message import AstrBotMessage
|
||||
from type.astrbot_message import AstrBotMessage, MessageType
|
||||
|
||||
@dataclass
|
||||
class MessageResult():
|
||||
result_message: Union[str, list]
|
||||
is_command_call: Optional[bool] = False
|
||||
use_t2i: Optional[bool] = None # None 为跟随用户设置
|
||||
callback: Optional[callable] = None
|
||||
|
||||
class AstrMessageEvent():
|
||||
|
||||
@@ -12,7 +19,9 @@ class AstrMessageEvent():
|
||||
platform: RegisteredPlatform,
|
||||
role: str,
|
||||
context: Context,
|
||||
session_id: str = None):
|
||||
session_id: str = None,
|
||||
unified_msg_origin: str = None,
|
||||
only_command: bool = False):
|
||||
'''
|
||||
AstrBot 消息事件。
|
||||
|
||||
@@ -22,6 +31,8 @@ class AstrMessageEvent():
|
||||
`role`: 角色,`admin` or `member`
|
||||
`context`: 全局对象
|
||||
`session_id`: 会话id
|
||||
`unified_msg_origin`: 统一消息来源
|
||||
`only_command`: 是否只处理指令,而不使用 LLM 回复
|
||||
'''
|
||||
self.context = context
|
||||
self.message_str = message_str
|
||||
@@ -29,24 +40,24 @@ class AstrMessageEvent():
|
||||
self.platform = platform
|
||||
self.role = role
|
||||
self.session_id = session_id
|
||||
self.unified_msg_origin = unified_msg_origin
|
||||
self.only_command = only_command
|
||||
|
||||
def from_astrbot_message(message: AstrBotMessage,
|
||||
context: Context,
|
||||
platform_name: str,
|
||||
session_id: str,
|
||||
role: str = "member"):
|
||||
role: str = "member",
|
||||
unified_msg_origin: str = None,
|
||||
only_command: bool = False):
|
||||
|
||||
ame = AstrMessageEvent(message.message_str,
|
||||
message,
|
||||
context.find_platform(platform_name),
|
||||
role,
|
||||
context,
|
||||
session_id)
|
||||
session_id,
|
||||
unified_msg_origin,
|
||||
only_command=only_command)
|
||||
return ame
|
||||
|
||||
@dataclass
|
||||
class MessageResult():
|
||||
result_message: Union[str, list]
|
||||
is_command_call: Optional[bool] = False
|
||||
use_t2i: Optional[bool] = None # None 为跟随用户设置
|
||||
callback: Optional[callable] = None
|
||||
|
||||
+61
-5
@@ -1,4 +1,4 @@
|
||||
import asyncio
|
||||
import asyncio, os
|
||||
from asyncio import Task
|
||||
from type.register import *
|
||||
from typing import List, Awaitable
|
||||
@@ -8,8 +8,11 @@ from util.t2i.renderer import TextToImageRenderer
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.image_uploader import ImageUploader
|
||||
from util.updator.plugin_updator import PluginUpdator
|
||||
from type.command import CommandResult
|
||||
from type.astrbot_message import MessageType
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from model.provider.provider import Provider
|
||||
from util.agent.func_call import FuncCall
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -40,6 +43,9 @@ class Context:
|
||||
self.image_uploader = ImageUploader()
|
||||
self.message_handler = None # see astrbot/message/handler.py
|
||||
self.ext_tasks: List[Task] = []
|
||||
|
||||
self.command_manager = None
|
||||
self.running = True
|
||||
|
||||
# useless
|
||||
# self.reply_prefix = ""
|
||||
@@ -50,7 +56,8 @@ class Context:
|
||||
description: str,
|
||||
priority: int,
|
||||
handler: callable,
|
||||
use_regex: bool = False):
|
||||
use_regex: bool = False,
|
||||
ignore_prefix: bool = False):
|
||||
'''
|
||||
注册插件指令。
|
||||
|
||||
@@ -60,8 +67,19 @@ class Context:
|
||||
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
|
||||
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
|
||||
@param use_regex: 是否使用正则表达式匹配指令名。
|
||||
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
|
||||
|
||||
.. Example::
|
||||
|
||||
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
|
||||
'''
|
||||
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex)
|
||||
self.plugin_command_bridge.register_command(plugin_name,
|
||||
command_name,
|
||||
description,
|
||||
priority,
|
||||
handler,
|
||||
use_regex,
|
||||
ignore_prefix)
|
||||
|
||||
def register_task(self, coro: Awaitable, task_name: str):
|
||||
'''
|
||||
@@ -80,10 +98,48 @@ class Context:
|
||||
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
|
||||
'''
|
||||
self.llms.append(RegisteredLLM(llm_name, provider, origin))
|
||||
|
||||
def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable):
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
|
||||
|
||||
def unregister_llm_tool(self, tool_name: str):
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
self.message_handler.llm_tools.remove_func(tool_name)
|
||||
|
||||
def find_platform(self, platform_name: str) -> RegisteredPlatform:
|
||||
for platform in self.platforms:
|
||||
if platform_name == platform.platform_name:
|
||||
return platform
|
||||
|
||||
raise ValueError("couldn't find the platform you specified")
|
||||
|
||||
if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错
|
||||
raise ValueError("couldn't find the platform you specified")
|
||||
|
||||
async def send_message(self, unified_msg_origin: str, message: CommandResult):
|
||||
'''
|
||||
发送消息。
|
||||
|
||||
`unified_msg_origin`: 统一消息来源
|
||||
`message`: 消息内容
|
||||
'''
|
||||
l = unified_msg_origin.split(":")
|
||||
if len(l) != 3:
|
||||
raise ValueError("Invalid unified_msg_origin")
|
||||
platform_name, message_type, id = l
|
||||
platform = self.find_platform(platform_name)
|
||||
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
|
||||
|
||||
def get_current_llm_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前的 LLM Provider。
|
||||
'''
|
||||
return self.message_handler.provider
|
||||
+72
-170
@@ -1,9 +1,6 @@
|
||||
|
||||
from model.provider.provider import Provider
|
||||
import json
|
||||
import util.general_utils as gu
|
||||
|
||||
import time
|
||||
|
||||
import textwrap
|
||||
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
def __init__(self, msg):
|
||||
@@ -22,16 +19,24 @@ class FuncNotFoundError(Exception):
|
||||
|
||||
|
||||
class FuncCall():
|
||||
def __init__(self, provider) -> None:
|
||||
def __init__(self, provider: Provider) -> None:
|
||||
self.func_list = []
|
||||
self.provider = provider
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
|
||||
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj=None) -> None:
|
||||
if name == None or func_args == None or desc == None or func_obj == None:
|
||||
raise FuncCallJsonFormatError(
|
||||
"name, func_args, desc must be provided.")
|
||||
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
params = {
|
||||
"type": "object", # hardcore here
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {}
|
||||
}
|
||||
for param in func_args:
|
||||
@@ -39,15 +44,24 @@ class FuncCall():
|
||||
"type": param['type'],
|
||||
"description": param['description']
|
||||
}
|
||||
self._func = {
|
||||
_func = {
|
||||
"name": name,
|
||||
"parameters": params,
|
||||
"description": desc,
|
||||
"func_obj": func_obj,
|
||||
}
|
||||
self.func_list.append(self._func)
|
||||
|
||||
def func_dump(self, intent: int = 2) -> str:
|
||||
self.func_list.append(_func)
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
for i, f in enumerate(self.func_list):
|
||||
if f["name"] == name:
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
def func_dump(self) -> str:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
@@ -55,7 +69,7 @@ class FuncCall():
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
})
|
||||
return json.dumps(_l, indent=intent, ensur_ascii=False)
|
||||
return json.dumps(_l, ensure_ascii=False)
|
||||
|
||||
def get_func(self) -> list:
|
||||
_l = []
|
||||
@@ -70,64 +84,39 @@ class FuncCall():
|
||||
})
|
||||
return _l
|
||||
|
||||
def func_call(self, question, func_definition, is_task=False, tasks=None, taskindex=-1, is_summary=True, session_id=None):
|
||||
async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple:
|
||||
|
||||
if not provider:
|
||||
provider = self.provider
|
||||
|
||||
funccall_prompt = """
|
||||
我正实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(意味着你不是创造函数)。
|
||||
下面会给你提供可能用到的函数相关信息和一个问题,你需要将其转换成给定的函数调用。
|
||||
- 你的返回信息只含json,请严格仿照以下内容(不含注释),必须含有`res`,`func_call`字段:
|
||||
```
|
||||
{
|
||||
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
|
||||
"func_call": [ // 这是一个数组,里面包含了所有的函数调用,如果没有函数调用,那么这个数组是空数组。
|
||||
{
|
||||
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
|
||||
"name": str, // 函数的名字
|
||||
"args_type": {
|
||||
"arg1": str, // 函数的参数的类型
|
||||
"arg2": str,
|
||||
...
|
||||
},
|
||||
"args": {
|
||||
"arg1": any, // 函数的参数
|
||||
"arg2": any,
|
||||
...
|
||||
}
|
||||
},
|
||||
... // 可能在这个问题中会有多个函数调用
|
||||
],
|
||||
}
|
||||
```
|
||||
- 如果用户的要求较复杂,允许返回多个函数调用,但需保证这些函数调用的顺序正确。
|
||||
- 当问题没有提到给定的函数时,相当于提问方不打算使用function call功能,这时你可以在res中正常输出这个问题的回答(以AI的身份正常回答该问题,并将答案输出在res字段中,回答不要涉及到任何函数调用的内容,就只是正常讨论这个问题。)
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
|
||||
|
||||
提供的函数是:
|
||||
TOOLS:
|
||||
可用的函数列表:
|
||||
|
||||
"""
|
||||
{func_definition}
|
||||
|
||||
prompt = f"{funccall_prompt}\n```\n{func_definition}\n```\n"
|
||||
prompt += f"""
|
||||
用户的提问是:
|
||||
```
|
||||
{question}
|
||||
```
|
||||
"""
|
||||
LIMIT:
|
||||
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
|
||||
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
|
||||
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
|
||||
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
|
||||
|
||||
# if is_task:
|
||||
# # task_prompt = f"\n任务列表为{str(tasks)}\n你目前进行到了任务{str(taskindex)}, **你不需要重新进行已经进行过的任务, 不要生成已经进行过的**"
|
||||
# prompt += task_prompt
|
||||
EXAMPLE:
|
||||
1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
|
||||
|
||||
# provider.forget()
|
||||
用户的提问是:{question}
|
||||
""")
|
||||
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = self.provider.text_chat(prompt=prompt, session_id=session_id)
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
print(res)
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') + 7: res.rfind('```')]
|
||||
gu.log("REVGPT func_call json result",
|
||||
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
|
||||
print(res)
|
||||
res = json.loads(res)
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -136,112 +125,25 @@ class FuncCall():
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
raise e
|
||||
|
||||
if 'res' in res and not res['res']:
|
||||
return "", False
|
||||
|
||||
invoke_func_res = ""
|
||||
|
||||
if "func_call" in res and len(res["func_call"]) > 0:
|
||||
task_list = res["func_call"]
|
||||
|
||||
invoke_func_res_list = []
|
||||
|
||||
for res in task_list:
|
||||
# 说明有函数调用
|
||||
func_name = res["name"]
|
||||
# args_type = res["args_type"]
|
||||
args = res["args"]
|
||||
# 调用函数
|
||||
# func = eval(func_name)
|
||||
func_target = None
|
||||
for func in self.func_list:
|
||||
if func["name"] == func_name:
|
||||
func_target = func["func_obj"]
|
||||
break
|
||||
if func_target == None:
|
||||
raise FuncNotFoundError(
|
||||
f"Request function {func_name} not found.")
|
||||
t_res = str(func_target(**args))
|
||||
invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n"
|
||||
invoke_func_res_list.append(invoke_func_res)
|
||||
gu.log(f"[FUNC| {func_name} invoked]",
|
||||
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
|
||||
# print(str(t_res))
|
||||
|
||||
if is_summary:
|
||||
|
||||
# 生成返回结果
|
||||
after_prompt = """
|
||||
有以下内容:"""+invoke_func_res+"""
|
||||
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
|
||||
用户的提问是:
|
||||
```""" + question + """```
|
||||
- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
|
||||
- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
|
||||
```json
|
||||
{
|
||||
"res": string, // 回答的内容
|
||||
"func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
|
||||
}
|
||||
```
|
||||
- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。"""
|
||||
|
||||
_c = 0
|
||||
while _c < 5:
|
||||
try:
|
||||
res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
|
||||
# 截取```之间的内容
|
||||
gu.log(
|
||||
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
|
||||
print(res)
|
||||
gu.log(
|
||||
"DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') +
|
||||
7: res.rfind('```')]
|
||||
gu.log("REVGPT after_func_call json result",
|
||||
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
|
||||
after_prompt_res = res
|
||||
after_prompt_res = json.loads(after_prompt_res)
|
||||
break
|
||||
except Exception as e:
|
||||
_c += 1
|
||||
if _c == 5:
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
# 如果返回的内容太长了,那么就截取一部分
|
||||
time.sleep(3)
|
||||
invoke_func_res = invoke_func_res[:int(
|
||||
len(invoke_func_res) / 2)]
|
||||
after_prompt = """
|
||||
函数返回以下内容:"""+invoke_func_res+"""
|
||||
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
|
||||
用户的提问是:
|
||||
```""" + question + """```
|
||||
- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
|
||||
- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
|
||||
```json
|
||||
{
|
||||
"res": string, // 回答的内容
|
||||
"func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
|
||||
}
|
||||
```
|
||||
- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。"""
|
||||
else:
|
||||
raise e
|
||||
|
||||
if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]:
|
||||
# 如果需要重新调用函数
|
||||
# 重新调用函数
|
||||
gu.log("REVGPT func_call_again",
|
||||
bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"])
|
||||
res = self.func_call(question, func_definition)
|
||||
return res, True
|
||||
|
||||
gu.log("REVGPT func callback:",
|
||||
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
|
||||
# print(after_prompt_res["res"])
|
||||
return after_prompt_res["res"], True
|
||||
else:
|
||||
return str(invoke_func_res_list), True
|
||||
else:
|
||||
# print(res["res"])
|
||||
return res["res"], False
|
||||
tool_call_result = []
|
||||
for tool in res:
|
||||
# 说明有函数调用
|
||||
func_name = tool["name"]
|
||||
args = tool["args"]
|
||||
# 调用函数
|
||||
tool_callable = None
|
||||
for func in self.func_list:
|
||||
if func["name"] == func_name:
|
||||
tool_callable = func["func_obj"]
|
||||
break
|
||||
if not tool_callable:
|
||||
raise FuncNotFoundError(
|
||||
f"Request function {func_name} not found.")
|
||||
ret = await tool_callable(**args)
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
return tool_call_result, True
|
||||
|
||||
+19
-97
@@ -1,13 +1,11 @@
|
||||
import traceback
|
||||
import random
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from openai._exceptions import *
|
||||
from util.agent.func_call import FuncCall
|
||||
from util.websearch.config import HEADERS, USER_AGENTS
|
||||
from util.websearch.bing import Bing
|
||||
@@ -16,6 +14,8 @@ from util.websearch.google import Google
|
||||
from model.provider.provider import Provider
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.message_event import AstrMessageEvent
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -31,24 +31,7 @@ def tidy_text(text: str) -> str:
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
# def special_fetch_zhihu(link: str) -> str:
|
||||
# '''
|
||||
# function-calling 函数, 用于获取知乎文章的内容
|
||||
# '''
|
||||
# response = requests.get(link, headers=HEADERS)
|
||||
# response.encoding = "utf-8"
|
||||
# soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
# if "zhuanlan.zhihu.com" in link:
|
||||
# r = soup.find(class_="Post-RichTextContainer")
|
||||
# else:
|
||||
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
|
||||
# if r is None:
|
||||
# print("debug: zhihu none")
|
||||
# raise Exception("zhihu none")
|
||||
# return tidy_text(r.text)
|
||||
|
||||
async def search_from_bing(keyword: str) -> str:
|
||||
async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str:
|
||||
'''
|
||||
tools, 从 bing 搜索引擎搜索
|
||||
'''
|
||||
@@ -84,10 +67,11 @@ async def search_from_bing(keyword: str) -> str:
|
||||
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
return ret
|
||||
|
||||
return await summarize(context, ame, ret)
|
||||
|
||||
|
||||
async def fetch_website_content(url):
|
||||
async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str):
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -97,87 +81,25 @@ async def fetch_website_content(url):
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
|
||||
async def web_search(prompt, provider: Provider, session_id, official_fc=False):
|
||||
'''
|
||||
official_fc: 使用官方 function-calling
|
||||
'''
|
||||
new_func_call = FuncCall(provider)
|
||||
|
||||
new_func_call.add_func("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "搜索关键词"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
search_from_bing
|
||||
)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "要获取内容的网页链接"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
return await summarize(context, ame, ret)
|
||||
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
# we use official function-calling
|
||||
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
|
||||
if isinstance(result, Function):
|
||||
logger.debug(f"web_searcher - function-calling: {result}")
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
if i["name"] == result.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(result.arguments)
|
||||
function_invoked_ret = await func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
# we use our own function-calling
|
||||
try:
|
||||
args = {
|
||||
'question': prompt,
|
||||
'func_definition': new_func_call.func_dump(),
|
||||
'is_task': False,
|
||||
'is_summary': False,
|
||||
}
|
||||
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
|
||||
except BaseException as e:
|
||||
res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return res
|
||||
has_func = True
|
||||
|
||||
if has_func:
|
||||
await provider.forget(session_id=session_id, )
|
||||
summary_prompt = f"""
|
||||
async def summarize(context: Context, ame: AstrMessageEvent, text: str):
|
||||
|
||||
summary_prompt = f"""
|
||||
你是一个专业且高效的助手,你的任务是
|
||||
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
|
||||
2. 简单地发表你对这个问题的简略看法。
|
||||
1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结;
|
||||
2. 简单地发表你对这个问题的看法。
|
||||
|
||||
# 例子
|
||||
1. 从网上的信息来看,可以知道...我个人认为...你觉得呢?
|
||||
2. 根据网上的最新信息,可以得知...我觉得...你怎么看?
|
||||
|
||||
# 限制
|
||||
1. 限制在 200 字以内;
|
||||
1. 限制在 200-300 字;
|
||||
2. 请**直接输出总结**,不要输出多余的内容和提示语。
|
||||
|
||||
|
||||
# 相关材料
|
||||
{function_invoked_ret}"""
|
||||
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
|
||||
return ret
|
||||
return function_invoked_ret
|
||||
{text}"""
|
||||
|
||||
provider = context.get_current_llm_provider()
|
||||
return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id)
|
||||
@@ -1,30 +0,0 @@
|
||||
import time
|
||||
import asyncio
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
import psutil
|
||||
|
||||
from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
def run_monitor(global_object: Context):
|
||||
'''
|
||||
监测机器性能
|
||||
- Bot 内存使用量
|
||||
- CPU 占用率
|
||||
'''
|
||||
start_time = time.time()
|
||||
while True:
|
||||
stat = global_object.dashboard_data.stats
|
||||
# 程序占用的内存大小
|
||||
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
|
||||
stat['sys_perf'] = {
|
||||
'memory': mem,
|
||||
'cpu': psutil.cpu_percent()
|
||||
}
|
||||
stat['sys_start_time'] = start_time
|
||||
time.sleep(30)
|
||||
@@ -66,6 +66,9 @@ class MetricUploader():
|
||||
except BaseException as e:
|
||||
pass
|
||||
await asyncio.sleep(30*60)
|
||||
|
||||
def increment_platform_stat(self, platform_name: str):
|
||||
self.platform_stats[platform_name] = self.platform_stats.get(platform_name, 0) + 1
|
||||
|
||||
def clear(self):
|
||||
self.platform_stats.clear()
|
||||
|
||||
@@ -9,7 +9,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
class AstrBotUpdator(RepoZipUpdator):
|
||||
def __init__(self):
|
||||
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
|
||||
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
|
||||
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
@@ -30,9 +30,11 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
def _reboot(self, delay: int = None):
|
||||
if delay: time.sleep(delay)
|
||||
def _reboot(self, delay: int = None, context = None):
|
||||
# if delay: time.sleep(delay)
|
||||
py = sys.executable
|
||||
context.running = False
|
||||
time.sleep(3)
|
||||
self.terminate_child_processes()
|
||||
py = py.replace(" ", "\\ ")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user