Compare commits

...

16 Commits

Author SHA1 Message Date
Soulter 32e2a7830a feat: Add timeout parameter to QQOfficial bot client initialization 2024-08-17 03:20:08 -04:00
Soulter 6992249e53 refactor: Update image downloading method in ProviderOpenAIOfficial 2024-08-17 15:06:13 +08:00
Soulter 107214ac53 fix: Handle errors in AstrBotBootstrap gracefully 2024-08-17 15:01:55 +08:00
Soulter 8a58772911 perf: fill the missing metric record 2024-08-17 14:58:43 +08:00
Soulter e21736b470 perf: remove message reply when rate limit occur 2024-08-17 14:54:11 +08:00
Soulter e8679f8984 Create codeql.yml 2024-08-17 14:34:02 +08:00
Soulter 970fe02027 fix: 修复QQ官方机器人API聊天时不能找到平台的问题 #189 2024-08-17 14:30:35 +08:00
Soulter 12216853c5 chore: issue and pr template 2024-08-17 11:20:36 +08:00
Soulter 33ec92258d Update config.py 2024-08-13 15:05:16 +08:00
Soulter a578edf137 fix: metrics
perf: aiocqhttp image url
2024-08-12 02:50:31 -04:00
Soulter f8949ebead perf: reboot after installing plugin 2024-08-11 23:24:37 -04:00
Soulter 141c91301f perf: Improve sleep time handling in QQOfficial and ProviderOpenAIOfficial 2024-08-11 23:24:37 -04:00
Soulter 8d95e67b5a Update README.md 2024-08-11 17:13:49 +08:00
Soulter 0633e7f25f perf: improve the effects of local function-calling 2024-08-11 03:55:31 -04:00
Soulter 266da0a9d8 fix: 修复重启时 aiocqhttp 没有正常退出导致端口占用的问题 2024-08-11 02:30:49 -04:00
Soulter 121c40f273 perf: raise error when badrequest 2024-08-11 01:49:33 -04:00
23 changed files with 395 additions and 257 deletions
+82
View File
@@ -0,0 +1,82 @@
name: '🐛 报告 Bug'
title: '[Bug]'
description: 提交报告帮助我们改进。
labels: [ 'bug' ]
body:
- type: markdown
attributes:
value: |
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。
- type: textarea
attributes:
label: 发生了什么
description: 描述你遇到的异常
placeholder: >
一个清晰且具体的描述这个异常是什么。
validations:
required: true
- type: textarea
attributes:
label: 如何复现?
description: >
复现该问题的步骤
placeholder: >
如: 1. 打开 '...'
validations:
required: true
- type: textarea
attributes:
label: AstrBot 版本与部署方式
description: >
请提供您的 AstrBot 版本和部署方式。
placeholder: >
如: 3.1.8 Docker, 3.1.7 Windows启动器
validations:
required: true
- type: dropdown
attributes:
label: 操作系统
description: |
你在哪个操作系统上遇到了这个问题?
multiple: false
options:
- 'Windows'
- 'macOS'
- 'Linux'
- 'Other'
- 'Not sure'
validations:
required: true
- type: textarea
attributes:
label: 额外信息
description: >
任何额外信息,如报错日志、截图等。
placeholder: >
请提供完整的报错日志或截图。
validations:
required: true
- type: checkboxes
attributes:
label: 你愿意提交 PR 吗?
description: >
这绝对不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
options:
- label: 是的,我愿意提交 PR!
- type: checkboxes
attributes:
label: Code of Conduct
options:
- label: >
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true
- type: markdown
attributes:
value: "感谢您填写我们的表单!"
@@ -0,0 +1,42 @@
name: '🎉 功能建议'
title: "[Feature]"
description: 提交建议帮助我们改进。
labels: [ "enhancement" ]
body:
- type: markdown
attributes:
value: |
感谢您抽出时间提出新功能建议,请准确解释您的想法。
- type: textarea
attributes:
label: 描述
description: 简短描述您的功能建议。
- type: textarea
attributes:
label: 使用场景
description: 你想要发生什么?
placeholder: >
一个清晰且具体的描述这个功能的使用场景。
- type: checkboxes
attributes:
label: 你愿意提交PR吗?
description: >
这不是必须的,但我们欢迎您的贡献。
options:
- label: 是的, 我愿意提交PR!
- type: checkboxes
attributes:
label: Code of Conduct
options:
- label: >
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true
- type: markdown
attributes:
value: "感谢您填写我们的表单!"
+10
View File
@@ -0,0 +1,10 @@
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
修复了 #XYZ
### Motivation
<!--解释为什么要改动-->
### Modifications
<!--简单解释你的改动-->
+93
View File
@@ -0,0 +1,93 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
schedule:
- cron: '21 15 * * 5'
jobs:
analyze:
name: Analyze (${{ matrix.language }})
# Runner size impacts CodeQL analysis time. To learn more, please see:
# - https://gh.io/recommended-hardware-resources-for-running-codeql
# - https://gh.io/supported-runners-and-hardware-resources
# - https://gh.io/using-larger-runners (GitHub.com only)
# Consider using larger runners or machines with greater resources for possible analysis time improvements.
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }}
permissions:
# required for all workflows
security-events: write
# required to fetch internal or private CodeQL packs
packages: read
# only required for workflows in private repositories
actions: read
contents: read
strategy:
fail-fast: false
matrix:
include:
- language: python
build-mode: none
# CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
# Use `c-cpp` to analyze code written in C, C++ or both
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# If the analyze step fails for one of the languages you are analyzing with
# "We were unable to automatically build your code", modify the matrix above
# to set the build mode to "manual" for that language. Then modify this step
# to build your code.
# ️ Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
- if: matrix.build-mode == 'manual'
shell: bash
run: |
echo 'If you are using a "manual" build mode for one or more of the' \
'languages you are analyzing, replace this with the commands to build' \
'your code, for example:'
echo ' make bootstrap'
echo ' make release'
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"
+4
View File
@@ -45,6 +45,10 @@
有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/docs/get-started/plugin)
## 云部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ❤️ 贡献
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
+5 -2
View File
@@ -100,10 +100,13 @@ class AstrBotBootstrap():
try:
result = await task
return result
except asyncio.CancelledError:
logger.info(f"{task.get_name()} 任务已取消。")
return
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试")
await asyncio.sleep(5)
logger.error(f"{task.get_name()} 任务发生错误。")
return
def load_llm(self):
if 'openai' in self.config_helper.cached_config and \
+6 -4
View File
@@ -138,9 +138,11 @@ class MessageHandler():
# return MessageResult("Hi~")
# check the rate limit
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
# return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。")
return
# remove the nick prefix
for nick in self.nicks:
if msg_plain.startswith(nick):
@@ -189,7 +191,7 @@ class MessageHandler():
try:
if web_search:
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider)
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
else:
llm_result = await provider.text_chat(
prompt=msg_plain,
+8 -6
View File
@@ -192,10 +192,11 @@ class AstrBotDashBoard():
try:
logger.info(f"正在安装插件 {repo_url}")
self.plugin_manager.install_plugin(repo_url)
logger.info(f"安装插件 {repo_url} 成功")
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
return Response(
status="success",
message="安装成功~",
message="安装成功,机器人将在 2 秒内重启。",
data=None
).__dict__
except Exception as e:
@@ -258,10 +259,11 @@ class AstrBotDashBoard():
try:
logger.info(f"正在更新插件 {plugin_name}")
self.plugin_manager.update_plugin(plugin_name)
logger.info(f"更新插件 {plugin_name} 成功")
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response(
status="success",
message="更新成功~",
message="更新成功,机器人将在 2 秒内重启。",
data=None
).__dict__
except Exception as e:
@@ -311,7 +313,7 @@ class AstrBotDashBoard():
latest = False
try:
self.astrbot_updator.update(latest=latest, version=version)
threading.Thread(target=self.astrbot_updator._reboot, args=(3, )).start()
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
return Response(
status="success",
message="更新成功,机器人将在 3 秒内重启。",
@@ -374,7 +376,7 @@ class AstrBotDashBoard():
self.dashboard_data, self.context.config_helper.get_all())
# 重启
threading.Thread(target=self.astrbot_updator._reboot,
args=(2, ), daemon=True).start()
args=(2, self.context), daemon=True).start()
except Exception as e:
raise e
+1 -1
View File
@@ -53,7 +53,7 @@ if __name__ == "__main__":
check_env()
logger = LogManager.GetLogger(
log_name='astrbot',
log_name='astrbot',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
+2 -2
View File
@@ -117,11 +117,11 @@ class InternalCommandHandler:
success=False,
message_chain="你没有权限使用该指令",
)
context.updator._reboot(5)
context.updator._reboot(3, context)
return CommandResult(
hit=True,
success=True,
message_chain="AstrBot 将在 5s 后重启。",
message_chain="AstrBot 将在 3s 后重启。",
)
def plugin(self, message: AstrMessageEvent, context: Context):
+6 -3
View File
@@ -7,8 +7,9 @@ from type.astrbot_message import MessageType
class Platform():
def __init__(self) -> None:
pass
def __init__(self, platform_name: str, context) -> None:
self.PLATFORM_NAME = platform_name
self.context = context
@abc.abstractmethod
async def handle_msg(self, message: AstrBotMessage):
@@ -79,4 +80,6 @@ class Platform():
else:
rendered_images.append(Image.fromFileSystem(p))
return rendered_images
async def record_metrics(self):
self.context.metrics_uploader.increment_platform_stat(self.PLATFORM_NAME)
+6 -3
View File
@@ -18,6 +18,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AIOCQHTTP(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
super().__init__("aiocqhttp", context)
self.message_handler = message_handler
self.waiting = {}
self.context = context
@@ -67,7 +68,9 @@ class AIOCQHTTP(Platform):
message_str += m['data']['text'].strip()
abm.message.append(a)
if t == 'image':
a = Image(file=m['data']['file'])
file = m['data']['file'] if 'file' in m['data'] else None
url = m['data']['url'] if 'url' in m['data'] else None
a = Image(file=file, url=url)
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -99,7 +102,7 @@ class AIOCQHTTP(Platform):
return bot
async def shutdown_trigger_placeholder(self):
while True:
while self.context.running:
await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool:
@@ -195,9 +198,9 @@ class AIOCQHTTP(Platform):
await self._reply(message, res)
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
await self.record_metrics()
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
ret = []
image_idx = []
for idx, segment in enumerate(message_chain):
+2
View File
@@ -30,6 +30,7 @@ class FakeSource:
class QQGOCQ(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
super().__init__("nakuru", context)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
@@ -191,6 +192,7 @@ class QQGOCQ(Platform):
await self._reply(source, res)
async def _reply(self, source, message_chain: List[BaseMessageComponent]):
await self.record_metrics()
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
+8 -5
View File
@@ -53,7 +53,7 @@ class botClient(Client):
class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
super().__init__()
super().__init__("qqofficial", context)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
@@ -81,7 +81,8 @@ class QQOfficial(Platform):
)
self.client = botClient(
intents=self.intents,
bot_log=False
bot_log=False,
timeout=20,
)
self.client.set_platform(self)
@@ -178,7 +179,8 @@ class QQOfficial(Platform):
logger.error(traceback.format_exc())
self.client = botClient(
intents=self.intents,
bot_log=False
bot_log=False,
timeout=20,
)
self.client.set_platform(self)
return self.client.start(
@@ -216,7 +218,7 @@ class QQOfficial(Platform):
role = 'member'
# construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqchan", session_id, role)
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id, role)
message_result = await self.message_handler.handle(ame)
if not message_result:
@@ -321,6 +323,7 @@ class QQOfficial(Platform):
return await self._reply(**data)
async def _reply(self, **kwargs):
await self.record_metrics()
if 'group_openid' in kwargs or 'openid' in kwargs:
# QQ群组消息
if 'file_image' in kwargs and kwargs['file_image']:
@@ -408,4 +411,4 @@ class QQOfficial(Platform):
cnt += 1
if cnt > 300:
raise Exception("等待消息超时。")
time.sleep(1)()
time.sleep(1)
+35 -9
View File
@@ -5,6 +5,7 @@ import traceback
import uuid
import shutil
import yaml
import subprocess
from util.updator.plugin_updator import PluginUpdator
from util.io import remove_dir, download_file
@@ -84,8 +85,28 @@ class PluginManager():
def update_plugin_dept(self, path):
mirror = "https://mirrors.aliyun.com/pypi/simple/"
py = sys.executable
os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet")
# os.system(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com")
process = subprocess.Popen(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com",
stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True)
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
output = output.strip()
if output.startswith("Requirement already satisfied"):
continue
if output.startswith("Using cached"):
continue
if output.startswith("Looking in indexes"):
continue
logger.info(output)
rc = process.poll()
def install_plugin(self, repo_url: str):
ppath = self.plugin_store_path
@@ -95,10 +116,13 @@ class PluginManager():
plugin_path = self.updator.update(repo_url)
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
f.write(repo_url)
self.check_plugin_dept_update()
ok, err = self.plugin_reload()
if not ok:
raise Exception(err)
return plugin_path
# ok, err = self.plugin_reload()
# if not ok:
# raise Exception(err)
def download_from_repo_url(self, target_path: str, repo_url: str):
repo_namespace = repo_url.split("/")[-2:]
@@ -158,7 +182,7 @@ class PluginManager():
logger.info(f"正在加载插件 {root_dir_name} ...")
# self.check_plugin_dept_update(cached_plugins, root_dir_name)
self.check_plugin_dept_update(target_plugin=root_dir_name)
module = __import__("addons.plugins." +
root_dir_name + "." + p, fromlist=[p])
@@ -227,10 +251,12 @@ class PluginManager():
# remove the temp dir
remove_dir(temp_dir)
self.check_plugin_dept_update()
ok, err = self.plugin_reload()
if not ok:
raise Exception(err)
# ok, err = self.plugin_reload()
# if not ok:
# raise Exception(err)
def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata:
metadata = None
+9 -8
View File
@@ -1,5 +1,5 @@
import os
import sys
import asyncio
import json
import time
import tiktoken
@@ -11,11 +11,10 @@ from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
from util.io import download_image_by_url
from astrbot.persist.helper import dbConn
from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
@@ -154,7 +153,7 @@ class ProviderOpenAIOfficial(Provider):
将图片转换为 base64
'''
if image_url.startswith("http"):
image_url = await gu.download_image_by_url(image_url)
image_url = await download_image_by_url(image_url)
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode()
@@ -359,7 +358,7 @@ class ProviderOpenAIOfficial(Provider):
logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
retry += 1
raise e
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
@@ -369,7 +368,9 @@ class ProviderOpenAIOfficial(Provider):
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key()
rate_limit_retry += 1
time.sleep(1)
await asyncio.sleep(1)
except NotFoundError as e:
raise e
except Exception as e:
retry += 1
if retry >= 3:
@@ -381,7 +382,7 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(traceback.format_exc())
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
await asyncio.sleep(1)
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
@@ -452,7 +453,7 @@ class ProviderOpenAIOfficial(Provider):
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
await asyncio.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False
+2 -2
View File
@@ -1,4 +1,4 @@
VERSION = '3.3.8'
VERSION = '3.3.9'
DEFAULT_CONFIG = {
"qqbot": {
@@ -72,4 +72,4 @@ DEFAULT_CONFIG = {
"ws_reverse_host": "",
"ws_reverse_port": 0,
}
}
}
+1
View File
@@ -44,6 +44,7 @@ class Context:
self.ext_tasks: List[Task] = []
self.command_manager = None
self.running = True
# useless
self.reply_prefix = ""
+45 -165
View File
@@ -1,9 +1,7 @@
from model.provider.provider import Provider
import json
import util.general_utils as gu
import time
import textwrap
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
@@ -22,14 +20,11 @@ class FuncNotFoundError(Exception):
class FuncCall():
def __init__(self, provider) -> None:
def __init__(self, provider: Provider) -> None:
self.func_list = []
self.provider = provider
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj=None) -> None:
if name == None or func_args == None or desc == None or func_obj == None:
raise FuncCallJsonFormatError(
"name, func_args, desc must be provided.")
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
params = {
"type": "object", # hardcore here
"properties": {}
@@ -47,7 +42,7 @@ class FuncCall():
}
self.func_list.append(self._func)
def func_dump(self, intent: int = 2) -> str:
def func_dump(self) -> str:
_l = []
for f in self.func_list:
_l.append({
@@ -55,7 +50,7 @@ class FuncCall():
"parameters": f["parameters"],
"description": f["description"],
})
return json.dumps(_l, indent=intent, ensur_ascii=False)
return json.dumps(_l, ensure_ascii=False)
def get_func(self) -> list:
_l = []
@@ -70,64 +65,36 @@ class FuncCall():
})
return _l
def func_call(self, question, func_definition, is_task=False, tasks=None, taskindex=-1, is_summary=True, session_id=None):
async def func_call(self, question: str, func_definition: str, session_id: str=None):
funccall_prompt = """
我正实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(意味着你不是创造函数)。
下面会给你提供可能用到的函数相关信息和一个问题,你需要将其转换成给定的函数调用。
- 你的返回信息只含json,请严格仿照以下内容(不含注释),必须含有`res`,`func_call`字段:
```
{
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
"func_call": [ // 这是一个数组,里面包含了所有的函数调用,如果没有函数调用,那么这个数组是空数组。
{
"res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
"name": str, // 函数的名字
"args_type": {
"arg1": str, // 函数的参数的类型
"arg2": str,
...
},
"args": {
"arg1": any, // 函数的参数
"arg2": any,
...
}
},
... // 可能在这个问题中会有多个函数调用
],
}
```
- 如果用户的要求较复杂,允许返回多个函数调用,但需保证这些函数调用的顺序正确。
- 当问题没有提到给定的函数时,相当于提问方不打算使用function call功能,这时你可以在res中正常输出这个问题的回答(以AI的身份正常回答该问题,并将答案输出在res字段中,回答不要涉及到任何函数调用的内容,就只是正常讨论这个问题。)
prompt = textwrap.dedent(f"""
ROLE:
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
提供的函数是:
TOOLS:
可用的函数列表:
"""
{func_definition}
prompt = f"{funccall_prompt}\n```\n{func_definition}\n```\n"
prompt += f"""
用户的提问是:
```
{question}
```
"""
LIMIT:
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
# if is_task:
# # task_prompt = f"\n任务列表为{str(tasks)}\n你目前进行到了任务{str(taskindex)}, **你不需要重新进行已经进行过的任务, 不要生成已经进行过的**"
# prompt += task_prompt
EXAMPLE:
1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
# provider.forget()
用户的提问是:{question}
""")
_c = 0
while _c < 3:
try:
res = self.provider.text_chat(prompt=prompt, session_id=session_id)
res = await self.provider.text_chat(prompt, session_id)
print(res)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT func_call json result",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
print(res)
res = json.loads(res)
break
except Exception as e:
@@ -136,112 +103,25 @@ class FuncCall():
raise e
if "The message you submitted was too long" in str(e):
raise e
if 'res' in res and not res['res']:
return "", False
invoke_func_res = ""
if "func_call" in res and len(res["func_call"]) > 0:
task_list = res["func_call"]
invoke_func_res_list = []
for res in task_list:
# 说明有函数调用
func_name = res["name"]
# args_type = res["args_type"]
args = res["args"]
# 调用函数
# func = eval(func_name)
func_target = None
for func in self.func_list:
if func["name"] == func_name:
func_target = func["func_obj"]
break
if func_target == None:
raise FuncNotFoundError(
f"Request function {func_name} not found.")
t_res = str(func_target(**args))
invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n"
invoke_func_res_list.append(invoke_func_res)
gu.log(f"[FUNC| {func_name} invoked]",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
# print(str(t_res))
if is_summary:
# 生成返回结果
after_prompt = """
有以下内容:"""+invoke_func_res+"""
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
用户的提问是:
```""" + question + """```
- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
```json
{
"res": string, // 回答的内容
"func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
}
```
- 如果func_call_again为trueres请你设为空值,否则请你填写回答的内容。"""
_c = 0
while _c < 5:
try:
res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
# 截取```之间的内容
gu.log(
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
print(res)
gu.log(
"DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
if res.find('```') != -1:
res = res[res.find('```json') +
7: res.rfind('```')]
gu.log("REVGPT after_func_call json result",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
after_prompt_res = res
after_prompt_res = json.loads(after_prompt_res)
break
except Exception as e:
_c += 1
if _c == 5:
raise e
if "The message you submitted was too long" in str(e):
# 如果返回的内容太长了,那么就截取一部分
time.sleep(3)
invoke_func_res = invoke_func_res[:int(
len(invoke_func_res) / 2)]
after_prompt = """
函数返回以下内容:"""+invoke_func_res+"""
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
用户的提问是:
```""" + question + """```
- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
```json
{
"res": string, // 回答的内容
"func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
}
```
- 如果func_call_again为trueres请你设为空值,否则请你填写回答的内容。"""
else:
raise e
if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]:
# 如果需要重新调用函数
# 重新调用函数
gu.log("REVGPT func_call_again",
bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"])
res = self.func_call(question, func_definition)
return res, True
gu.log("REVGPT func callback:",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
# print(after_prompt_res["res"])
return after_prompt_res["res"], True
else:
return str(invoke_func_res_list), True
else:
# print(res["res"])
return res["res"], False
tool_call_result = []
for tool in res:
# 说明有函数调用
func_name = tool["name"]
args = tool["args"]
# 调用函数
tool_callable = None
for func in self.func_list:
if func["name"] == func_name:
tool_callable = func["func_obj"]
break
if not tool_callable:
raise FuncNotFoundError(
f"Request function {func_name} not found.")
ret = await tool_callable(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
+21 -15
View File
@@ -1,13 +1,13 @@
import traceback
import random
import json
import asyncio
import aiohttp
import os
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from openai._exceptions import *
from util.agent.func_call import FuncCall
from util.websearch.config import HEADERS, USER_AGENTS
from util.websearch.bing import Bing
@@ -100,9 +100,9 @@ async def fetch_website_content(url):
return ret
async def web_search(prompt, provider: Provider, session_id, official_fc=False):
async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False):
'''
official_fc: 使用官方 function-calling
@param official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
@@ -127,9 +127,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
function_invoked_ret = ""
if official_fc:
# we use official function-calling
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
try:
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
except BadRequestError as e:
# seems dont support function-calling
logger.error(f"error: {e}. Try to use local function-calling implementation")
return await web_search(prompt, provider, session_id, official_fc=False)
if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}")
logger.debug(f"function-calling: {result}")
func_obj = None
for i in new_func_call.func_list:
if i["name"] == result.name:
@@ -152,30 +157,31 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
args = {
'question': prompt,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
function_invoked_ret, has_func = await new_func_call.func_call(**args)
if not has_func:
return await provider.text_chat(prompt, session_id)
except BaseException as e:
res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
logger.error(traceback.format_exc())
return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)"
if has_func:
await provider.forget(session_id=session_id, )
await provider.forget(session_id=session_id)
summary_prompt = f"""
你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
2. 简单地发表你对这个问题的简略看法。
2. 简单地发表你对这个问题的看法。
# 例子
1. 从网上的信息来看,可以知道...我个人认为...你觉得呢?
2. 根据网上的最新信息,可以得知...我觉得...你怎么看?
# 限制
1. 限制在 200 字以内
1. 限制在 200-300 字;
2. 请**直接输出总结**,不要输出多余的内容和提示语。
# 相关材料
{function_invoked_ret}"""
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
-30
View File
@@ -1,30 +0,0 @@
import time
import asyncio
import requests
import json
import sys
import psutil
from type.types import Context
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot')
def run_monitor(global_object: Context):
'''
监测机器性能
- Bot 内存使用量
- CPU 占用率
'''
start_time = time.time()
while True:
stat = global_object.dashboard_data.stats
# 程序占用的内存大小
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
stat['sys_perf'] = {
'memory': mem,
'cpu': psutil.cpu_percent()
}
stat['sys_start_time'] = start_time
time.sleep(30)
+3
View File
@@ -65,6 +65,9 @@ class MetricUploader():
except BaseException as e:
pass
await asyncio.sleep(30*60)
def increment_platform_stat(self, platform_name: str):
self.platform_stats[platform_name] = self.platform_stats.get(platform_name, 0) + 1
def clear(self):
self.platform_stats.clear()
+4 -2
View File
@@ -30,9 +30,11 @@ class AstrBotUpdator(RepoZipUpdator):
except psutil.NoSuchProcess:
pass
def _reboot(self, delay: int = None):
if delay: time.sleep(delay)
def _reboot(self, delay: int = None, context = None):
# if delay: time.sleep(delay)
py = sys.executable
context.running = False
time.sleep(3)
self.terminate_child_processes()
py = py.replace(" ", "\\ ")
try: