Compare commits

...

48 Commits

Author SHA1 Message Date
Soulter f8ab40eb39 chore: 上传管理面板package.json 2025-01-09 22:25:46 +08:00
Soulter 55e9233b93 docs: v3.4.3 changelog 2025-01-09 22:19:11 +08:00
Soulter b7277b51fd feat: 管理面板支持显示不在metadata中的配置 2025-01-09 22:03:53 +08:00
Soulter 1fa9111b2b perf: 进一步防止llm递归调用 2025-01-09 22:03:22 +08:00
Soulter 90a9e496d9 feat: 适配器类插件支持设置默认配置模板 2025-01-09 19:45:18 +08:00
Soulter 2a7dce1eb0 chore: clean code 2025-01-09 16:34:39 +08:00
Soulter 0c0841cc03 fix: websearch 在 cmd_config 中失效的问题 2025-01-09 16:33:58 +08:00
Soulter 4c9fe016bf fix: test_pipeline 2025-01-09 16:00:43 +08:00
Soulter acc90f140c chore: bump dashboard_release_url 2025-01-09 15:50:24 +08:00
Soulter 68a7bc3930 Merge pull request #232 from Soulter/feat-python-interpreter
初步实现代码执行器
2025-01-09 15:43:40 +08:00
Soulter 12ea64be0e fix: dashboard command bug 2025-01-09 15:42:04 +08:00
Soulter 7f30a673f7 fix: 修复 qq_official 无法发图 2025-01-09 15:20:54 +08:00
Soulter 897e100c32 Merge pull request #234 from Soulter/233-gemini-native-support
支持通过 Google GenAI 访问 Gemini 模型
2025-01-09 14:23:44 +08:00
Soulter 0d4ad5cb31 fix: 修复 APScheduler 任务错过后不执行的问题 2025-01-09 14:23:07 +08:00
Soulter b124bd0d0e feat: 支持通过 Google GenAI 访问 Gemini 模型 2025-01-09 14:05:48 +08:00
Soulter 6bc2f84602 Update README.md
qingcloud 在新网的账户余额不足导致原域名无法续费
2025-01-09 10:35:02 +08:00
Soulter d787a28c40 feat: 支持使用 /dashboard update 更新管理面板 2025-01-09 00:59:28 +08:00
Soulter 6b078a5731 cd: build dashboard files automatically 2025-01-09 00:57:48 +08:00
Soulter 17dddbfe21 chore: 禁用插件 2025-01-08 23:34:54 +08:00
Soulter 3ff3c9e144 perf: 检测到docker不可用时自动禁用本插件 2025-01-08 23:32:49 +08:00
Soulter f5a37d82cc Merge branch 'master' into feat-python-interpreter 2025-01-08 23:13:52 +08:00
Soulter d3d428dc9d fix: 管理面板支持禁用/启用插件 2025-01-08 23:04:03 +08:00
Soulter 8dc8c5b5dc feat: 支持对插件禁用/启用 2025-01-08 22:28:20 +08:00
Soulter e6b06f914b perf: provider 偏好项记忆 2025-01-08 20:46:34 +08:00
Soulter 4dc502a8b6 fix: 修复事件监听器会让wakestage失效的问题 2025-01-08 20:24:01 +08:00
Soulter b1d1a13d5f perf: 支持图片输入 2025-01-08 19:56:03 +08:00
Soulter 75cc4cac5a perf: 代码执行器添加部分控制指令,添加更多可用库 2025-01-08 13:26:16 +08:00
Soulter 1b7e4fbbdc perf: 退出时关闭 aiohttp client session 2025-01-08 12:43:34 +08:00
Soulter 9789e2f6c1 perf: 代码执行器请求llm不持久化历史记录 2025-01-08 02:12:35 +08:00
Soulter b8fb0bee24 feat: 初步实现代码执行器 #210 2025-01-08 02:10:27 +08:00
Soulter 419f77e245 Update README.md 2025-01-07 20:56:25 +08:00
Soulter 59b1c3473b Merge pull request #230 from Soulter/feat-dify
接入 Dify
2025-01-07 20:14:33 +08:00
Soulter 6db58ca375 perf: 优化在prompt为空的情况下不请求provider 2025-01-07 20:01:47 +08:00
Soulter 4832b342b0 Merge branch 'master' into feat-dify 2025-01-07 19:59:54 +08:00
Soulter 6cec542402 feat: 初步接入 Dify 2025-01-07 19:56:18 +08:00
Soulter 9644791783 feat: kdb 2024-12-30 18:06:09 +08:00
Soulter 5031c307d1 update: readme 2024-12-26 23:39:29 +08:00
Soulter aa49539e3e chore: fix test 2024-12-26 23:33:40 +08:00
Soulter 7b4118493b chore: fix test 2024-12-26 23:15:10 +08:00
Soulter d1cc9ba4ce chore: update test workflow 2024-12-26 23:09:11 +08:00
Soulter e0e92139d7 fix: test workflow 2024-12-26 23:07:50 +08:00
Soulter 62039392bb chore: fix test workflow 2024-12-26 23:06:30 +08:00
Soulter b72c69892e test: dashboard test 2024-12-26 22:59:17 +08:00
Soulter e6205e9aad ci: update workflow 2024-12-25 17:18:29 +08:00
Soulter b8a6fb1720 chore: update tests 2024-12-25 12:50:29 +08:00
Soulter 7c06d82f27 perf: plugin manager 重复 reload 释放资源 2024-12-25 12:50:29 +08:00
Soulter d92cb0f500 perf: 当没有provider时直接返回 2024-12-25 12:50:29 +08:00
Soulter 7fa72f2fe9 perf: adapt glm-4v-flash 2024-12-24 14:08:20 +08:00
66 changed files with 12476 additions and 156 deletions
+12 -1
View File
@@ -2,6 +2,7 @@ on:
push:
tags:
- 'v*'
workflow_dispatch:
name: Auto Release
@@ -14,6 +15,15 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4
- name: Dashboard Build
run: |
cd dashboard
npm install
npm run build
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
@@ -21,4 +31,5 @@ jobs:
- name: Create Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"
+15 -9
View File
@@ -1,7 +1,14 @@
name: Run tests and upload coverage
on:
push
push:
branches:
- master
paths-ignore:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
workflow_dispatch:
jobs:
test:
@@ -21,17 +28,16 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
mkdir data
mkdir data/plugins
mkdir data/config
mkdir temp
- name: Run tests
run: |
export LLM_MODEL=${{ secrets.LLM_MODEL }}
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
+36 -3
View File
@@ -14,9 +14,10 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
</a>
<a href="https://astrbot.soulter.top/">查看文档</a>
<a href="https://astrbot.lwl.lol/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
@@ -24,7 +25,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ✨ 多消息平台部署
1. QQ 群、QQ 频道、微信、Telegram。
1. QQ 群、QQ 频道、微信个人号、Telegram。
2. 支持文本转图片,Markdown 渲染。
## ✨ 多 LLM 配置
@@ -33,7 +34,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
2. 支持 OneAPI 等分发平台。
3. 支持 LLMTuner 载入微调模型。
4. 支持 Ollama 载入自部署模型。
4. 支持网页搜索(Web Search)。
4. 支持网页搜索(Web Search、自然语言待办提醒
## ✨ 管理面板
@@ -42,6 +43,38 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
3. 简单的信息统计
4. 插件管理
## ✨ 支持 Dify
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流!
## ✨ Demo
<div align='center'>
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/caadf2bd-a0ee-43d0-a95e-566d63e3e34d" height=330>
<img src="https://github.com/user-attachments/assets/b418f281-e920-49db-9fe1-d6a13ce28a84" height=350>
_✨ 管理面板 ✨_
</div>
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
+2
View File
@@ -2,6 +2,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = [
@@ -10,4 +11,5 @@ __all__ = [
"personalities",
"html_renderer",
"llm_tool",
"sp"
]
+7 -1
View File
@@ -1,6 +1,7 @@
import os
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
@@ -8,5 +9,10 @@ os.makedirs("data", exist_ok=True)
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
if os.environ.get('TESTING', ""):
logger.setLevel('DEBUG')
db_helper = SQLiteDatabase(DB_PATH)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
sp = SharedPreferences() # 简单的偏好设置存储
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+47 -6
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.2"
VERSION = "3.4.3"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -50,7 +50,8 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": ""
"plugin_repo_mirror": "",
"knowledge_db": {},
}
@@ -230,10 +231,10 @@ CONFIG_METADATA_2 = {
},
},
"provider_group": {
"name": "大语言模型",
"name": "服务提供商",
"metadata": {
"provider": {
"description": "大语言模型配置",
"description": "服务提供商配置",
"type": "list",
"config_template": {
"openai": {
@@ -256,7 +257,7 @@ CONFIG_METADATA_2 = {
"model": "llama3.1-8b",
},
},
"gemini": {
"gemini(OpenAI兼容)": {
"id": "gemini_default",
"type": "openai_chat_completion",
"enable": True,
@@ -266,6 +267,16 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash",
},
},
"gemini(googlegenai原生)": {
"id": "gemini_default",
"type": "googlegenai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/",
"model_config": {
"model": "gemini-1.5-flash",
},
},
"deepseek": {
"id": "deepseek_default",
"type": "openai_chat_completion",
@@ -278,7 +289,7 @@ CONFIG_METADATA_2 = {
},
"zhipu": {
"id": "zhipu_default",
"type": "openai_chat_completion",
"type": "zhipu_chat_completion",
"enable": True,
"key": [],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
@@ -295,6 +306,15 @@ CONFIG_METADATA_2 = {
"llmtuner_template": "",
"finetuning_type": "lora",
"quantization_bit": 4,
},
"dify": {
"id": "dify_app_default",
"type": "dify",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
}
},
"items": {
@@ -366,6 +386,27 @@ CONFIG_METADATA_2 = {
"top_p": {"description": "Top P值", "type": "float"},
},
},
"dify_api_key": {
"description": "API Key",
"type": "string",
"hint": "Dify API Key。此项必填。",
},
"dify_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Dify API Base URL。默认为 https://api.dify.ai/v1",
},
"dify_api_type": {
"description": "Dify 应用类型",
"type": "string",
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
"options": ["chat", "agent", "workflow"],
},
"dify_workflow_output_key": {
"description": "Dify Workflow 输出变量名",
"type": "string",
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
}
},
},
"provider_settings": {
+18 -5
View File
@@ -16,6 +16,7 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
@@ -29,7 +30,10 @@ class AstrBotCoreLifecycle:
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
logger.setLevel(self.astrbot_config['log_level'])
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
else:
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
self.event_queue.closed = False
@@ -37,12 +41,19 @@ class AstrBotCoreLifecycle:
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
self.star_context.platform_manager = self.platform_manager
self.star_context.provider_manager = self.provider_manager
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
await self.plugin_manager.reload()
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
@@ -81,6 +92,8 @@ class AstrBotCoreLifecycle:
self.event_queue.closed = True
for task in self.curr_tasks:
task.cancel()
await self.provider_manager.terminate()
for task in self.curr_tasks:
try:
+14 -1
View File
@@ -54,6 +54,7 @@ class ComponentType(Enum):
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""
class File(BaseMessageComponent):
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url(本地路径)
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
"plain": Plain,
@@ -441,5 +453,6 @@ ComponentTypes = {
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
"unknown": Unknown,
'file': File,
}
@@ -0,0 +1,60 @@
'''
Dify 调用 Stage
'''
import traceback
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
class DifyRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider.meta().type != "dify":
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
event.set_extra("provider_request", req)
if not req.prompt:
return
try:
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
yield # rick roll
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e)))
return
@@ -1,3 +1,6 @@
'''
本地 Agent 模式的 LLM 调用 Stage
'''
import traceback
from typing import Union, AsyncGenerator
from ...context import PipelineContext
@@ -15,10 +18,13 @@ class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider is None:
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
@@ -38,6 +44,9 @@ class LLMRequestSubStage(Stage):
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt:
return
# 执行请求 LLM 前事件。
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
@@ -48,7 +57,9 @@ class LLMRequestSubStage(Stage):
logger.error(traceback.format_exc())
try:
logger.debug(f"请求 LLM{req.__dict__}")
logger.debug(f"提供商请求 Payload: {req.__dict__}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
@@ -82,7 +93,7 @@ class LLMRequestSubStage(Stage):
for tool_name, tool_result in function_calling_result.items():
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
req.prompt += extra_prompt
async for _ in self.process(event):
async for _ in self.process(event, _nested=True):
yield
except BaseException as e:
@@ -1,3 +1,6 @@
'''
本地 Agent 模式的 AstrBot 插件调用 Stage
'''
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
@@ -24,7 +27,7 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_str not in star_map:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
@@ -36,7 +39,7 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
+20 -7
View File
@@ -3,6 +3,7 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from .method.dify_request import DifyRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest
@@ -20,11 +21,15 @@ class ProcessStage(Stage):
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
self.dify_request_sub_stage = DifyRequestSubStage()
await self.dify_request_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
# 有插件 Handler 被激活
if activated_handlers:
async for resp in self.star_request_sub_stage.process(event):
# 生成器返回值处理
@@ -36,10 +41,18 @@ class ProcessStage(Stage):
yield
else:
yield
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
if not event._has_send_oper:
'''当没有发送操作'''
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
async for _ in self.llm_request_sub_stage.process(event):
yield
# 调用提供商相关请求
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if not event._has_send_oper and event.is_at_or_wake_command:
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
provider = self.ctx.plugin_manager.context.get_using_provider()
match provider.meta().type:
case "dify":
async for _ in self.dify_request_sub_stage.process(event):
yield
case _:
async for _ in self.llm_request_sub_stage.process(event):
yield
+3 -1
View File
@@ -22,4 +22,6 @@ class RespondStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
await handler.handler(event)
await handler.handler(event)
event.clear_result()
-1
View File
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
try:
ready_to_call = handler(event, **params)
except TypeError as e:
print(e)
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
@@ -47,6 +47,7 @@ class WakingCheckStage(Stage):
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
break
is_wake = True
event.is_at_or_wake_command = True
event.is_wake = True
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
@@ -60,11 +61,13 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
wake_prefix = ""
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
wake_prefix = ""
# 检查插件的 handler filter
+2 -1
View File
@@ -35,7 +35,8 @@ class AstrMessageEvent(abc.ABC):
self.platform_meta = platform_meta
self.session_id = session_id
self.role = "member"
self.is_wake = False
self.is_wake = False # 是否通过 WakingStage
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
+3 -1
View File
@@ -2,4 +2,6 @@ from dataclasses import dataclass
@dataclass
class PlatformMetadata():
name: str # 平台的名称
description: str # 平台的描述
description: str # 平台的描述
default_config_tmpl: dict = None # 平台的默认配置模板
+13 -2
View File
@@ -7,15 +7,26 @@ platform_registry: List[PlatformMetadata] = []
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(adapter_name: str, desc: str):
'''用于注册平台适配器的带参装饰器'''
def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
'''用于注册平台适配器的带参装饰器
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
'''
def decorator(cls):
if adapter_name in platform_cls_map:
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = adapter_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
@@ -1,3 +1,4 @@
import os
import time
import asyncio
import logging
@@ -5,12 +6,13 @@ from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
@@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
def convert_message(self, event: Event) -> AstrBotMessage:
async def convert_message(self, event: Event) -> AstrBotMessage:
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.tag = "aiocqhttp"
@@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform):
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data'])
elif t == 'file':
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform):
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@@ -31,11 +31,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if image_path:
@@ -73,9 +75,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:])
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
return plain_text, image_base64, image_file_path
@@ -2,6 +2,7 @@ import sys
import time
import uuid
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
@@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform):
self.start_time = int(time.time())
return self._run()
async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
+4 -4
View File
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List, Dict
from .func_tool_manager import FuncCall
@@ -32,9 +32,9 @@ class ProviderRequest():
class LLMResponse:
role: str
'''角色'''
completion_text: str = None
completion_text: str = ""
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = None
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
'''工具调用参数'''
tools_call_name: List[str] = None
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
@@ -14,6 +14,7 @@ class FuncTool:
parameters: Dict
description: str
handler: Awaitable
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
active: bool = True
'''是否激活'''
@@ -100,10 +101,29 @@ class FuncCall:
}
)
return _l
def get_func_desc_google_genai_style(self) -> Dict:
declarations = {}
tools = []
for f in self.func_list:
if not f.active:
continue
tools.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f["name"],
@@ -169,3 +189,10 @@ class FuncCall:
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)
+29 -6
View File
@@ -5,7 +5,7 @@ from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
from .register import provider_cls_map, llm_tools
from astrbot.core import logger
from astrbot.core import logger, sp
class ProviderManager():
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
@@ -18,6 +18,11 @@ class ProviderManager():
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
@@ -29,9 +34,15 @@ class ProviderManager():
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
case "dify":
from .sources.dify_source import ProviderDify # noqa: F401
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
async def initialize(self):
@@ -39,19 +50,31 @@ class ProviderManager():
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
selected_provider_id = sp.get("curr_provider")
cls_type = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id']:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 失败:{e}")
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
if len(self.provider_insts) > 0:
if len(self.provider_insts) > 0 and not self.curr_provider_inst:
self.curr_provider_inst = self.provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何提供商适配器。")
def get_insts(self):
return self.provider_insts
return self.provider_insts
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()
@@ -0,0 +1,131 @@
from typing import List
from .. import Provider
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
) -> None:
super().__init__(
provider_config, provider_settings, persistant_history, db_helper
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_client = DifyAPIClient(self.api_key, api_base)
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
self.model_name = "dify"
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
self.conversation_ids = {}
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
result = ""
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
file_response = await self.api_client.file_upload(image_path, user=session_id)
if 'id' not in file_response:
logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。")
continue
files_payload.append({
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response['id'],
})
else:
# TODO: 处理更多情况
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
logger.debug(files_payload)
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk['event'] == "message" or \
chunk['event'] == "agent_message":
result += chunk['answer']
if not conversation_id:
self.conversation_ids[session_id] = chunk['conversation_id']
conversation_id = chunk['conversation_id']
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
"astrbot_text_query": prompt,
"astrbot_session_id": session_id
},
user=session_id,
files=files_payload
):
match chunk['event']:
case "workflow_started":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
case "node_finished":
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
case "workflow_finished":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
if chunk['data']['error']:
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
if self.workflow_output_key not in chunk['data']['outputs']:
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
result = chunk['data']['outputs'][self.workflow_output_key]
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
return LLMResponse(role="assistant", completion_text=result)
async def forget(self, session_id):
self.conversation_ids.pop(session_id, None)
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("Dify 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 Dify 的历史消息记录。")
async def terminate(self):
await self.api_client.close()
@@ -0,0 +1,287 @@
import traceback
import base64
import json
import aiohttp
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
class SimpleGoogleGenAIClient():
def __init__(self, api_key: str, api_base: str):
self.api_key = api_key
if api_base.endswith("/"):
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession()
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
async with self.client.get(request_url, timeout=10) as resp:
response = await resp.json()
models = []
for model in response["models"]:
if 'generateContent' in model["supportedGenerationMethods"]:
models.append(model["name"].replace("models/", ""))
return models
async def generate_content(
self,
contents: List[dict],
model: str="gemini-1.5-flash",
system_instruction: str="",
tools: dict=None
):
payload = {}
if system_instruction:
payload["system_instruction"] = {
"parts": {"text": system_instruction}
}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
logger.debug(f"payload: {payload}")
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
async with self.client.post(request_url, json=payload, timeout=10) as resp:
response = await resp.json()
return response
@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器")
class ProviderGoogleGenAI(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.client = SimpleGoogleGenAIClient(
api_key=self.chosen_api_key,
api_base=provider_config.get("api_base", None)
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
return await self.client.models_list()
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
tool = tools.get_func_desc_google_genai_style()
system_instruction = ""
for message in payloads["messages"]:
if message["role"] == "system":
system_instruction = message["content"]
break
google_genai_conversation = []
for message in payloads["messages"]:
if message["role"] == "user":
if isinstance(message["content"], str):
google_genai_conversation.append({
"role": "user",
"parts": [{"text": message["content"]}]
})
elif isinstance(message["content"], list):
# images
parts = []
for part in message["content"]:
if part["type"] == "text":
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append({"inline_data": {
"mime_type": "image/jpeg",
"data": part["image_url"]["url"].replace("data:image/jpeg;base64,", "") # base64
}})
google_genai_conversation.append({
"role": "user",
"parts": parts
})
elif message["role"] == "assistant":
google_genai_conversation.append({
"role": "model",
"parts": [{"text": message["content"]}]
})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool
)
logger.debug(f"result: {result}")
candidates = result["candidates"][0]['content']['parts']
llm_response = LLMResponse("assistant")
for candidate in candidates:
if 'text' in candidate:
llm_response.completion_text += candidate['text']
elif 'functionCall' in candidate:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate['functionCall']['args'])
llm_response.tools_call_name.append(candidate['functionCall']['name'])
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
}
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
self.session_memory[session_id] = [*contexts, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
def get_current_key(self) -> str:
return self.client.api_key
def get_keys(self) -> List[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
组装上下文。
'''
if image_urls:
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
else:
return {"role": "user","content": text}
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
return ''
@@ -162,7 +162,12 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
@@ -180,8 +185,6 @@ class ProviderOpenAIOfficial(Provider):
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
return llm_response
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
@@ -0,0 +1,73 @@
import traceback
from astrbot.core.db import BaseDatabase
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
class ProviderZhipu(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
# glm-4v-flash 只支持一张图片
model: str = model_cfgs.get("model", "")
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
new_context_query_ = []
for i in range(0, len(context_query) - 1, 2):
if isinstance(context_query[i].get("content", ""), list):
continue
new_context_query_.append(context_query[i])
new_context_query_.append(context_query[i+1])
new_context_query_.append(context_query[-1]) # 保留最后一条记录
context_query = new_context_query_
logger.debug(context_query)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**model_cfgs
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
@@ -0,0 +1,25 @@
from typing import List
from openai import AsyncOpenAI
class SimpleOpenAIEmbedding():
def __init__(
self,
model,
api_key,
api_base=None,
) -> None:
self.client = AsyncOpenAI(
api_key=api_key,
base_url=api_base
)
self.model = model
async def get_embedding(self, text) -> List[float]:
'''
获取文本的嵌入
'''
embedding = await self.client.embeddings.create(
input=text,
model=self.model
)
return embedding.data[0].embedding
+92
View File
@@ -0,0 +1,92 @@
import os
from typing import List, Dict
from astrbot.core import logger
from .store import Store
from astrbot.core.config import AstrBotConfig
class KnowledgeDBManager():
def __init__(self, astrbot_config: AstrBotConfig) -> None:
self.db_path = "data/knowledge_db/"
self.config = astrbot_config.get("knowledge_db", {})
self.astrbot_config = astrbot_config
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
self.store_insts: Dict[str, Store] = {}
for name, cfg in self.config.items():
if cfg["strategy"] == "embedding":
logger.info(f"加载 Chroma Vector Store{name}")
try:
from .store.chroma_db import ChromaVectorStore
except ImportError as ie:
logger.error(f"{ie} 可能未安装 chromadb 库。")
continue
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
else:
logger.error(f"不支持的策略:{cfg['strategy']}")
async def list_knowledge_db(self) -> List[str]:
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
async def create_knowledge_db(self, name: str, config: Dict):
'''
config 格式:
```
{
"strategy": "embedding", # 目前只支持 embedding
"chunk_method": {
"strategy": "fixed",
"chunk_size": 100,
"overlap_size": 10
},
"embedding_config": {
"strategy": "openai",
"base_url": "",
"model": "",
"api_key": ""
}
}
```
'''
if name in self.config:
raise ValueError(f"知识库已存在:{name}")
self.config[name] = config
self.astrbot_config["knowledge_db"] = self.config
self.astrbot_config.save_config()
async def insert_record(self, name: str, text: str):
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
ret = []
match self.config[name]["chunk_method"]['strategy']:
case "fixed":
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
case _:
pass
for chunk in ret:
await self.store_insts[name].save(chunk)
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
inst = self.store_insts[name]
return await inst.query(query, top_n)
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - chunk_overlap
return chunks
+8
View File
@@ -0,0 +1,8 @@
from typing import List
class Store():
async def save(self, text: str):
pass
async def query(self, query: str, top_n: int = 3) -> List[str]:
pass
+39
View File
@@ -0,0 +1,39 @@
import chromadb
import uuid
from typing import List, Dict
from astrbot.api import logger
from ..embedding.openai_source import SimpleOpenAIEmbedding
from . import Store
class ChromaVectorStore(Store):
def __init__(self, name: str, embedding_cfg: Dict) -> None:
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
self.collection = self.chroma_client.get_or_create_collection(name=name)
self.embedding = None
if embedding_cfg["strategy"] == "openai":
self.embedding = SimpleOpenAIEmbedding(
model=embedding_cfg["model"],
api_key=embedding_cfg["api_key"],
api_base=embedding_cfg.get("base_url", None)
)
async def save(self, text: str, metadata: Dict = None):
logger.debug(f"Saving text: {text}")
embedding = await self.embedding.get_embedding(text)
self.collection.upsert(
documents=text,
metadatas=metadata,
ids=str(uuid.uuid4()),
embeddings=embedding
)
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
embedding = await self.embedding.get_embedding(query)
results = self.collection.query(
query_embeddings=embedding,
n_results=top_n,
where=metadata_filter
)
return results['documents'][0]
+28 -3
View File
@@ -1,6 +1,7 @@
from asyncio import Queue
from typing import List, TypedDict, Union
from astrbot.core import sp
from astrbot.core.provider.provider import Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -14,6 +15,7 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class StarCommand(TypedDict):
full_command_name: str
@@ -38,11 +40,22 @@ class Context:
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
def __init__(self,
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
self._config = config
self._db = db
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
for star in star_registry:
@@ -73,7 +86,7 @@ class Context:
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_str=func_obj.__module__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
@@ -94,6 +107,12 @@ class Context:
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
@@ -105,6 +124,12 @@ class Context:
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
@@ -125,7 +150,7 @@ class Context:
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_str=awaitable.__module__,
handler_module_path=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
+3
View File
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
return False
# if len(self.handler_params) == 0 and len(ls) > 1:
# # 一定程度避免 LLM 聊天时误判为指令
# return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 去除空字符串
+2 -2
View File
@@ -28,7 +28,7 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_str=handler.__module__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[]
)
@@ -185,7 +185,7 @@ def register_llm_tool(name: str = None):
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler)
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
+3
View File
@@ -32,6 +32,9 @@ class StarMetadata:
'''Star 的根目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留 Star'''
activated: bool = True
'''是否被激活'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
+17 -8
View File
@@ -1,11 +1,12 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Awaitable, List, Dict
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
class StarHandlerRegistry(List):
T = TypeVar('T', bound='StarHandlerMetadata')
class StarHandlerRegistry(Generic[T], List[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -16,9 +17,18 @@ class StarHandlerRegistry(List):
super().append(handler)
self.star_handlers_map[handler.handler_full_name] = handler
def get_handlers_by_event_type(self, event_type: EventType) -> List[StarHandlerMetadata]:
def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]:
'''通过事件类型获取 Handler'''
return [handler for handler in self if handler.event_type == event_type]
if only_activated:
return [
handler
for handler in self
if handler.event_type == event_type and
star_map[handler.handler_module_path] and
star_map[handler.handler_module_path].activated
]
else:
return [handler for handler in self if handler.event_type == event_type]
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
'''通过 Handler 的全名获取 Handler'''
@@ -26,8 +36,7 @@ class StarHandlerRegistry(List):
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for handler in self if handler.handler_module_str == module_name]
return [handler for handler in self if handler.handler_module_path == module_name]
star_handlers_registry = StarHandlerRegistry()
@@ -55,7 +64,7 @@ class StarHandlerMetadata():
handler_name: str
'''Handler 的名字,也就是方法名'''
handler_module_str: str
handler_module_path: str
'''Handler 所在的模块路径。'''
handler: Awaitable
+101 -18
View File
@@ -1,6 +1,7 @@
import inspect
import functools
import os
import sys
import traceback
import yaml
import logging
@@ -8,15 +9,14 @@ from types import ModuleType
from typing import List
from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger
from astrbot.core import logger, sp
from .context import Context
from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from astrbot.core.provider.register import llm_tools
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
class PluginManager:
def __init__(
@@ -27,6 +27,7 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self # 就这样吧,不想改了
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
@@ -101,7 +102,7 @@ class PluginManager:
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend(self.config.pip_install_arg)
args.extend([self.config.pip_install_arg])
result_code = pip_main(args)
if result_code != 0:
raise Exception(str(result_code))
@@ -136,15 +137,29 @@ class PluginManager:
return metadata
def reload(self):
async def reload(self):
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_handlers_registry.star_handlers_map.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"
fail_rec = ""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 导入 Star 模块,并尝试实例化 Star 类
for plugin_module in plugin_modules:
try:
@@ -171,21 +186,24 @@ class PluginManager:
if path in star_map:
# 通过装饰器的方式注册插件
star_metadata = star_map[path]
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
star_metadata.module = module
star_metadata.root_dir_name = root_dir_name
star_metadata.reserved = reserved
metadata = star_map[path]
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
related_handlers = star_handlers_registry.get_handlers_by_module_name(star_metadata.module_path)
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
for handler in related_handlers:
logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}")
logger.debug(f"bind handler {handler.handler_name} to {metadata.name}")
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
handler.handler = functools.partial(handler.handler, star_metadata.star_cls)
handler.handler = functools.partial(handler.handler, metadata.star_cls)
# llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == star_metadata.module_path:
func_tool.handler = functools.partial(func_tool.handler, star_metadata.star_cls)
if func_tool.handler.__module__ == metadata.module_path:
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(func_tool.handler, metadata.star_cls)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
else:
# v3.4.0 以前的方式注册插件
@@ -209,6 +227,13 @@ class PluginManager:
star_map[path] = metadata
star_registry.append(metadata)
logger.debug(f"插件 {root_dir_name} 载入成功。")
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 执行 initialize 函数
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
except BaseException as e:
traceback.print_exc()
@@ -225,10 +250,11 @@ class PluginManager:
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
self._check_plugin_dept_update()
# reload the plugin
await self.reload()
return plugin_path
def uninstall_plugin(self, plugin_name: str):
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -237,10 +263,26 @@ class PluginManager:
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
del star_map[plugin.module_path]
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
del star_map[plugin_module_path]
for i, p in enumerate(star_registry):
if p.name == plugin_name:
del star_registry[i]
break
for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path):
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
star_handlers_registry.remove(handler)
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)]
for k in keys_to_delete:
v = star_handlers_registry.star_handlers_map[k]
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
del star_handlers_registry.star_handlers_map[k]
async def update_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
@@ -250,6 +292,46 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
await self.reload()
async def turn_off_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
inactivated_plugins.remove(plugin.module_path)
sp.put("inactivated_plugins", inactivated_plugins)
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = True
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
@@ -262,3 +344,4 @@ class PluginManager:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
+1 -2
View File
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
logger.info(f"删除临时文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:
+78
View File
@@ -0,0 +1,78 @@
import json
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession()
self.headers = {
"Authorization": f"Bearer {self.api_key}",
}
async def chat_messages(
self,
inputs: Dict,
query: str,
user: str,
response_mode: str = "streaming",
conversation_id: str = "",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
) -> AsyncGenerator[Dict[str, Any], None]:
url = f"{self.api_base}/chat-messages"
payload = locals()
payload.pop("self")
payload.pop("timeout")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
async def workflow_run(
self,
inputs: Dict,
user: str,
response_mode: str = "streaming",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
):
url = f"{self.api_base}/workflows/run"
payload = locals()
payload.pop("self")
payload.pop("timeout")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
async def file_upload(
self,
file_path: str,
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
+11 -1
View File
@@ -5,6 +5,7 @@ import socket
import time
import aiohttp
import base64
import zipfile
from PIL import Image
@@ -96,7 +97,9 @@ async def download_file(url: str, path: str):
'''
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
async with session.get(url, timeout=20) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
@@ -123,3 +126,10 @@ def get_local_ip_addresses():
finally:
s.close()
return ip
async def download_dashboard():
'''下载管理面板文件'''
dashboard_release_url = "https://astrbot-registry.lwl.lol/download/astrbot-dashboard/latest/dist.zip"
await download_file(dashboard_release_url, "data/dashboard.zip")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
@@ -22,6 +22,9 @@ class ParameterValidationMixin:
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
+33
View File
@@ -0,0 +1,33 @@
import json
import os
class SharedPreferences:
def __init__(self, path="data/shared_preferences.json"):
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
with open(self.path, "r") as f:
return json.load(f)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4)
def get(self, key, default=None):
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
+1 -1
View File
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,从默认分支下载。")
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
+7
View File
@@ -7,6 +7,7 @@ from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -121,6 +122,12 @@ class ConfigRoute(Route):
async def _get_astrbot_config(self):
config = self.config
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
return {
"metadata": CONFIG_METADATA_2,
"config": config
+30 -8
View File
@@ -16,7 +16,9 @@ class PluginRoute(Route):
'/plugin/install-upload': ('POST', self.install_plugin_upload),
'/plugin/update': ('POST', self.update_plugin),
'/plugin/uninstall': ('POST', self.uninstall_plugin),
'/plugin/market_list': ('GET', self.get_online_plugins)
'/plugin/market_list': ('GET', self.get_online_plugins),
'/plugin/off': ('POST', self.off_plugin),
'/plugin/on': ('POST', self.on_plugin)
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -42,7 +44,8 @@ class PluginRoute(Route):
"author": plugin.author,
"desc": plugin.desc,
"version": plugin.version,
"reserved": plugin.reserved
"reserved": plugin.reserved,
"activated": plugin.activated
}
_plugin_resp.append(_t)
return Response().ok(_plugin_resp).__dict__
@@ -53,7 +56,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -69,7 +71,6 @@ class PluginRoute(Route):
await file.save(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
self.core_lifecycle.restart()
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
@@ -80,7 +81,7 @@ class PluginRoute(Route):
plugin_name = post_data["name"]
try:
logger.info(f"正在卸载插件 {plugin_name}")
self.plugin_manager.uninstall_plugin(plugin_name)
await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:
@@ -93,9 +94,30 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
self.core_lifecycle.restart()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def off_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_off_plugin(plugin_name)
logger.info(f"停用插件 {plugin_name}")
return Response().ok(None, "停用成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/off: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def on_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_on_plugin(plugin_name)
logger.info(f"启用插件 {plugin_name}")
return Response().ok(None, "启用成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/on: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+6 -2
View File
@@ -32,6 +32,7 @@ class UpdateRoute(Route):
async def update_project(self):
data = await request.json
version = data.get('version', '')
reboot = data.get('reboot', True)
if version == "" or version == "latest":
latest = True
version = ''
@@ -39,8 +40,11 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
if reboot:
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
else:
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
-1
View File
@@ -18,7 +18,6 @@ class AstrBotDashboard():
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist"))
logger.info(f"Dashboard data path: {self.data_path}")
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
self.app.json.sort_keys = False
self.app.before_request(self.auth_middleware)
+14
View File
@@ -0,0 +1,14 @@
# What's Changed
1. 修复了 reminder 插件可能不会触发回调的问题。
2. 修复了 telegram 插件不可用的问题。
3. 修复了 qq_official 无法发图的问题。
4. 修复事件监听器会让 WakeStage 失效的问题。
5. 修复 websearch 在 cmd_config 中失效的问题。
3. 支持通过 Google GenAI 访问 Gemini 模型,而不需要使用 Gemini 对 OpenAI 的兼容 API。详见文档。
4. 支持对插件禁用/启用。/plugin off/on <plugin_name>
5. 支持基于 Docker 的沙箱化代码执行器。(Beta 测试)详见文档。
6. 支持接入 Dify LLMOps 平台。详见文档。
7. 适配器类插件支持设置默认配置模板。
8. 优化了部分指令的持久化记忆。如 /tool 的禁用、/provider 的选择都将持久化保存,每次启动时不需要重新设置。
9. 优化了 glm-4v-flash 模型。其只支持一张图。
+1 -3
View File
@@ -1,4 +1,2 @@
node_modules/
.DS_Store
package-lock.json
package.json
.DS_Store
+9981
View File
File diff suppressed because it is too large Load Diff
+59
View File
@@ -0,0 +1,59 @@
{
"name": "astrbot-dashboard",
"version": "1.0.0",
"private": true,
"author": "CodedThemes",
"scripts": {
"dev": "vite --host",
"build": "vue-tsc --noEmit && vite build",
"build-stage": "vue-tsc --noEmit && vite build --base=/vue/free/stage/",
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
"preview": "vite preview --port 5050",
"typecheck": "vue-tsc --noEmit",
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
},
"dependencies": {
"@guolao/vue-monaco-editor": "^1.5.4",
"@tiptap/starter-kit": "2.1.7",
"@tiptap/vue-3": "2.1.7",
"apexcharts": "3.42.0",
"axios": "^1.6.2",
"axios-mock-adapter": "^1.22.0",
"chance": "1.1.11",
"date-fns": "2.30.0",
"js-md5": "^0.8.3",
"lodash": "4.17.21",
"pinia": "2.1.6",
"remixicon": "3.5.0",
"vee-validate": "4.11.3",
"vite-plugin-vuetify": "1.0.2",
"vue": "3.3.4",
"vue-router": "4.2.4",
"vue3-apexcharts": "1.4.4",
"vue3-print-nb": "0.1.4",
"vuetify": "3.3.14",
"xterm": "^5.3.0",
"xterm-addon-fit": "^0.8.0",
"yup": "1.2.0"
},
"devDependencies": {
"@mdi/font": "7.2.96",
"@rushstack/eslint-patch": "1.3.3",
"@types/chance": "1.1.3",
"@types/node": "20.5.7",
"@vitejs/plugin-vue": "4.3.3",
"@vue/eslint-config-prettier": "8.0.0",
"@vue/eslint-config-typescript": "11.0.3",
"@vue/tsconfig": "0.4.0",
"eslint": "8.48.0",
"eslint-plugin-vue": "9.17.0",
"prettier": "3.0.2",
"sass": "1.66.1",
"sass-loader": "13.3.2",
"typescript": "5.1.6",
"vite": "4.4.9",
"vue-cli-plugin-vuetify": "2.5.8",
"vue-tsc": "1.8.8",
"vuetify-loader": "^2.0.0-alpha.9"
}
}
@@ -10,7 +10,7 @@
</v-alert>
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
<div style="width: 100%;">
<div style="width: 100%;" v-if="metadata[metadataKey].items[key]">
<v-select v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible" v-model="iterable[key]"
variant="outlined" :items="metadata[metadataKey].items[key]?.options"
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" dense :disabled="metadata[metadataKey].items[key]?.readonly"></v-select>
@@ -46,6 +46,11 @@
</div>
</div>
<div style="width: 100%;" v-else>
<!-- metadata 中没有 key -->
<v-text-field v-model="iterable[key]" :label="key" variant="outlined" dense></v-text-field>
</div>
<div
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object' && !metadata[metadataKey].items[key]?.invisible">
<v-btn icon size="x-small" style="margin-bottom: 22px;">
+33 -1
View File
@@ -29,7 +29,9 @@ import axios from 'axios';
<v-btn variant="plain" @click="updateExtension(extension.name)">更新</v-btn>
<v-btn variant="plain" @click="uninstallExtension(extension.name)">卸载</v-btn>
</div>
<span v-else>保留插件</span>
<!-- <span v-else>保留插件</span> -->
<v-btn variant="plain" v-if="extension.activated" @click="pluginOff(extension)">禁用</v-btn>
<v-btn variant="plain" v-else @click="pluginOn(extension)">启用</v-btn>
</div>
</ExtensionCard>
</v-col>
@@ -329,6 +331,36 @@ export default {
this.toast(err, "error");
});
},
pluginOn(extension) {
axios.post('/api/plugin/on',
{
name: extension.name
}).then((res) => {
if (res.data.status === "error") {
this.toast(res.data.message, "error");
return;
}
this.toast(res.data.message, "success");
this.getExtensions();
}).catch((err) => {
this.toast(err, "error");
});
},
pluginOff(extension) {
axios.post('/api/plugin/off',
{
name: extension.name
}).then((res) => {
if (res.data.status === "error") {
this.toast(res.data.message, "error");
return;
}
this.toast(res.data.message, "success");
this.getExtensions();
}).catch((err) => {
this.toast(err, "error");
});
},
openExtensionConfig(extension_name) {
this.curr_namespace = extension_name;
this.configDialog = true;
+14 -18
View File
@@ -1,4 +1,3 @@
import os
import asyncio
import sys
@@ -8,6 +7,8 @@ import zipfile
from astrbot.dashboard import AstrBotDashBoardLifecycle
from astrbot.core import db_helper
from astrbot.core import logger, LogManager, LogBroker
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard
# add parent path to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -39,25 +40,20 @@ def check_env():
async def check_dashboard_files():
'''下载管理面板文件'''
if os.path.exists("data/dist"):
return
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
logger.info("开始下载管理面板文件...")
async with aiohttp.ClientSession() as session:
async with session.get(dashboard_release_url) as resp:
if resp.status != 200:
logger.error(f"下载管理面板文件失败: {resp.status}")
with open("data/dashboard.zip", "wb") as f:
f.write(await resp.read())
logger.info("管理面板文件下载完成。")
ok = True
if not ok:
logger.critical("下载管理面板文件失败")
if os.path.exists("data/dist/assets/version"):
with open("data/dist/assets/version", "r") as f:
if f.read() != VERSION:
logger.warning("检测到管理面板有更新。可以使用 /dashboard update 命令更新。")
return
# unzip
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
logger.info("开始下载管理面板文件...")
try:
await download_dashboard()
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return
logger.info("管理面板下载完成。")
if __name__ == "__main__":
+77 -16
View File
@@ -3,8 +3,9 @@ import datetime
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import personalities
from astrbot.api import personalities, sp
from astrbot.api.provider import Personality, ProviderRequest
from astrbot.core.utils.io import download_dashboard
from typing import Union
@@ -16,6 +17,8 @@ class Main(star.Star):
self.prompt_prefix = cfg['provider_settings']['prompt_prefix']
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
self.kdb_enabled = False
async def _query_astrbot_notice(self):
try:
@@ -42,6 +45,7 @@ class Main(star.Star):
/deop <admin_id>: 取消管理员
/wl <sid>: 添加会话白名单
/dwl <sid>: 删除会话白名单
/dashboard update: 更新管理面板
[大模型]
/provider: 查看、切换大模型提供商
@@ -87,24 +91,41 @@ class Main(star.Star):
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent, oper: str = None):
if oper is None:
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
if oper1 is None:
plugin_list_info = "已加载的插件:\n"
for plugin in self.context.get_all_stars():
plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}\n"
if plugin_list_info.strip() == "":
plugin_list_info = "没有加载任何插件。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False))
else:
plugin = self.context.get_registered_star(oper)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
if oper1 == "off":
# 禁用插件
if oper2 is None:
event.set_result(MessageEventResult().message("/plugin off <插件名> 禁用插件。"))
return
await self.context._star_manager.turn_off_plugin(oper2)
event.set_result(MessageEventResult().message(f"插件 {oper2} 已禁用。"))
elif oper1 == "on":
# 启用插件
if oper2 is None:
event.set_result(MessageEventResult().message("/plugin on <插件名> 启用插件。"))
return
await self.context._star_manager.turn_on_plugin(oper2)
event.set_result(MessageEventResult().message(f"插件 {oper2} 已启用。"))
else:
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
ret = f"插件 {oper} 帮助信息:\n" + help_msg
event.set_result(MessageEventResult().message(ret).use_t2i(False))
# 获取插件帮助
plugin = self.context.get_registered_star(oper1)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
else:
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
event.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
@@ -167,8 +188,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
if idx is None:
ret = "## 当前载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
ret += f"{idx + 1}. {llm.meta().id} ({llm.meta().model})"
if self.provider == llm:
id_ = llm.meta().id
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
if self.context.get_using_provider().meta().id == id_:
ret += " (当前使用)"
ret += "\n"
@@ -178,9 +200,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
self.context.provider_manager.curr_provider_inst = self.context.get_all_providers()[idx - 1]
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
self.context.provider_manager.curr_provider_inst = provider
sp.put("curr_provider", id_)
event.set_result(MessageEventResult().message(f"成功切换到 {self.context.provider_manager.curr_provider_inst.meta().id}"))
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
@@ -289,7 +314,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
- 重置 LLM 会话(保留人格): /reset p
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
"""))
""").use_t2i(False))
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
@@ -318,6 +343,13 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
name="自定义人格", prompt=ps)
message.set_result(
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard update")
async def update_dashboard(self, event: AstrMessageEvent):
yield event.plain_result("正在尝试更新管理面板...")
await download_dashboard()
yield event.plain_result("管理面板更新完成。")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
@@ -337,4 +369,33 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
event.stop_event()
event.stop_event()
@filter.command_group("kdb")
def kdb(self):
pass
@kdb.command("on")
async def on_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = True
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if not curr_kdb_name:
yield event.plain_result("未载入任何知识库")
else:
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
@kdb.command("off")
async def off_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = False
yield event.plain_result("知识库已关闭")
@filter.on_llm_request()
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if self.kdb_enabled and curr_kdb_name:
mgr = self.context.knowledge_db_manager
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
if results:
req.system_prompt += "\nHere are documents that related to user's query: \n"
for result in results:
req.system_prompt += f"- {result}\n"
+382
View File
@@ -0,0 +1,382 @@
import os
import json
import shutil
import aiohttp
import uuid
import asyncio
import re
import astrbot.api.star as star
import aiodocker
from collections import defaultdict
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
from astrbot.api.event import filter
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Image, File
PROMPT = """
## Task
You need to generate python codes to solve user's problem: {prompt}
{extra_input}
## Limit
1. Available libraries:
- standard libs
- `Pillow`
- `requests`
- `numpy`
- `matplotlib`
- `scipy`
- `scikit-learn`
- `beautifulsoup4`
- `pandas`
- `opencv-python`
- `python-docx`
- `python-pptx`
- `pymupdf` (Do not use fpdf, reportlab, etc.)
- `mplfonts`
You can only use these libraries and the libraries that they depend on.
2. Do not generate malicious code.
3. Use given `shared.api` package to output the result.
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
For Image and file, you must save it to `output` folder.
4. You must only output the code, do not output the result of the code and any other information.
5. The output language is same as user's input language.
6. Please first provide relevant knowledge about user's problem appropriately.
## Example
1. User's problem: `please solve the fabonacci sequence problem.`
Output:
```python
from shared.api import send_text, send_image, send_file
def fabonacci(n):
if n <= 1:
return n
else:
return fabonacci(n-1) + fabonacci(n-2)
result = fabonacci(10)
send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.")
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
```
2. User's problem: `please draw a sin(x) function.`
Output:
```python
from shared.api import send_text, send_image, send_file
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)
plt.plot(x, y)
plt.savefig("output/sin_x.png")
send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).")
send_image("output/sin_x.png") # send_image is a function to send image to user
send_text("If you need more information, please let me know :)")
```
{extra_prompt}
"""
DEFAULT_CONFIG = {
"sandbox": {
"image": "soulter/astrbot-code-interpreter-sandbox",
"docker_mirror": "", # cjie.eu.org
}
}
PATH = "data/config/python_interpreter.json"
@star.register(name="astrbot-python-interpreter", desc="Python 代码执行器", author="Soulter", version="0.0.1")
class Main(star.Star):
'''基于 Docker 沙箱的 Python 代码执行器'''
def __init__(self, context: star.Context) -> None:
self.context = context
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
self.workplace_path = os.path.join(self.curr_dir, "workplace")
self.shared_path = os.path.join(self.curr_dir, "shared")
os.makedirs(self.workplace_path, exist_ok=True)
self.user_file_msg_buffer = defaultdict(list)
'''存放用户上传的文件'''
# 加载配置
if not os.path.exists(PATH):
self.config = DEFAULT_CONFIG
self._save_config()
else:
with open(PATH, "r") as f:
self.config = json.load(f)
async def initialize(self):
ok = await self.is_docker_available()
if not ok:
logger.warning("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。")
await self.context._star_manager.turn_off_plugin("astrbot-python-interpreter")
async def file_upload(self, file_path: str):
'''
上传图像文件到 S3
'''
ext = os.path.splitext(file_path)[1]
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
with open(file_path, "rb") as f:
file = f.read()
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
async with session.put(s3_file_url, data=file) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
return s3_file_url
async def is_docker_available(self) -> bool:
'''Check if docker is available'''
try:
docker = aiodocker.Docker()
await docker.version()
return True
except aiodocker.exceptions.DockerError as e:
logger.error(f"检查 Docker 可用性时出现问题: {e}")
return False
async def get_image_name(self) -> str:
'''Get the image name'''
if self.config["sandbox"]["docker_mirror"]:
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
return self.config["sandbox"]["image"]
async def _save_config(self):
with open(PATH, "w") as f:
json.dump(self.config, f)
async def gen_magic_code(self) -> str:
return uuid.uuid4().hex[:8]
async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str:
'''Download image from url to workplace_path'''
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
if resp.status != 200:
return ""
image_path = os.path.join(workplace_path, f"{filename}.jpg")
with open(image_path, 'wb') as f:
f.write(await resp.read())
return f"{filename}.jpg"
async def tidy_code(self, code: str) -> str:
'''Tidy the code'''
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code, re.DOTALL)
if match is None:
raise ValueError("The code is not in the code block.")
return match.group(1)
@filter.event_message_type(filter.EventMessageType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''处理消息'''
for comp in event.message_obj.message:
if isinstance(comp, File):
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
logger.debug(f"User uploaded file: {comp.file}")
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
@filter.on_llm_request()
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
def pi(self):
pass
@pi.command("mirror")
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
'''Docker 镜像地址'''
if not url:
yield event.plain_result(f"""当前 Docker 镜像地址: {self.config['sandbox']['docker_mirror']}
使用 `pi mirror <url>` 来设置 Docker 镜像地址。
您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。
""")
else:
self.config["sandbox"]["docker_mirror"] = url
await self._save_config()
yield event.plain_result("设置 Docker 镜像地址成功。")
@pi.command("repull")
async def pi_repull(self, event: AstrMessageEvent):
'''重新拉取沙箱镜像'''
docker = aiodocker.Docker()
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
await docker.images.delete(image_name, force=True)
except aiodocker.exceptions.DockerError:
pass
await docker.images.pull(image_name)
yield event.plain_result("重新拉取沙箱镜像成功。")
@llm_tool("python_interpreter")
async def python_interpreter(self, event: AstrMessageEvent):
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
'''
if not await self.is_docker_available():
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
plain_text = event.message_str
# 创建必要的工作目录和幻术码
magic_code = await self.gen_magic_code()
workplace_path = os.path.join(self.workplace_path, magic_code)
output_path = os.path.join(workplace_path, "output")
os.makedirs(workplace_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
# 图片
images = []
idx = 1
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
if image_url.startswith("http"):
image_path = await self.download_image(image_url, workplace_path, f"img_{idx}")
if image_path:
images.append(image_path)
idx += 1
# 文件
files = []
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
# cp
file_name = os.path.basename(file_path)
shutil.copy(file_path, os.path.join(workplace_path, file_name))
files.append(file_name)
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
# 整理额外输入
extra_inputs = ""
if images:
extra_inputs += f"User provided images: {images}\n"
if files:
extra_inputs += f"User provided files: {files}\n"
obs = ""
n = 5
for i in range(n):
if i > 0:
logger.info(f"Try {i+1}/{n}")
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=extra_inputs,
extra_prompt=obs,
)
provider = self.context.get_using_provider()
llm_response = await provider.text_chat(prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}")
logger.debug("code interpreter llm gened code:" + llm_response.completion_text)
# 整理代码并保存
code_clean = await self.tidy_code(llm_response.completion_text)
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
f.write(code_clean)
# 启动容器
docker = aiodocker.Docker()
# 检查有没有image
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
except aiodocker.exceptions.DockerError:
# 拉取镜像
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
await docker.images.pull(image_name)
yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})")
container = await docker.containers.run({
"Image": image_name,
"Cmd": ["python", "exec.py"],
"Memory": 512 * 1024 * 1024,
"NanoCPUs": 1000000000,
"HostConfig": {
"Binds": [
f"{self.shared_path}:/astrbot_sandbox/shared:ro",
f"{output_path}:/astrbot_sandbox/output:rw",
f"{workplace_path}:/astrbot_sandbox:rw",
]
},
"Env": [
f"MAGIC_CODE={magic_code}"
],
"AutoRemove": True
})
logger.debug(f"Container {container.id} created.")
logs = await self.run_container(container)
logger.debug(f"Container {container.id} finished.")
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
match = re.match(pattern, log)
if match:
ok = True
if match.group(1) == "TEXT":
yield event.plain_result(match.group(2))
elif match.group(1) == "IMAGE":
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending file: {file_path}")
file_s3_url = await self.file_upload(file_path)
logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_s3_url)]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log \
or "[Error]: " in log:
traceback = "\n".join(logs[idx:])
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
break
else:
return
yield event.plain_result("经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。")
async def run_container(self, container: aiodocker.docker.DockerContainer, timeout: int = 20) -> list[str]:
'''Run the container and get the output'''
try:
await container.wait(timeout=timeout)
logs = await container.log(stdout=True, stderr=True)
return logs
except asyncio.TimeoutError:
logger.warning(f"Container {container.id} timeout.")
await container.kill()
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
finally:
await container.delete()
+18
View File
@@ -0,0 +1,18 @@
import os
def _get_magic_code():
'''防止注入攻击'''
return os.getenv("MAGIC_CODE")
def send_text(text: str):
print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}")
def send_image(image_path: str):
if not os.path.exists(image_path):
raise Exception(f"Image file not found: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
def send_file(file_path: str):
if not os.path.exists(file_path):
raise Exception(f"File not found: {file_path}")
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")
+27 -4
View File
@@ -31,9 +31,21 @@ class Main(star.Star):
if "datetime" in reminder:
if self.check_is_outdated(reminder):
continue
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"], reminder], run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
self.scheduler.add_job(
self._reminder_callback,
trigger='date',
args=[reminder["text"], reminder],
run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"),
misfire_grace_time=60
)
elif "cron" in reminder:
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"], reminder], **self._parse_cron_expr(reminder["cron"]))
self.scheduler.add_job(
self._reminder_callback,
trigger='cron',
args=[reminder["text"], reminder],
misfire_grace_time=60,
**self._parse_cron_expr(reminder["cron"])
)
def check_is_outdated(self, reminder: dict):
'''Check if the reminder is outdated.'''
@@ -75,14 +87,25 @@ class Main(star.Star):
if cron_expression:
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
self.reminder_data[event.unified_msg_origin].append(d)
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d])
self.scheduler.add_job(
self._reminder_callback,
'cron',
misfire_grace_time=60,
**self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d]
)
if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else:
d = { "text": text, "datetime": datetime_str }
self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], run_date=datetime_scheduled)
self.scheduler.add_job(
self._reminder_callback,
'date',
args=[event.unified_msg_origin, d],
run_date=datetime_scheduled,
misfire_grace_time=60
)
reminder_time = datetime_str
await self._save_data()
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
+9
View File
@@ -22,6 +22,15 @@ class Main(star.Star):
self.sogo_search = Sogo()
self.google = Google()
async def initialize(self):
websearch = self.context.get_config()['provider_settings']['web_search']
if websearch:
self.context.activate_llm_tool("web_search")
self.context.activate_llm_tool("fetch_url")
else:
self.context.deactivate_llm_tool("web_search")
self.context.deactivate_llm_tool("fetch_url")
async def _tidy_text(self, text: str) -> str:
'''清理文本,去除空格、换行符等'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
+2 -1
View File
@@ -15,4 +15,5 @@ colorlog
aiocqhttp
pyjwt
apscheduler
docstring_parser
docstring_parser
aiodocker
+148
View File
@@ -0,0 +1,148 @@
import pytest
import os
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core import LogBroker
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@pytest.fixture(scope="module")
def core_lifecycle_td():
db = SQLiteDatabase("data/data_v3.db")
log_broker = LogBroker()
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle_td
@pytest.fixture(scope="module")
def app(core_lifecycle_td):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle_td, db)
return server.app
@pytest.fixture(scope="module")
def header():
return {}
@pytest.mark.asyncio
async def test_init_core_lifecycle_td(core_lifecycle_td):
await core_lifecycle_td.initialize()
assert core_lifecycle_td is not None
@pytest.mark.asyncio
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict):
test_client = app.test_client()
response = await test_client.post('/api/auth/login', json={
"username": "wrong",
"password": "password"
})
data = await response.get_json()
assert data['status'] == 'error'
response = await test_client.post('/api/auth/login', json={
"username": core_lifecycle_td.astrbot_config['dashboard']['username'],
"password": core_lifecycle_td.astrbot_config['dashboard']['password']
})
data = await response.get_json()
assert data['status'] == 'ok' and 'token' in data['data']
header['Authorization'] = f"Bearer {data['data']['token']}"
@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/stat/get')
assert response.status_code == 401
response = await test_client.get('/api/stat/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok' and 'platform' in data['data']
@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get('/api/plugin/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件市场
response = await test_client.get('/api/plugin/market_list', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件安装
response = await test_client.post('/api/plugin/install', json={
"url": "https://github.com/Soulter/astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# 插件更新
response = await test_client.post('/api/plugin/update', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件卸载
response = await test_client.post('/api/plugin/uninstall', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/update/check', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'success'
@pytest.mark.asyncio
async def test_do_update(app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post('/api/update/do', headers=header, json={
"version": "latest"
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'error' # 已经是最新版本
response = await test_client.post('/api/update/do', headers=header, json={
"version": "v3.4.0",
"reboot": False
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
assert os.path.exists("data/astrbot_release/astrbot")
+48
View File
@@ -0,0 +1,48 @@
import os
import sys
import pytest
from unittest import mock
from main import check_env, check_dashboard_files
class _version_info():
def __init__(self, major, minor):
self.major = major
self.minor = minor
def test_check_env(monkeypatch):
version_info_correct = _version_info(3, 10)
version_info_wrong = _version_info(3, 9)
monkeypatch.setattr(sys, 'version_info', version_info_correct)
with mock.patch('os.makedirs') as mock_makedirs:
check_env()
mock_makedirs.assert_any_call("data/config", exist_ok=True)
mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
mock_makedirs.assert_any_call("data/temp", exist_ok=True)
monkeypatch.setattr(sys, 'version_info', version_info_wrong)
with pytest.raises(SystemExit):
check_env()
@pytest.mark.asyncio
async def test_check_dashboard_files(monkeypatch):
monkeypatch.setattr(os.path, 'exists', lambda x: False)
async def mock_get(*args, **kwargs):
class MockResponse:
status = 200
async def read(self):
return b'content'
return MockResponse()
with mock.patch('aiohttp.ClientSession.get', new=mock_get):
with mock.patch('builtins.open', mock.mock_open()) as mock_file:
with mock.patch('zipfile.ZipFile.extractall') as mock_extractall:
async def mock_aenter(_):
await check_dashboard_files()
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
mock_extractall.assert_called_once()
async def mock_aexit(obj, exc_type, exc, tb):
return
mock_extractall.__aenter__ = mock_aenter
mock_extractall.__aexit__ = mock_aexit
+226
View File
@@ -0,0 +1,226 @@
import pytest
import logging
import os
import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.message.components import Plain, At
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from asyncio import Queue
SESSION_ID_IN_WHITELIST = "test_sid_wl"
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
TEST_LLM_PROVIDER = {
"id": "zhipu_default",
"type": "openai_chat_completion",
"enable": True,
"key": [os.getenv("ZHIPU_API_KEY")],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
}
TEST_COMMANDS = [
["help", "已注册的 AstrBot 内置指令"],
["tool ls", "函数工具"],
["tool on websearch", "激活工具"],
["tool off websearch", "停用工具"],
["plugin", "已加载的插件"],
["t2i", "文本转图片模式"],
["sid", "此 ID 可用于设置会话白名单。"],
["op test_op", "授权成功。"],
["deop test_op", "取消授权成功。"],
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
["provider", "当前载入的 LLM 提供商"],
["reset", "重置成功"],
# ["model", "查看、切换提供商模型列表"],
["history", "历史记录:"],
["key", "当前 Key"],
["persona", "[Persona]"]
]
class FakeAstrMessageEvent(AstrMessageEvent):
def __init__(self, abm: AstrBotMessage = None):
meta = PlatformMetadata("test_platform", "test")
super().__init__(
message_str=abm.message_str,
message_obj=abm,
platform_meta=meta,
session_id=abm.session_id
)
async def send(self, message: MessageChain):
await super().send(message)
@staticmethod
def create_fake_event(
message_str: str,
session_id: str = "test_sid",
is_at: bool = False,
is_group: bool = False,
sender_id: str = "123456"
):
abm = AstrBotMessage()
abm.message_str = message_str
abm.group_id = "test"
abm.message = [Plain(message_str)]
if is_at:
abm.message.append(At(qq="bot"))
abm.self_id = "bot"
abm.sender = MessageMember(sender_id, "mika")
abm.timestamp = 1234567890
abm.message_id = "test"
abm.session_id = session_id
if is_group:
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
return FakeAstrMessageEvent(abm)
@pytest.fixture(scope="module")
def event_queue():
return Queue()
@pytest.fixture(scope="module")
def config():
cfg = AstrBotConfig()
cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"]
cfg['admins_id'] = ["123456"]
cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"]
cfg['provider'] = [TEST_LLM_PROVIDER]
return cfg
@pytest.fixture(scope="module")
def db():
return SQLiteDatabase("data/data_v3.db")
@pytest.fixture(scope="module")
def platform_manager(event_queue, config):
return PlatformManager(config, event_queue)
@pytest.fixture(scope="module")
def provider_manager(config, db):
return ProviderManager(config, db)
@pytest.fixture(scope="module")
def star_context(event_queue, config, db, platform_manager, provider_manager):
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
return star_context
@pytest.fixture(scope="module")
def plugin_manager(star_context, config):
plugin_manager = PluginManager(star_context, config)
# await plugin_manager.reload()
asyncio.run(plugin_manager.reload())
return plugin_manager
@pytest.fixture(scope="module")
def pipeline_context(config, plugin_manager):
return PipelineContext(config, plugin_manager)
@pytest.fixture(scope="module")
def pipeline_scheduler(pipeline_context):
return PipelineScheduler(pipeline_context)
@pytest.mark.asyncio
async def test_platform_initialization(platform_manager: PlatformManager):
await platform_manager.initialize()
@pytest.mark.asyncio
async def test_provider_initialization(provider_manager: ProviderManager):
await provider_manager.initialize()
@pytest.mark.asyncio
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
await pipeline_scheduler.initialize()
@pytest.mark.asyncio
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
'''测试唤醒'''
# 群聊无 @ 无指令
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages)
# 群聊有 @ 无指令
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
# 群聊有指令
mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST)
await pipeline_scheduler.execute(mock_event)
assert mock_event._has_send_oper is True
@pytest.mark.asyncio
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
@pytest.mark.asyncio
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
# 测试默认屏蔽词
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
# 测试额外屏蔽词
mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" not in message for message in caplog.messages)
# TODO: 测试 百度AI 的内容安全检查
@pytest.mark.asyncio
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert mock_event.get_result() is not None
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
@pytest.mark.asyncio
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages)
@pytest.mark.asyncio
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
for command in TEST_COMMANDS:
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
assert any(command[1] in message for message in caplog.messages)
+80
View File
@@ -0,0 +1,80 @@
import pytest
import os
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
from astrbot.core.star.context import Context
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase
from asyncio import Queue
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase("data/data_v3.db")
star_context = Context(event_queue, config, db)
@pytest.fixture
def plugin_manager_pm():
return PluginManager(star_context, config)
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
assert plugin_manager_pm is not None
assert plugin_manager_pm.context is not None
assert plugin_manager_pm.config is not None
@pytest.mark.asyncio
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = await plugin_manager_pm.reload()
assert success is True
assert err_message is None
assert len(star_handlers_registry) > 0 # package
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):
'''测试插件安装和重载'''
os.makedirs("data/plugins", exist_ok=True)
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert plugin_path is not None
assert os.path.exists(plugin_path)
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# shutil.rmtree(plugin_path)
# install plugin which is not exists
with pytest.raises(Exception):
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
# update
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
# uninstall
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
assert not os.path.exists(plugin_path)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
with pytest.raises(Exception):
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
# TODO: file installation