Compare commits

...

60 Commits

Author SHA1 Message Date
Soulter 8995e62e73 🐛fix: 更新v-slot类型定义以增强类型安全性 2025-02-23 20:18:00 +08:00
Soulter 316147a8db v3.4.31 2025-02-23 20:11:39 +08:00
Soulter 1fdcfc7a30 Merge pull request #587 from Raven95676/master
🐛fix: 修复aiocqhttp_platform_adapter文件相关判断逻辑
2025-02-23 19:57:50 +08:00
Soulter 8e2c633cd4 feat: 前端支持以列表展示正式版和开发版的列表 2025-02-23 19:53:55 +08:00
渡鸦95676 786b0e4a54 Update aiocqhttp_platform_adapter.py
else尾随空格
2025-02-23 18:16:39 +08:00
Raven95676 c38c1c3c35 🐛fix: 修复aiocqhttp_platform_adapter文件相关判断逻辑 2025-02-23 18:05:45 +08:00
Soulter 7d856756f4 🐛 fix: 修复 gemini 请求时出现多次不支持函数工具调用最后 429 的问题 2025-02-23 17:24:37 +08:00
Soulter f0d1d365e0 Merge branch 'refactor-hot-load' 2025-02-23 17:04:36 +08:00
Soulter 8e2d666ff8 feat: 优化关于页面和配置页面样式,添加重启按钮功能 2025-02-23 16:57:48 +08:00
Soulter 38d7be1d5f feat: 优化提示框样式并更新关于页面内容 2025-02-23 16:29:57 +08:00
Soulter 431e2fad72 feat: 支持插件禁止默认的llm调用 #579 2025-02-23 16:10:32 +08:00
Soulter b3b63be8fc Merge pull request #584 from Soulter/refactor-hot-load
🍺 refactor: 支持更大范围的热重载以及管理面板将平台和提供商配置独立化
2025-02-23 15:56:04 +08:00
Soulter 071fc7d6ef feat: 调整适配器类型显示样式并添加API Base信息 2025-02-23 15:52:30 +08:00
Soulter 2a37f7edac feat: 在聊天页面添加粘贴图片的快捷键提示 2025-02-23 15:41:34 +08:00
Soulter c656ad5e2c feat: 消息平台和服务提供商页面支持显示日志 2025-02-23 15:27:05 +08:00
Soulter da14a89490 🍺 refactor: 支持更大范围的热重载以及管理面板将平台和提供商配置独立化 2025-02-23 12:54:25 +08:00
Soulter cf22eae467 fix: save config 2025-02-22 23:20:25 +08:00
Soulter b199bddb0b feat: 适配多节点的转发消息(OneBot V11) 2025-02-22 21:07:57 +08:00
崔永亮 2188ea82de feat: 支持 AstrBot 更新使用 Github 加速地址 2025-02-22 18:17:34 +08:00
Soulter 1fa13d0177 Merge pull request #577 from Soulter/perf-autoScroll-switch
perf: 添加控制台关闭自动滚动按钮
2025-02-22 17:16:52 +08:00
崔永亮 ed508af424 perf: 添加控制台关闭自动滚动按钮 2025-02-22 17:10:53 +08:00
Fridemn 5df26864d5 Merge pull request #574 from Soulter/perf-port-check
🎈 perf: 启动时检查端口占用
2025-02-22 17:01:53 +08:00
崔永亮 837111b17e perf: 填加具体占用进程显示 2025-02-22 16:23:50 +08:00
崔永亮 a6b363b433 🎈 perf: 启动时检查端口占用 2025-02-22 16:10:46 +08:00
Soulter 2807e1e892 feat: add template of FastGPT 2025-02-22 15:43:14 +08:00
Soulter 0a2abd8214 Merge pull request #572 from Soulter/feat-dashscope
支持阿里云百炼应用智能体、工作流
2025-02-22 15:04:46 +08:00
Soulter 8beb7acdb1 feat: 支持为 dify 和 dashscope 提供商设置默认固定变量 #552 2025-02-22 14:48:18 +08:00
Soulter 466c80b94d feat: 阿里云百炼应用工作流支持自定义动态变量 #552 2025-02-22 14:32:37 +08:00
Soulter 36c0cfc9a9 feat: 支持阿里云百炼应用智能体、工作流
#552
2025-02-22 14:08:51 +08:00
Soulter 35ba1b3345 fix: gewechat verify code 2025-02-22 11:37:34 +08:00
Soulter d00821d1c7 Update README.md 2025-02-22 10:07:18 +08:00
Soulter 6c1b3f242b Merge pull request #568 from Raven95676/master
🐛 fix: 修复webchat未处理base64的问题
2025-02-22 01:07:20 +08:00
Raven95676 9f9da1e0c9 🐛 fix: 修复webchat未处理base64的问题 2025-02-21 23:39:53 +08:00
崔永亮 14fb4b70bd feat: 支持 gewechat 设置验证码 #448 2025-02-21 23:08:23 +08:00
崔永亮 b1049540a4 feat: claude 支持纯图片 2025-02-21 22:26:31 +08:00
Fridemn 5e2909df33 Merge pull request #559 from Rt39/feat-claude-api
添加对Anthropic Claude API的支持
2025-02-21 21:12:52 +08:00
崔永亮 c122dad21f feat: 添加自定义api base 2025-02-21 21:07:59 +08:00
Rt39 48ae686602 feat: add claude template 2025-02-20 23:58:10 -05:00
Rt39 bf2c3a1a81 fix: 根据Codacy Production / Codacy Static Code Analysis修改格式问题 2025-02-20 21:15:07 -05:00
Rt39 96e7a93886 feat: 添加对Claude API的支持 2025-02-20 19:59:16 -05:00
Soulter dba1ed1e19 v3.4.30 2025-02-21 01:31:36 +08:00
Soulter a24514876b fix: 修复 dify 无法使用事件钩子的问题以及出现 GeneratorExit 的问题 #533 #264 2025-02-21 01:14:13 +08:00
Soulter 466a1c1c41 🐛 fix: 修复某些情况下导致插件报错 AttributeError 的问题 #549 2025-02-21 00:38:08 +08:00
Soulter a2d5e9f40f feat: add xAI template 2025-02-20 16:34:32 +08:00
Soulter 1bbff1d161 v3.4.29 2025-02-19 20:05:33 +08:00
Soulter 0948bae99b feat: 添加代码执行器 Docker 宿主机绝对路径配置及相关功能
Co-authored-by: Bocity <haolovej@vip.qq.com>
2025-02-19 19:56:31 +08:00
Soulter 850db41596 feat: gemini source 初步支持对 API Key 进行负载均衡请求 #534 2025-02-19 19:06:37 +08:00
Soulter 7bafc87e2b 🐛 fix: 修复部分单指令失效的问题 2025-02-19 19:04:23 +08:00
Soulter 1a0de02a15 fix: 尝试修复gewechat群聊用户名出现unknown 2025-02-19 17:07:11 +08:00
Soulter 6d5d278624 fix: 尝试修复 gewechat 微信群聊情况下可能导致 unknown 的问题 #537 2025-02-19 16:42:30 +08:00
Soulter 3b4cc48fa0 👌 perf: 开启对话隔离的群聊以及私聊下,非op可以可以使用 /del 和 /reset #519 2025-02-19 16:22:42 +08:00
Soulter c908461088 Merge pull request #543 from Soulter/refactor-command-group
更换为预编译指令的方式处理指令组指令并且让事件钩子也支持 yield 的方式发送消息
2025-02-19 15:54:26 +08:00
Soulter 53d1398d30 fix: 修复子指令组不能被调用的问题 2025-02-19 15:53:01 +08:00
Soulter 782c0367d0 feat: 事件钩子支持 yield 方式发送消息 2025-02-19 15:29:10 +08:00
Soulter 4678222e9b 👌 refactor: 更换为预编译指令的方式处理指令组指令 2025-02-19 14:55:14 +08:00
Soulter f71dc3e4be 🐛 fix: reminder time zone issue 2025-02-19 00:15:14 +08:00
Soulter f6233893bd 🐛 fix: 修复 reminder rm失败 #529 2025-02-19 00:10:18 +08:00
Soulter 6427bcf130 👌perf: 查询模型列表时,可以显示当前使用的模型名称 #523 2025-02-17 22:35:45 +08:00
Soulter 8fa41b706c Merge pull request #522 from yuanxinlyx/fix-keyerror-ls-command
fix: resolve KeyError when current conversation is not in paginated list
2025-02-17 21:45:40 +08:00
YuanxinLu 4706c4438d fix: resolve KeyError when current conversation is not in paginated list 2025-02-17 03:15:59 +08:00
64 changed files with 2179 additions and 692 deletions
+4 -1
View File
@@ -17,10 +17,13 @@ addons/plugins
tests/astrbot_plugin_openai
chroma
node_modules/
dashboard/node_modules/
dashboard/dist/
.DS_Store
package-lock.json
package.json
venv/*
packages/python_interpreter/workplace
.venv/*
.conda/
-4
View File
@@ -148,10 +148,6 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
</div>
## Sponsors
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
## Disclaimer
1. The project is protected under the `AGPL-v3` opensource license.
+67 -3
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.28"
VERSION = "3.4.31"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -154,7 +154,8 @@ CONFIG_METADATA_2 = {
"id": {
"description": "ID",
"type": "string",
"hint": "用于在多实例下方便管理和识别。自定义,ID 不能重复。",
"obvious_hint": True,
"hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。",
},
"type": {
"description": "适配器类型",
@@ -409,6 +410,29 @@ CONFIG_METADATA_2 = {
"model": "gpt-4o-mini",
},
},
"xAI": {
"id": "xai",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {
"model": "grok-2-latest",
},
},
"anthropic(claude)": {
"id": "claude",
"type": "anthropic_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
"model_config": {
"model": "claude-3-5-sonnet-latest",
"max_tokens": 4096,
},
},
"ollama": {
"id": "ollama_default",
"type": "openai_chat_completion",
@@ -504,6 +528,25 @@ CONFIG_METADATA_2 = {
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
"dify_query_input_key": "astrbot_text_query",
"variables": {},
"timeout": 60,
},
"dashscope": {
"id": "dashscope",
"type": "dashscope",
"enable": True,
"dashscope_app_type": "agent",
"dashscope_api_key": "",
"dashscope_app_id": "",
"variables": {},
"timeout": 60,
},
"fastgpt": {
"id": "fastgpt",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.fastgpt.in/api/v1",
"timeout": 60,
},
"whisper(API)": {
@@ -542,6 +585,26 @@ CONFIG_METADATA_2 = {
},
},
"items": {
# "variables": {
# "description": "工作流固定输入变量",
# "type": "object",
# "obvious_hint": True,
# "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
# },
# "fastgpt_app_type": {
# "description": "应用类型",
# "type": "string",
# "hint": "FastGPT 应用的应用类型。",
# "options": ["agent", "workflow", "plugin"],
# "obvious_hint": True,
# },
"dashscope_app_type": {
"description": "应用类型",
"type": "string",
"hint": "阿里云百炼应用的应用类型。",
"options": ["agent", "agent-arrange", "dialog-workflow", "task-workflow"],
"obvious_hint": True,
},
"timeout": {
"description": "超时时间",
"type": "int",
@@ -568,7 +631,8 @@ CONFIG_METADATA_2 = {
"id": {
"description": "ID",
"type": "string",
"hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。",
"obvious_hint": True,
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
},
"type": {
"description": "模型提供商类型",
+5 -9
View File
@@ -63,9 +63,6 @@ class AstrBotCoreLifecycle:
await self.provider_manager.initialize()
'''根据配置实例化各个 Provider'''
await self.platform_manager.initialize()
'''根据配置实例化各个平台适配器'''
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
await self.pipeline_scheduler.initialize()
'''初始化消息事件流水线调度器'''
@@ -74,19 +71,18 @@ class AstrBotCoreLifecycle:
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
self.start_time = int(time.time())
self.curr_tasks: List[asyncio.Task] = []
await self.platform_manager.initialize()
'''根据配置实例化各个平台适配器'''
def _load(self):
platform_tasks = self.load_platform()
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_:
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
+1 -1
View File
@@ -58,7 +58,7 @@ class LogManager:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_formatter = colorlog.ColoredFormatter(
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s',
fmt='%(log_color)s [%(asctime)s] [%(levelname)-5s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s',
datefmt='%H:%M:%S',
log_colors=log_color_config
)
+26 -11
View File
@@ -30,11 +30,19 @@ from enum import Enum
from pydantic.v1 import BaseModel
class ComponentType(Enum):
Plain = "Plain"
Face = "Face"
Record = "Record"
Video = "Video"
At = "At"
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
Video = "Video" # 视频
At = "At" # At
Node = "Node" # 转发消息的一个节点
Nodes = "Nodes" # 转发消息的多个节点
Poke = "Poke" # QQ 戳一戳
Image = "Image" # 图片
Reply = "Reply" # 回复
Forward = "Forward" # 转发消息
File = "File" # 文件
RPS = "RPS" # TODO
Dice = "Dice" # TODO
Shake = "Shake" # TODO
@@ -43,18 +51,12 @@ class ComponentType(Enum):
Contact = "Contact" # TODO
Location = "Location" # TODO
Music = "Music"
Image = "Image"
Reply = "Reply"
RedBag = "RedBag"
Poke = "Poke"
Forward = "Forward"
Node = "Node"
Xml = "Xml"
Json = "Json"
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -362,6 +364,18 @@ class Node(BaseMessageComponent):
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
class Nodes(BaseMessageComponent):
type: ComponentType = "Nodes"
nodes: T.List[Node]
def __init__(self, nodes: T.List[Node], **_):
super().__init__(nodes=nodes, **_)
def toDict(self):
return {
"messages": [node.toDict() for node in self.nodes]
}
class Xml(BaseMessageComponent):
@@ -451,6 +465,7 @@ ComponentTypes = {
"poke": Poke,
"forward": Forward,
"node": Node,
"nodes": Nodes,
"xml": Xml,
"json": Json,
"cardimage": CardImage,
@@ -28,4 +28,3 @@ class ContentSafetyCheckStage(Stage):
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return
event.continue_event()
@@ -1,66 +0,0 @@
'''
Dify 调用 Stage
'''
import traceback
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
class DifyRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
return
if provider.meta().type != "dify":
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
event.set_extra("provider_request", req)
if not req.prompt:
return
req.session_id = event.unified_msg_origin
try:
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
yield # rick roll
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e)))
return
@@ -13,6 +13,7 @@ from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
class LLMRequestSubStage(Stage):
@@ -54,7 +55,7 @@ class LLMRequestSubStage(Stage):
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
req.session_id = conversation_id
req.session_id = event.unified_msg_origin
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
@@ -64,11 +65,12 @@ class LLMRequestSubStage(Stage):
if not req.prompt and not req.image_urls:
return
# 执行请求 LLM 前事件。
# 执行请求 LLM 前事件钩子
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
logger.debug(f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
@@ -82,10 +84,11 @@ class LLMRequestSubStage(Stage):
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
# 执行 LLM 响应后的事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
logger.debug(f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
@@ -154,6 +157,6 @@ class LLMRequestSubStage(Stage):
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.session_id,
req.conversation.cid,
history=contexts_to_save
)
@@ -28,10 +28,8 @@ class StarRequestSubStage(Stage):
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
logger.debug(f"执行插件 handler {handler.handler_full_name}")
logger.debug(f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}")
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for ret in wrapper:
yield ret
+6 -13
View File
@@ -3,7 +3,6 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from .method.dify_request import DifyRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest
@@ -21,9 +20,6 @@ class ProcessStage(Stage):
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
self.dify_request_sub_stage = DifyRequestSubStage()
await self.dify_request_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
@@ -45,22 +41,19 @@ class ProcessStage(Stage):
else:
yield
# 调用提供商相关请求
# 调用 LLM 相关请求
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if not event._has_send_oper and event.is_at_or_wake_command:
if not event._has_send_oper and event.is_at_or_wake_command and not event.call_llm:
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
match provider.meta().type:
case "dify":
async for _ in self.dify_request_sub_stage.process(event):
yield
case _:
async for _ in self.llm_request_sub_stage.process(event):
yield
async for _ in self.llm_request_sub_stage.process(event):
yield
@@ -73,8 +73,6 @@ class RateLimitStage(Stage):
timestamps.append(now)
return event.continue_event()
def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None:
"""
移除时间窗口外的时间戳。
+7 -2
View File
@@ -1,6 +1,7 @@
import random
import asyncio
import math
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
@@ -9,6 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import Plain, Reply, At
@register_stage
class RespondStage(Stage):
@@ -88,7 +90,10 @@ class RespondStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
await handler.handler(event)
try:
logger.debug(f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
event.clear_result()
+11 -3
View File
@@ -10,6 +10,7 @@ from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
@register_stage
class ResultDecorateStage(Stage):
@@ -47,7 +48,7 @@ class ResultDecorateStage(Stage):
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
if result is None or not result.chain:
return
# 回复时检查内容安全
@@ -59,10 +60,17 @@ class ResultDecorateStage(Stage):
async for _ in self.content_safe_check_stage.process(event, check_text=text):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
await handler.handler(event)
try:
logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。")
except BaseException:
logger.error(traceback.format_exc())
# 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result()
if result is None:
+6 -3
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import abc
import inspect
from astrbot.api import logger
from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
@@ -36,16 +37,18 @@ class Stage(abc.ABC):
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
**params
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
'''调用 Handler。'''
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
try:
ready_to_call = handler(event, **params)
ready_to_call = handler(event, *args, **kwargs)
except TypeError as e:
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
logger.debug(str(e))
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
async for ret in ready_to_call:
+7 -26
View File
@@ -5,6 +5,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At, Reply
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
@@ -76,34 +77,17 @@ class WakingCheckStage(Stage):
# 检查插件的 handler filter
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
# filter 需满足 AND 逻辑关系
passed = True
child_command_handler_md = None
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
if len(handler.event_filters) == 0:
# 不可能有这种情况, 也不允许有这种情况
continue
if 'sub_command' in handler.extras_configs:
# 如果是子指令
continue
for filter in handler.event_filters:
try:
if isinstance(filter, CommandGroupFilter):
"""如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata"""
ok, child_command_handler_md = filter.filter(
event, self.ctx.astrbot_config
)
if not ok:
passed = False
else:
handler = child_command_handler_md # handler 覆盖
break
elif isinstance(filter, PermissionTypeFilter):
if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
else:
@@ -111,19 +95,15 @@ class WakingCheckStage(Stage):
passed = False
break
except Exception as e:
# event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
# yield
await event.send(
MessageEventResult().message(
f"插件 {handler.handler_full_name} 报错:{e}"
f"插件 {star_map[handler.handler_module_path].name}: {e}"
)
)
event.stop_event()
passed = False
break
if passed:
if permission_not_pass:
if self.no_permission_reply:
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"))
@@ -138,6 +118,7 @@ class WakingCheckStage(Stage):
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
event.clear_extra()
event.set_extra("activated_handlers", activated_handlers)
+11 -2
View File
@@ -57,7 +57,8 @@ class AstrMessageEvent(abc.ABC):
self._has_send_oper = False
'''是否有过至少一次发送操作'''
self.call_llm = False
'''是否在此消息事件中禁止默认的 LLM 请求'''
# back_compability
self.platform = platform_meta
@@ -242,7 +243,15 @@ class AstrMessageEvent(abc.ABC):
'''
if self._result is None:
return False # 默认是继续传播
return self._result.is_stopped()
return self._result.is_stopped()
def should_call_llm(self, call_llm: bool):
'''
是否在此消息事件中禁止默认的 LLM 请求。
只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。
'''
self.call_llm = call_llm
def get_result(self) -> MessageEventResult:
'''
+92 -30
View File
@@ -1,3 +1,5 @@
import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .platform import Platform
from typing import List
@@ -11,43 +13,103 @@ class PlatformManager():
self.platform_insts: List[Platform] = []
'''加载的 Platform 的实例'''
self._inst_map = {}
self.platforms_config = config['platform']
self.settings = config['platform_settings']
self.event_queue = event_queue
try:
for platform in self.platforms_config:
if not platform['enable']:
continue
match platform['type']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
except Exception as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}")
async def initialize(self):
'''初始化所有平台适配器'''
for platform in self.platforms_config:
if not platform['enable']:
continue
if platform['type'] not in platform_cls_map:
logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = platform_cls_map[platform['type']]
logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
await self.load_platform(platform)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
# 网页聊天
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
asyncio.create_task(self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")))
async def load_platform(self, platform_config: dict):
'''实例化一个平台'''
if not platform_config['enable']:
return
logger.info(f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...")
# 动态导入
try:
match platform_config['type']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
except Exception as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}")
if platform_config['type'] not in platform_cls_map:
logger.error(f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。")
return
cls_type = platform_cls_map[platform_config['type']]
inst = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config['id']] = inst
self.platform_insts.append(inst)
asyncio.create_task(self._task_wrapper(asyncio.create_task(inst.run(), name=platform_config['id'] + "_platform")))
async def _task_wrapper(self, task: asyncio.Task):
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def reload(self, platform_config: dict):
# 还未实现完成,不要调用此方法
if platform_config['id'] in self._inst_map:
# 正在运行
if getattr(self._inst_map[platform_config['id']], 'terminate', None):
logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
await self._inst_map[platform_config['id']].terminate()
logger.info(f"{platform_config['id']} 平台适配器已终止。")
del self._inst_map[platform_config['id']]
self.platform_insts.remove(self._inst_map[platform_config['id']])
else:
logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
# 再启动新的实例
await self.load_platform(platform_config)
else:
# 先将 _inst_map 中在 platform_config 中不存在的实例删除
config_ids = [platform['id'] for platform in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
if getattr(self._inst_map[key], 'terminate', None):
logger.info(f"正在尝试终止 {key} 平台适配器 ...")
await self._inst_map[key].terminate()
logger.info(f"{key} 平台适配器已终止。")
del self._inst_map[key]
self.platform_insts.remove(self._inst_map[key])
else:
logger.warning(f"可能无法正常终止 {key} 平台适配器。")
# 再启动新的实例
await self.load_platform(platform_config)
def get_insts(self):
return self.platform_insts
+6
View File
@@ -20,6 +20,12 @@ class Platform(abc.ABC):
'''
raise NotImplementedError
async def terminate(self):
'''
终止一个平台的运行实例。
'''
pass
@abc.abstractmethod
def meta(self) -> PlatformMetadata:
'''
@@ -1,7 +1,7 @@
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record, At, Node, Music, Video
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -45,15 +45,25 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
send_one_by_one = False
for seg in message.chain:
if isinstance(seg, (Node, Music)):
if isinstance(seg, (Node, Nodes)):
# 转发消息不能和普通消息混在一起发送
send_one_by_one = True
break
if send_one_by_one:
for seg in message.chain:
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
await asyncio.sleep(0.5)
if isinstance(seg, Nodes):
# 带有多个节点的合并转发消息
payload = seg.toDict()
if self.get_group_id():
payload['group_id'] = self.get_group_id()
await self.bot.call_action('send_group_forward_msg', **payload)
else:
payload['user_id'] = self.get_sender_id()
await self.bot.call_action('send_private_forward_msg', **payload)
else:
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
await asyncio.sleep(0.5)
else:
await self.bot.send(self.message_obj.raw_message, ret)
@@ -16,7 +16,7 @@ from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
from astrbot.core.utils.io import download_file
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
@register_platform_adapter("aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
@@ -32,6 +32,8 @@ class AiocqhttpAdapter(Platform):
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
)
self.stop = False
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value:
@@ -146,8 +148,11 @@ class AiocqhttpAdapter(Platform):
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
elif t == 'file':
if m['data']['url'] and m['data']['url'].startswith("http"):
if m['data'].get('url') and m['data'].get('url').startswith("http"):
# Lagrange
logger.info("guessing lagrange")
@@ -159,6 +164,8 @@ class AiocqhttpAdapter(Platform):
"file": path,
"name": file_name
}
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
else:
try:
@@ -173,13 +180,17 @@ class AiocqhttpAdapter(Platform):
"file": ret['file'],
"name": ret['file_name']
}
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
else:
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = event
@@ -230,11 +241,15 @@ class AiocqhttpAdapter(Platform):
return bot
async def terminate(self):
self.stop = True
await asyncio.sleep(1)
def meta(self) -> PlatformMetadata:
return self.metadata
async def shutdown_trigger_placeholder(self):
while not self._event_queue.closed:
while not self._event_queue.closed and not self.stop:
await asyncio.sleep(1)
logger.info("aiocqhttp 适配器已关闭。")
@@ -248,4 +263,4 @@ class AiocqhttpAdapter(Platform):
bot=self.bot
)
self.commit_event(message_event)
self.commit_event(message_event)
@@ -5,6 +5,7 @@ import quart
import base64
import datetime
import re
import os
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api import logger, sp
@@ -51,6 +52,10 @@ class SimpleGewechatClient():
self.event_queue = event_queue
self.multimedia_downloader = None
self.userrealnames = {}
self.stop = False
async def get_token_id(self):
async with aiohttp.ClientSession() as session:
@@ -118,10 +123,25 @@ class SimpleGewechatClient():
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] \
.replace('在群聊中@了你', '') \
.replace('在群聊中发了一段语音', '') \
.replace('在群聊中发了一张图片', '') # 真实昵称
# 解析用户真实名字
user_real_name = "unknown"
if abm.group_id:
if abm.group_id not in self.userrealnames or user_id not in self.userrealnames[abm.group_id]:
# 获取群成员列表,并且缓存
if abm.group_id not in self.userrealnames:
self.userrealnames[abm.group_id] = {}
member_list = await self.get_chatroom_member_list(abm.group_id)
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
if member_list and 'memberList' in member_list:
for member in member_list['memberList']:
self.userrealnames[abm.group_id][member['wxid']] = member['nickName']
if user_id in self.userrealnames[abm.group_id]:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0]
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
@@ -213,7 +233,7 @@ class SimpleGewechatClient():
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed:
while not self.event_queue.closed and not self.stop:
await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。")
@@ -285,8 +305,25 @@ class SimpleGewechatClient():
"uuid": qr_uuid,
"appId": appid
})
verify_flag = False
while retry_cnt > 0:
retry_cnt -= 1
# 需要验证码
if verify_flag or os.path.exists("data/temp/gewe_code"):
with open("data/temp/gewe_code", "r") as f:
code = f.read().strip()
if not code:
logger.warning("未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456")
await asyncio.sleep(5)
continue
payload['captchCode'] = code
logger.info(f"使用验证码: {code}")
try:
os.remove("data/temp/gewe_code")
except:
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkLogin",
@@ -295,17 +332,26 @@ class SimpleGewechatClient():
) as resp:
json_blob = await resp.json()
logger.info(f"检查登录状态: {json_blob}")
status = json_blob['data']['status']
nickname = json_blob['data'].get('nickName', '')
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
ret = json_blob['ret']
msg = ''
if json_blob['data'] and 'msg' in json_blob['data']:
msg = json_blob['data']['msg']
if ret == 500 and '安全验证码' in msg:
logger.warning("此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456")
verify_flag = True
else:
logger.warning(f"未知状态: {status}")
status = json_blob['data']['status']
nickname = json_blob['data'].get('nickName', '')
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5)
if appid:
@@ -313,6 +359,23 @@ class SimpleGewechatClient():
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
'''API'''
async def get_chatroom_member_list(self, chatroom_wxid: str):
payload = {
"appId": self.appid,
"chatroomId": chatroom_wxid
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
return json_blob['data']
async def post_text(self, to_wxid, content: str, ats: str = ""):
payload = {
"appId": self.appid,
@@ -47,6 +47,10 @@ class GewechatPlatformAdapter(Platform):
"基于 gewechat 的 Wechat 适配器",
)
async def terminate(self):
self.client.stop = True
await asyncio.sleep(1)
@override
def run(self):
self.client = SimpleGewechatClient(
@@ -1,5 +1,6 @@
import os
import uuid
import base64
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
@@ -31,6 +32,11 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file.startswith("base64://"):
base64_str = comp.file[9:]
image_data = base64.b64decode(base64_str)
with open(path, "wb") as f:
f.write(image_data)
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
@@ -102,6 +102,29 @@ class FuncCall:
)
return _l
def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
tools = []
for f in self.func_list:
if not f.active:
continue
# Convert internal format to Anthropic style
tool = {
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": f.parameters.get("properties", {}),
# Keep the required field from the original parameters if it exists
"required": f.parameters.get("required", [])
}
}
tools.append(tool)
return tools
def get_func_desc_google_genai_style(self) -> Dict:
declarations = {}
tools = []
+148 -126
View File
@@ -1,11 +1,9 @@
import traceback
import uuid
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
from .register import provider_cls_map, llm_tools
from astrbot.core import logger, sp
@@ -16,6 +14,14 @@ class ProviderManager():
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', [])
self.astrbot_config = config
self.selected_provider_id = sp.get("curr_provider")
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
self.provider_enabled = self.provider_settings.get("enable", False)
self.stt_enabled = self.provider_stt_settings.get("enable", False)
self.tts_enabled = self.provider_tts_settings.get("enable", False)
# 人格情景管理
# 目前没有拆成独立的模块
@@ -75,14 +81,15 @@ class ProviderManager():
_mood_imitation_dialogs_processed=""
)
self.personas.append(self.selected_default_persona)
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
self.inst_map = {}
'''Provider 实例映射. key: provider_id, value: Provider 实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
@@ -90,7 +97,6 @@ class ProviderManager():
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
# kdb(experimental)
@@ -99,141 +105,157 @@ class ProviderManager():
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
changed = False
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
if provider_cfg['id'] in self.loaded_ids:
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}")
provider_cfg['id'] = new_id
changed = True
self.loaded_ids[provider_cfg['id']] = True
try:
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
if changed:
try:
config.save_config()
except Exception as e:
logger.warning(f"保存配置文件失败:{e}")
async def initialize(self):
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_tts_settings.get("enable", False)
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
await self.load_provider(provider_config)
provider_metadata = provider_cls_map[provider_config['type']]
logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.db_helper,
self.provider_settings.get('persistant_history', True),
self.selected_default_persona
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
self.curr_provider_inst = self.provider_insts[0]
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if stt_enabled and not self.curr_stt_provider_inst:
if self.stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst:
if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
async def load_provider(self, provider_config: dict):
if not provider_config['enable']:
return
logger.info(f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商适配器 ...")
# 动态导入
try:
match provider_config['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "anthropic_chat_completion":
from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "dashscope":
from .sources.dashscope_source import ProviderDashscope as ProviderDashscope
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
return
except Exception as e:
logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因")
return
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
return
provider_metadata = provider_cls_map[provider_config['type']]
try:
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if self.selected_stt_provider_id == provider_config['id'] and self.stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
if not self.curr_stt_provider_inst and self.stt_enabled:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if self.selected_tts_provider_id == provider_config['id'] and self.tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
if not self.curr_tts_provider_inst and self.tts_enabled:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.db_helper,
self.provider_settings.get('persistant_history', True),
self.selected_default_persona
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if self.selected_provider_id == provider_config['id'] and self.provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
if not self.curr_provider_inst and self.provider_enabled:
self.curr_provider_inst = inst
self.inst_map[provider_config['id']] = inst
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config['id'])
if provider_config['enable']:
await self.load_provider(provider_config)
# 和配置文件保持同步
config_ids = [provider['id'] for provider in self.providers_config]
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
def get_insts(self):
return self.provider_insts
async def terminate_provider(self, provider_id: str):
if provider_id in self.inst_map:
logger.info(f"终止 {provider_id} 提供商适配器 ...")
if self.inst_map[provider_id] in self.provider_insts:
self.provider_insts.remove(self.inst_map[provider_id])
if self.inst_map[provider_id] in self.stt_provider_insts:
self.stt_provider_insts.remove(self.inst_map[provider_id])
if self.inst_map[provider_id] in self.tts_provider_insts:
self.tts_provider_insts.remove(self.inst_map[provider_id])
if getattr(self.inst_map[provider_id], 'terminate', None):
await self.inst_map[provider_id].terminate()
logger.info(f"{provider_id} 提供商适配器已终止。")
del self.inst_map[provider_id]
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
@@ -0,0 +1,189 @@
from typing import List
from mimetypes import guess_type
from anthropic import AsyncAnthropic
from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("anthropic_chat_completion", "Anthropic Claude API 提供商适配器")
class ProviderAnthropic(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True,
default_persona: Personality = None
) -> None:
# Skip OpenAI's __init__ and call Provider's __init__ directly
Provider.__init__(self, provider_config, provider_settings, persistant_history, db_helper, default_persona)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.client = AsyncAnthropic(
api_key=self.chosen_api_key,
timeout=self.timeout,
base_url=self.base_url
)
self.set_model(provider_config['model_config']['model'])
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_anthropic_style()
if tool_list:
payloads['tools'] = tool_list
completion = await self.client.messages.create(
**payloads,
stream=False
)
assert isinstance(completion, Message)
logger.debug(f"completion: {completion}")
if len(completion.content) == 0:
raise Exception("API 返回的 completion 为空。")
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
# 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
content = completion.content[-1]
llm_response = LLMResponse("assistant")
if content.type == "text":
# text completion
completion_text = str(content.text).strip()
llm_response.completion_text = completion_text
# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
# tools call (function calling)
args_ls = []
func_name_ls = []
func_name_ls.append(content.name)
args_ls.append(content.input)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
if not prompt:
prompt = "<image>"
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
payloads = {
"messages": context_query,
**model_config
}
# Anthropic has a different way of handling system prompts
if system_prompt:
payloads['system'] = system_prompt
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
try:
await self.pop_record(context_query)
response = await self.client.messages.create(
messages=context_query,
**model_config
)
llm_response = LLMResponse("assistant")
llm_response.completion_text = response.content[0].text
llm_response.raw_completion = response
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
else:
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''组装上下文,支持文本和图片'''
if not image_urls:
return {"role": "user", "content": text}
content = []
content.append({"type": "text", "text": text})
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine
content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_data.split("base64,")[1] if "base64," in image_data else image_data
}
})
return {"role": "user", "content": content}
@@ -0,0 +1,128 @@
import asyncio
import functools
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
class ProviderDashscope(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
default_persona: Personality = None,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空。")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空。")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.model_name = "dashscope"
self.variables: dict = provider_config.get("variables", {})
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
contexts_no_img = await self._remove_image_from_context(contexts)
context_query = [*contexts_no_img, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# 调用阿里云百炼 API
partial = functools.partial(
Application.call,
app_id=self.app_id,
api_key=self.api_key,
messages=context_query,
biz_params=payload_vars or None,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
# 调用阿里云百炼 API
partial = functools.partial(
Application.call,
app_id=self.app_id,
promtp=prompt,
api_key=self.api_key,
biz_params=payload_vars or None,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
)
return LLMResponse(
role="err",
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
)
output_text = response.output.get("text", "")
return LLMResponse(role="assistant", completion_text=output_text)
async def forget(self, session_id):
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
async def terminate(self):
pass
+55 -46
View File
@@ -32,6 +32,7 @@ class ProviderDify(Provider):
self.model_name = "dify"
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
self.dify_query_input_key = provider_config.get("dify_query_input_key", "astrbot_text_query")
self.variables: dict = provider_config.get("variables", {})
if not self.dify_query_input_key:
self.dify_query_input_key = "astrbot_text_query"
self.timeout = provider_config.get("timeout", 120)
@@ -72,55 +73,63 @@ class ProviderDify(Provider):
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
try:
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk['event'] == "message" or \
chunk['event'] == "agent_message":
result += chunk['answer']
if not conversation_id:
self.conversation_ids[session_id] = chunk['conversation_id']
conversation_id = chunk['conversation_id']
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout
):
match chunk['event']:
case "workflow_started":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
case "node_finished":
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
case "workflow_finished":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
if chunk['data']['error']:
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
if self.workflow_output_key not in chunk['data']['outputs']:
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
result = chunk['data']['outputs'][self.workflow_output_key]
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
except Exception as e:
logger.error(f"Dify 请求失败:{str(e)}")
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}")
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={
**session_var
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk['event'] == "message" or \
chunk['event'] == "agent_message":
result += chunk['answer']
if not conversation_id:
self.conversation_ids[session_id] = chunk['conversation_id']
conversation_id = chunk['conversation_id']
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**session_var
},
user=session_id,
files=files_payload,
timeout=self.timeout
):
match chunk['event']:
case "workflow_started":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
case "node_finished":
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
case "workflow_finished":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
if chunk['data']['error']:
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
if self.workflow_output_key not in chunk['data']['outputs']:
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
result = chunk['data']['outputs'][self.workflow_output_key]
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
return LLMResponse(role="assistant", completion_text=result)
async def forget(self, session_id):
+51 -30
View File
@@ -1,5 +1,6 @@
import base64
import aiohttp
import random
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
@@ -138,8 +139,7 @@ class ProviderGoogleGenAI(Provider):
"role": "model",
"parts": [{"text": message["content"]}]
})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
result = await self.client.generate_content(
@@ -194,33 +194,50 @@ class ProviderGoogleGenAI(Provider):
**model_config
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
elif "Function calling is not enabled" in str(e):
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
raise e
retry = 10
keys = self.api_keys.copy()
chosen_key = random.choice(keys)
for i in range(retry):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
elif "Function calling is not enabled" in str(e):
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
break
elif "429" in str(e) or "API key not valid" in str(e):
keys.remove(chosen_key)
if len(keys) > 0:
chosen_key = random.choice(keys)
logger.info(f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}...")
continue
else:
logger.error(f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}...")
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
else:
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
raise e
return llm_response
@@ -265,4 +282,8 @@ class ProviderGoogleGenAI(Provider):
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
return ''
return ''
async def terminate(self):
await self.client.client.close()
logger.info("Google GenAI 适配器已终止。")
@@ -48,8 +48,10 @@ class ProviderOpenAIOfficial(Provider):
base_url=provider_config.get("api_base", None),
timeout=self.timeout
)
self.set_model(provider_config['model_config']['model'])
model_config = provider_config.get("model_config", {})
model = model_config.get("model", "unknown")
self.set_model(model)
async def get_models(self):
try:
+57 -19
View File
@@ -1,20 +1,20 @@
import re
import inspect
from typing import List
from typing import List, Any, Type, Dict
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin
from .custom_filter import CustomFilter
from ..star_handler import StarHandlerMetadata
# 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter, ParameterValidationMixin):
class CommandFilter(HandlerFilter):
'''标准指令过滤器'''
def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None):
def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None, parent_command_names: List[str] = [""]):
self.command_name = command_name
self.alias = alias if alias else set()
self.parent_command_names = parent_command_names
if handler_md:
self.init_handler_md(handler_md)
self.custom_filter_list: List[CustomFilter] = []
@@ -26,6 +26,7 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
result += f"{k}({v.__name__}),"
else:
result += f"{k}({type(v).__name__})={v},"
result = result.rstrip(",")
return result
def init_handler_md(self, handle_md: StarHandlerMetadata):
@@ -54,6 +55,39 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
if not custom_filter.filter(event, cfg):
return False
return True
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
'''将参数列表 params 根据 param_type 转换为参数字典。
'''
result = {}
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
if i >= len(params):
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
# 是类型
raise ValueError(f"必要参数缺失。该指令完整参数: {self.print_types()}")
else:
# 是默认值
result[param_name] = param_type_or_default_val
else:
# 尝试强制转换
try:
if param_type_or_default_val is None:
if params[i].isdigit():
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, int):
result[param_name] = int(params[i])
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
raise ValueError(f"参数 {param_name} 类型错误。完整参数: {self.print_types()}")
return result
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
@@ -61,27 +95,31 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
if not self.custom_filter_ok(event, cfg):
return False
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
# 分割为列表(每个参数之间可能会有多个空格)
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0] and ls[0] not in self.alias:
# 检查是否以指令开头
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
candidates = [self.command_name] + list(self.alias)
ok = False
for candidate in candidates:
for parent_command_name in self.parent_command_names:
if parent_command_name:
_full = f"{parent_command_name} {candidate}"
else:
_full = candidate
if message_str.startswith(f"{_full} ") or message_str == _full:
message_str = message_str[len(_full):].strip()
ok = True
break
if not ok:
return False
# if len(self.handler_params) == 0 and len(ls) > 1:
# # 一定程度避免 LLM 聊天时误判为指令
# return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 分割为列表
ls = message_str.split(" ")
# 去除空字符串
ls = [param for param in ls if param]
params = {}
try:
params = self.validate_and_convert_params(ls, self.handler_params)
except ValueError as e:
raise e
+33 -37
View File
@@ -11,11 +11,12 @@ from ..star_handler import StarHandlerMetadata
# 指令组受到 wake_prefix 的制约。
class CommandGroupFilter(HandlerFilter):
def __init__(self, group_name: str, alias: set = None):
def __init__(self, group_name: str, alias: set = None, parent_group: CommandGroupFilter = None):
self.group_name = group_name
self.alias = alias if alias else set()
self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = []
self.custom_filter_list: List[CustomFilter] = []
self.parent_group = parent_group
def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]):
self.sub_command_filters.append(sub_command_filter)
@@ -23,6 +24,24 @@ class CommandGroupFilter(HandlerFilter):
def add_custom_filter(self, custom_filter: CustomFilter):
self.custom_filter_list.append(custom_filter)
def get_complete_command_names(self) -> List[str]:
'''遍历父节点获取完整的指令名。
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。'''
parent_cmd_names = self.parent_group.get_complete_command_names() if self.parent_group else []
if not parent_cmd_names:
# 根节点
return [self.group_name] + list(self.alias)
result = []
candidates = [self.group_name] + list(self.alias)
for parent_cmd_name in parent_cmd_names:
for candidate in candidates:
result.append(parent_cmd_name + " " + candidate)
return result
# 以树的形式打印出来
def print_cmd_tree(self,
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
@@ -43,6 +62,10 @@ class CommandGroupFilter(HandlerFilter):
result += f" ({cmd_th})"
else:
result += " (无参数指令)"
if sub_filter.handler_md and sub_filter.handler_md.desc:
result += f": {sub_filter.handler_md.desc}"
result += "\n"
elif isinstance(sub_filter, CommandGroupFilter):
custom_filter_pass = True
@@ -61,46 +84,19 @@ class CommandGroupFilter(HandlerFilter):
return False
return True
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]:
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False, None
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
ls = re.split(r"\s+", message_str)
if ls[0] != self.group_name and ls[0] not in self.alias:
return False, None
# 改写 message_str
ls = ls[1:]
# event.message_str = " ".join(ls)
# event.message_str = event.message_str.strip()
parsing_command = " ".join(ls)
parsing_command = parsing_command.strip()
event.set_extra("parsing_command", parsing_command)
return False
# 判断当前指令组的自定义过滤器
if not self.custom_filter_ok(event, cfg):
return False, None
return False
if parsing_command == "":
# 当前还是指令组
complete_command_names = self.get_complete_command_names()
if event.message_str.strip() in complete_command_names:
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
child_command_handler_md = None
for sub_filter in self.sub_command_filters:
if isinstance(sub_filter, CommandFilter):
if sub_filter.filter(event, cfg):
child_command_handler_md = sub_filter.get_handler_md()
return True, child_command_handler_md
elif isinstance(sub_filter, CommandGroupFilter):
ok, handler = sub_filter.filter(event, cfg)
if ok:
child_command_handler_md = handler
return True, child_command_handler_md
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree)
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False
+9 -20
View File
@@ -54,14 +54,12 @@ def get_handler_or_create(
def register_command(command_name: str = None, sub_command: str = None, alias: set = None, **kwargs):
'''注册一个 Command.
'''
# print("command: ", command_name, args, kwargs)
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
# 子指令
new_command = CommandFilter(sub_command, alias, None)
parent_command_names = command_name.parent_group.get_complete_command_names()
new_command = CommandFilter(sub_command, alias, None, parent_command_names=parent_command_names)
command_name.parent_group.add_sub_command_filter(new_command)
else:
# 裸指令
@@ -73,10 +71,7 @@ def register_command(command_name: str = None, sub_command: str = None, alias: s
kwargs['sub_command'] = True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
new_command.init_handler_md(handler_md)
if add_to_event_filters:
# 裸指令
handler_md.event_filters.append(new_command)
handler_md.event_filters.append(new_command)
return awaitable
return decorator
@@ -142,25 +137,19 @@ def register_command_group(
):
'''注册一个 CommandGroup
'''
# print("commandgroup: ", command_group_name,args, kwargs)
new_group = None
add_to_event_filters = False
if isinstance(command_group_name, RegisteringCommandable):
# 子指令组
new_group = CommandGroupFilter(sub_command, alias)
new_group = CommandGroupFilter(sub_command, alias, parent_group=command_group_name.parent_group)
command_group_name.parent_group.add_sub_command_filter(new_group)
else:
# 根指令组
new_group = CommandGroupFilter(command_group_name, alias)
add_to_event_filters = True
def decorator(obj):
if add_to_event_filters:
# 根指令组
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
# 根指令组
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
@@ -168,8 +157,8 @@ def register_command_group(
class RegisteringCommandable():
'''用于指令组级联注册'''
group = register_command_group
command = register_command
group: CommandGroupFilter = register_command_group
command: CommandFilter = register_command
custom_filter = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter):
+8 -4
View File
@@ -32,9 +32,6 @@ class AstrBotUpdator(RepoZipUpdator):
pass
def _reboot(self, delay: int = 3):
if os.environ.get('TEST_MODE', 'off') == 'on':
logger.info("测试模式下不会重启。")
return
py = sys.executable
time.sleep(delay)
self.terminate_child_processes()
@@ -47,8 +44,11 @@ class AstrBotUpdator(RepoZipUpdator):
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
async def get_releases(self) -> list:
return await self.fetch_release_info(self.ASTRBOT_RELEASE_API)
async def update(self, reboot = False, latest = True, version = None):
async def update(self, reboot = False, latest = True, version = None, proxy = ""):
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
file_url = None
@@ -70,6 +70,10 @@ class AstrBotUpdator(RepoZipUpdator):
raise Exception("commit hash 长度不正确,应为 40")
logger.info(f"正在尝试更新到指定 commit: {version}")
file_url = "https://github.com/Soulter/AstrBot/archive/" + version + ".zip"
if proxy:
proxy = proxy.removesuffix("/")
file_url = f"{proxy}/{file_url}"
try:
await download_file(file_url, "temp.zip")
@@ -1,36 +0,0 @@
import inspect
from typing import List, Dict, Any, Type
class ParameterValidationMixin:
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
'''将参数列表 params 根据 param_type 转换为参数字典。
'''
result = {}
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
if i >= len(params):
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
# 是类型
raise ValueError(f"参数 {param_name} 缺失")
else:
# 是默认值
result[param_name] = param_type_or_default_val
else:
# 尝试强制转换
try:
if param_type_or_default_val is None:
if params[i].isdigit():
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, int):
result[param_name] = int(params[i])
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
raise ValueError(f"参数 {param_name} 类型错误")
return result
+13 -20
View File
@@ -34,10 +34,19 @@ class RepoZipUpdator():
result = await response.json()
if not result:
return []
if latest:
ret = self.github_api_release_parser([result[0]])
else:
ret = self.github_api_release_parser(result)
# if latest:
# ret = self.github_api_release_parser([result[0]])
# else:
# ret = self.github_api_release_parser(result)
ret = []
for release in result:
ret.append({
"version": release['name'],
"published_at": release['published_at'],
"body": release['body'],
"tag_name": release['tag_name'],
"zipball_url": release['zipball_url']
})
except BaseException:
raise Exception("解析版本信息失败")
return ret
@@ -49,17 +58,10 @@ class RepoZipUpdator():
'''
ret = []
for release in releases:
version = release['name']
commit_hash = ''
# 规范是: v3.0.7.xxxxxx,其中xxxxxx为 commit hash
_t = version.split(".")
if len(_t) == 4:
commit_hash = _t[3]
ret.append({
"version": release['name'],
"published_at": release['published_at'],
"body": release['body'],
"commit_hash": commit_hash,
"tag_name": release['tag_name'],
"zipball_url": release['zipball_url']
})
@@ -114,15 +116,6 @@ class RepoZipUpdator():
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
# 镜像站点
# match self.repo_mirror:
# case 'https://github-mirror.us.kg/':
# release_url = self.repo_mirror + release_url
# case "https://ghp.ci/":
# release_url = self.repo_mirror + release_url
# case _:
# pass
if proxy:
release_url = f"{proxy}/{release_url}"
+104 -2
View File
@@ -1,4 +1,5 @@
import typing
import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
@@ -61,7 +62,7 @@ def validate_config(data, schema: dict, is_core: bool) -> typing.Tuple[typing.Li
group_meta = group.get("metadata")
if not group_meta:
continue
logger.info(f"验证配置: 组 {key} ...")
# logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
else:
validate(data, schema)
@@ -77,6 +78,7 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
else:
errors, post_config = validate_config(post_config, config.schema, is_core)
except BaseException as e:
logger.error(traceback.format_exc())
logger.warning(f"验证配置时出现异常: {e}")
if errors:
raise ValueError(f"格式校验未通过: {errors}")
@@ -90,6 +92,14 @@ class ConfigRoute(Route):
'/config/get': ('GET', self.get_configs),
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
'/config/plugin/update': ('POST', self.post_plugin_configs),
'/config/platform/new': ('POST', self.post_new_platform),
'/config/platform/update': ('POST', self.post_update_platform),
'/config/platform/delete': ('POST', self.post_delete_platform),
'/config/provider/new': ('POST', self.post_new_provider),
'/config/provider/update': ('POST', self.post_update_provider),
'/config/provider/delete': ('POST', self.post_delete_provider)
}
self.register_routes()
@@ -118,7 +128,99 @@ class ConfigRoute(Route):
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
except Exception as e:
return Response().error(str(e)).__dict__
async def post_new_platform(self):
new_platform_config = await request.json
self.config['platform'].append(new_platform_config)
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.load_platform(new_platform_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "新增平台配置成功~").__dict__
async def post_new_provider(self):
new_provider_config = await request.json
self.config['provider'].append(new_provider_config)
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.load_provider(new_provider_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "新增服务提供商配置成功~").__dict__
async def post_update_platform(self):
update_platform_config = await request.json
platform_id = update_platform_config.get("id", None)
new_config = update_platform_config.get("config", None)
if not platform_id or not new_config:
return Response().error("参数错误").__dict__
for i, platform in enumerate(self.config['platform']):
if platform['id'] == platform_id:
self.config['platform'][i] = new_config
break
else:
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新平台配置成功~").__dict__
async def post_update_provider(self):
update_provider_config = await request.json
provider_id = update_provider_config.get("id", None)
new_config = update_provider_config.get("config", None)
if not provider_id or not new_config:
return Response().error("参数错误").__dict__
for i, provider in enumerate(self.config['provider']):
if provider['id'] == provider_id:
self.config['provider'][i] = new_config
break
else:
return Response().error("未找到对应服务提供商").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.reload(new_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新成功,已经实时生效~").__dict__
async def post_delete_platform(self):
platform_id = await request.json
platform_id = platform_id.get("id")
for i, platform in enumerate(self.config['platform']):
if platform['id'] == platform_id:
del self.config['platform'][i]
break
else:
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除平台配置成功~").__dict__
async def post_delete_provider(self):
provider_id = await request.json
provider_id = provider_id.get("id")
for i, provider in enumerate(self.config['provider']):
if provider['id'] == provider_id:
del self.config['provider'][i]
break
else:
return Response().error("未找到对应服务提供商").__dict__
try:
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.provider_manager.terminate_provider(provider_id)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def _get_astrbot_config(self):
config = self.config
+8 -2
View File
@@ -113,11 +113,17 @@ class PluginRoute(Route):
for filter in handler.event_filters: # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
if isinstance(filter, CommandFilter):
info["type"] = "指令"
info["cmd"] = filter.command_name
info["cmd"] = f"{filter.parent_command_names[0]} {filter.command_name}"
info["cmd"] = info["cmd"].strip()
if self.core_lifecycle.astrbot_config['wake_prefix'] and len(self.core_lifecycle.astrbot_config['wake_prefix']) > 0:
info["cmd"] = f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
elif isinstance(filter, CommandGroupFilter):
info["type"] = "指令组"
info["cmd"] = filter.group_name
info["cmd"] = filter.get_complete_command_names()[0]
info["cmd"] = info["cmd"].strip()
info["sub_command"] = filter.print_cmd_tree(filter.sub_command_filters)
if self.core_lifecycle.astrbot_config['wake_prefix'] and len(self.core_lifecycle.astrbot_config['wake_prefix']) > 0:
info["cmd"] = f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
elif isinstance(filter, RegexFilter):
info["type"] = "正则匹配"
info["cmd"] = filter.regex_str
+1 -1
View File
@@ -16,7 +16,7 @@ class StatRoute(Route):
'/stat/get': ('GET', self.get_stat),
'/stat/version': ('GET', self.get_version),
'/stat/start-time': ('GET', self.get_start_time),
'/stat/restart-core': ('GET', self.restart_core)
'/stat/restart-core': ('POST', self.restart_core)
}
self.db_helper = db_helper
self.register_routes()
+15 -1
View File
@@ -13,6 +13,7 @@ class UpdateRoute(Route):
super().__init__(context)
self.routes = {
'/update/check': ('GET', self.check_update),
'/update/releases': ('GET', self.get_releases),
'/update/do': ('POST', self.update_project),
'/update/dashboard': ('POST', self.update_dashboard),
'/update/pip-install': ('POST', self.install_pip_package)
@@ -46,6 +47,14 @@ class UpdateRoute(Route):
except Exception as e:
logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)")
return Response().error(e.__str__()).__dict__
async def get_releases(self):
try:
ret = await self.astrbot_updator.get_releases()
return Response().ok(ret).__dict__
except Exception as e:
logger.error(f"/api/update/releases: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
async def update_project(self):
data = await request.json
@@ -56,8 +65,13 @@ class UpdateRoute(Route):
version = ''
else:
latest = False
proxy: str = data.get("proxy", None)
if proxy:
proxy = proxy.removesuffix("/")
try:
await self.astrbot_updator.update(latest=latest, version=version)
await self.astrbot_updator.update(latest=latest, version=version, proxy=proxy)
if latest:
try:
+55
View File
@@ -2,6 +2,9 @@ import logging
import jwt
import asyncio
import os
import socket
import sys
import psutil
from astrbot.core.config.default import VERSION
from quart import Quart, request, jsonify, g
from quart.logging import default_handler
@@ -67,6 +70,47 @@ class AstrBotDashboard():
await asyncio.sleep(1)
logger.info("管理面板已关闭。")
def check_port_in_use(self, port: int) -> bool:
"""
跨平台检测端口是否被占用
"""
try:
# 创建 IPv4 TCP Socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 设置超时时间
sock.settimeout(2)
result = sock.connect_ex(('127.0.0.1', port))
sock.close()
# result 为 0 表示端口被占用
return result == 0
except Exception as e:
logger.warning(f"检查端口 {port} 时发生错误: {str(e)}")
# 如果出现异常,保守起见认为端口可能被占用
return True
def get_process_using_port(self, port: int) -> str:
"""获取占用端口的进程详细信息"""
try:
for conn in psutil.net_connections(kind='inet'):
if conn.laddr.port == port:
try:
process = psutil.Process(conn.pid)
# 获取详细信息
proc_info = [
f"进程名: {process.name()}",
f"PID: {process.pid}",
f"执行路径: {process.exe()}",
f"工作目录: {process.cwd()}",
f"启动命令: {' '.join(process.cmdline())}"
]
return "\n ".join(proc_info)
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
return f"无法获取进程详细信息(可能需要管理员权限): {str(e)}"
return "未找到占用进程"
except Exception as e:
return f"获取进程信息失败: {str(e)}"
def run(self):
try:
ip_addr = get_local_ip_addresses()
@@ -76,6 +120,17 @@ class AstrBotDashboard():
port = self.core_lifecycle.astrbot_config['dashboard'].get("port", 6185)
if isinstance(port, str):
port = int(port)
if self.check_port_in_use(port):
process_info = self.get_process_using_port(port)
logger.error(f"错误:端口 {port} 已被占用\n"
f"占用信息: \n {process_info}\n"
f"请确保:\n"
f"1. 没有其他 AstrBot 实例正在运行\n"
f"2. 端口 {port} 没有被其他程序占用\n"
f"3. 如需使用其他端口,请修改配置文件")
raise Exception(f"端口 {port} 已被占用")
display = f"\n ✨✨✨\n AstrBot v{VERSION} 管理面板已启动,可访问\n\n"
display += f" ➜ 本地: http://localhost:{port}\n"
+14
View File
@@ -0,0 +1,14 @@
# What's Changed
1. ✨ 新增: gemini source 初步支持对 API Key 进行负载均衡请求 #534
2. ✨ 新增: 开启对话隔离的群聊以及私聊下,非 op 可以可以使用 /del 和 /reset #519
3. ✨ 新增: 事件钩子支持 yield 方式发送消息
4. ⚡ 优化: 查询模型列表时,可以显示当前使用的模型名称 #523
5. ⚡ 优化: 更换为预编译指令的方式处理指令组指令
6. 🐛 修复: resolve KeyError when current conversation is not in paginated list
7. 🐛 修复: 修复指令组的情况下,Permission Filter 对子指令失效的问题
8. 🐛 修复: 🐛 fix: 修复 reminder rm失败 #529
9. 🐛 修复: 🐛 fix: reminder 时区问题 #529
10. 🐛 修复: 修复 Dify 下无法主动回复的问题 #494
11. 🐛 修复: 添加代码执行器 Docker 宿主机绝对路径配置及相关功能以修复 Docker 下无法使用代码执行器的问题 #525
12. 🐛 修复: gewechat 微信群聊情况下可能导致 unknown 的问题 #537
+5
View File
@@ -0,0 +1,5 @@
# What's Changed
1. ‼️🐛 修复: 修复某些情况下导致插件报错 AttributeError 的问题 #549
2. ✨ 新增: add xAI template
3. 🐛 修复: 修复 dify 无法使用事件钩子的问题以及出现 GeneratorExit 的问题 #533 #264
+18
View File
@@ -0,0 +1,18 @@
# What's Changed
> 提示:改动范围较大
1. ✨ 新增: 添加对 Anthropic Claude 的支持 by @Rt39
2. ✨ 新增: 支持阿里云百炼应用(dashscope)智能体、工作流 #552 by @Soulter
3. ✨ 新增: 支持 AstrBot 更新使用 Github 加速地址 by @Fridemn
4. ✨ 新增: 适配多节点的转发消息,添加新的消息段 `Nodes`
5. ✨ 新增: 支持在管理面板重启(设置页)
6. ✨ 新增: 前端支持以列表展示正式版和开发版的列表
7. ✨ 新增: 支持插件禁止默认的llm调用(event.should_call_llm()#579
8. 🍺 重构: 支持更大范围的热重载以及管理面板将平台和提供商配置独立化 by @Soulter
9. ⚡ 优化: 启动时检查端口占用 by @Fridemn
10. ⚡ 优化: 添加控制台关闭自动滚动按钮 by @Fridemn
11. ⚡ 优化: 在聊天页面添加粘贴图片的快捷键提示 #557
12. 🐛 修复: 修复 webchat 未处理 base64 的问题 by @Raven95676
13. 🐛 修复: 修复 aiocqhttp_platform_adapter 文件相关判断逻辑 by @Raven95676
14. ‼️🐛 修复: 修复 gemini 请求时出现多次不支持函数工具调用最后 429 的问题
@@ -6,8 +6,8 @@
<div v-for="(index, key) in iterable" :key="key" style="margin-bottom: 0.5px;"
v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template">
<v-alert v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
style="margin-bottom: 16px" :text="metadata[metadataKey].items[key]?.hint"
:title="'💡 关于' + metadata[metadataKey].items[key]?.description" type="info" variant="tonal">
style="margin-bottom: 8px" :text="metadata[metadataKey].items[key]?.hint"
:title="'💡 ' + metadata[metadataKey].items[key]?.description" type="info" variant="tonal" color="primary">
</v-alert>
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
@@ -66,8 +66,8 @@
</div>
<div v-else>
<v-alert v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint"
style="margin-bottom: 16px" :text="metadata[metadataKey]?.hint"
:title="'💡 关于' + metadata[metadataKey]?.description" type="info" variant="tonal">
style="margin-bottom: 8px" :text="metadata[metadataKey]?.hint"
:title="'💡 ' + metadata[metadataKey]?.description" type="info" variant="tonal" color="primary">
</v-alert>
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
@@ -4,7 +4,7 @@ import { useCommonStore } from '@/stores/common';
<template>
<div id="term"
style="background-color: #1e1e1e; padding: 16px; border-radius: 8px; overflow-y:scroll">
style="background-color: #1e1e1e; padding: 16px; border-radius: 8px; overflow-y:auto">
</div>
</template>
@@ -13,6 +13,7 @@ export default {
name: 'ConsoleDisplayer',
data() {
return {
autoScroll: true, // 默认开启自动滚动
logColorAnsiMap: {
'\u001b[1;34m': 'color: #0000FF; font-weight: bold;', // bold_blue
'\u001b[1;36m': 'color: #00FFFF; font-weight: bold;', // bold_cyan
@@ -54,6 +55,9 @@ export default {
}
},
methods: {
toggleAutoScroll() {
this.autoScroll = !this.autoScroll;
},
printLog(log) {
// append 一个 span 标签到 termblock 的方式
let ele = document.getElementById('term')
@@ -66,11 +70,13 @@ export default {
break
}
}
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace;'
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap;'
span.classList.add('fade-in')
span.innerText = log
ele.appendChild(span)
ele.scrollTop = ele.scrollHeight
if (this.autoScroll) {
ele.scrollTop = ele.scrollHeight
}
}
},
}
@@ -1,7 +1,7 @@
<template>
<div class="list-config-item">
<h3>{{ label }}</h3>
<v-list dense style="background-color: transparent;max-height: 300px; overflow-y: scroll;" >
<v-list dense style="background-color: transparent;max-height: 300px; overflow-y: auto;" >
<v-list-item v-for="(item, index) in items" :key="index">
<v-list-item-content style="display: flex; justify-content: space-between;">
<v-list-item-title>
@@ -4,17 +4,6 @@
<v-card-title>正在等待 AstrBot 重启...</v-card-title>
<v-card-text>
<v-progress-linear indeterminate color="primary"></v-progress-linear>
<div style="margin-top: 16px;">
<div class="py-12 text-center" v-if="newStartTime != -1">
<v-icon class="mb-6" color="success" icon="mdi-check-circle-outline" size="128"></v-icon>
<p>重启成功</p>
</div>
<small v-if="startTime != -1" style="display: block;">当前实例标识{{ startTime }}</small>
<small v-if="newStartTime != -1" style="display: block;">检查到新实例{{ newStartTime }}即将自动刷新页面</small>
<small v-if="status" style="display: block;">{{ status }}</small>
<small style="display: block;">尝试次数{{ cnt }} / 60</small>
</div>
</v-card-text>
</v-card>
</v-dialog>
@@ -73,11 +62,9 @@ export default {
if (this.newStartTime !== this.startTime) {
this.newStartTime = newStartTime
console.log('wfr: restarted')
setTimeout(() => {
this.visible = false
// reload
window.location.reload()
}, 2000)
this.visible = false
// reload
window.location.reload()
}
return this.newStartTime
}
@@ -19,6 +19,18 @@ let botCurrVersion = ref('');
let dashboardHasNewVersion = ref(false);
let dashboardCurrentVersion = ref('');
let version = ref('');
let releases = ref([]);
let devCommits = ref([]); // 新增的 ref
let tab = ref(0);
let releasesHeader = [
{ title: '标签', key: 'tag_name' },
{ title: '发布时间', key: 'published_at' },
{ title: '内容', key: 'body' },
{ title: '源码地址', key: 'zipball_url' },
{ title: '操作', key: 'switch' }
];
const open = (link: string) => {
window.open(link, '_blank');
@@ -83,10 +95,46 @@ function checkUpdate() {
});
}
function getReleases() {
axios.get('/api/update/releases')
.then((res) => {
// releases.value = res.data.data;
// 更新 published_at 的时间为本地时间
releases.value = res.data.data.map((item: any) => {
item.published_at = new Date(item.published_at).toLocaleString();
return item;
})
})
.catch((err) => {
console.log(err);
});
}
function getDevCommits() {
fetch('https://api.github.com/repos/Soulter/AstrBot/commits', {
headers: {
'Host': 'api.github.com',
'Referer': 'https://api.github.com'
}
})
.then(response => response.json())
.then(data => {
devCommits.value = data.map((commit: any) => ({
sha: commit.sha,
date: new Date(commit.commit.author.date).toLocaleString(),
message: commit.commit.message
}));
})
.catch(err => {
console.log(err);
});
}
function switchVersion(version: string) {
updateStatus.value = '正在切换版本...';
axios.post('/api/update/do', {
version: version
version: version,
proxy: localStorage.getItem('selectedGitHubProxy') || ''
})
.then((res) => {
updateStatus.value = res.data.message;
@@ -150,10 +198,10 @@ commonStore.getStartTime();
</div>
<v-dialog v-model="updateStatusDialog" width="700">
<v-dialog v-model="updateStatusDialog" width="1000">
<template v-slot:activator="{ props }">
<v-btn @click="checkUpdate" class="text-primary mr-4" color="lightprimary" variant="flat" rounded="sm"
v-bind="props">
<v-btn @click="checkUpdate(); getReleases(); getDevCommits();" class="text-primary mr-4" color="lightprimary"
variant="flat" rounded="sm" v-bind="props">
更新 🔄
</v-btn>
</template>
@@ -163,40 +211,80 @@ commonStore.getStartTime();
</v-card-title>
<v-card-text>
<v-container>
<h3 class="mb-4">升级到项目最新版本</h3>
<small>当前版本 {{ botCurrVersion }}</small>
<div class="mb-4">
<small>会同时尝试更新机器人主程序和管理面板如果您正在使用 Docker 部署也可以重新拉取镜像或者使用 <a
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取</small>
<small>跳到旧版本或者切换到某个版本不会重新下载管理面板文件这可能会造成部分数据显示错误您可在 <a href="https://github.com/Soulter/AstrBot/releases">此处</a>
找到对应的面板文件 dist.zip解压后替换 data/dist 文件夹即可当然前端源代码在 dashboard 目录下你也可以自己使用 npm install npm build 构建</small>
</div>
<p>{{ updateStatus }}</p>
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
:disabled="!hasNewVersion">
更新到最新版本
<v-tabs v-model="tab">
<v-tab value="0">正式版</v-tab>
<v-tab value="1">开发版(master 分支)</v-tab>
</v-tabs>
<v-tabs-window v-model="tab">
<!-- 发行版 -->
<v-tabs-window-item key="0" v-show="tab == 0">
<small>当前版本 {{ botCurrVersion }}</small>
<p>{{ updateStatus }}</p>
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
:disabled="!hasNewVersion">
更新到最新版本
</v-btn>
<div class="mb-4">
<small>`更新到最新版本` 按钮会同时尝试更新机器人主程序和管理面板如果您正在使用 Docker 部署也可以重新拉取镜像或者使用 <a
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取</small>
</div>
<v-data-table :headers="releasesHeader" :items="releases" item-key="name">
<template v-slot:item.body="{ item }: { item: { body: string } }">
<v-tooltip :text="item.body">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" rounded="xl" variant="tonal" color="primary" size="small">查看</v-btn>
</template>
</v-tooltip>
</template>
<template v-slot:item.switch="{ item }: { item: { tag_name: string } }">
<v-btn @click="switchVersion(item.tag_name)" rounded="xl" variant="plain" color="primary">
切换
</v-btn>
</template>
</v-data-table>
</v-tabs-window-item>
<!-- 开发版 -->
<v-tabs-window-item key="1" v-show="tab == 1">
<div style="margin-top: 16px;">
<v-data-table
:headers="[{ title: 'SHA', key: 'sha' }, { title: '日期', key: 'date' }, { title: '信息', key: 'message' }, { title: '操作', key: 'switch' }]"
:items="devCommits" item-key="sha">
<template v-slot:item.switch="{ item }: { item: { sha: string } }">
<v-btn @click="switchVersion(item.sha)" rounded="xl" variant="plain" color="primary">
切换
</v-btn>
</template>
</v-data-table>
</div>
</v-tabs-window-item>
</v-tabs-window>
<h3 class="mb-4">手动输入版本号或 Commit SHA</h3>
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
variant="outlined"></v-text-field>
<div class="mb-4">
<small> v3.3.16 (不带 SHA) 42e5ec5d80b93b6bfe8b566754d45ffac4c3fe0b</small>
<br>
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>查看 master 分支提交记录点击右边的 copy
即可复制</small></a>
</div>
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
确定切换
</v-btn>
<v-divider></v-divider>
<div style="margin-top: 16px;">
<h3 class="mb-4">切换到项目指定版本或指定提交</h3>
<div class="mb-4">
<small>跳到旧版本不会重新下载管理面板文件这可能会造成部分数据显示错误您可在 <a href="https://github.com/Soulter/AstrBot/releases">此处</a>
找到对应的面板文件 dist.zip解压后替换 data/dist 文件夹即可</small>
</div>
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
variant="outlined"></v-text-field>
<div class="mb-4">
<small> v3.3.16 (不带 SHA) 42e5ec5d80b93b6bfe8b566754d45ffac4c3fe0b</small>
<br>
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>查看 master 分支提交记录点击右边的 copy
即可复制</small></a>
</div>
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
确定切换
</v-btn>
</div>
<v-divider></v-divider>
<div style="margin-top: 16px;">
<h3 class="mb-4">更新管理面板到最新版本</h3>
<h3 class="mb-4">单独更新管理面板到最新版本</h3>
<div class="mb-4">
<small>当前版本 {{ dashboardCurrentVersion }}</small>
<br>
@@ -267,4 +355,4 @@ commonStore.getStartTime();
</v-card>
</v-dialog>
</v-app-bar>
</template>
</template>
@@ -21,7 +21,17 @@ const sidebarItem: menu[] = [
to: '/dashboard/default'
},
{
title: '配置文件',
title: '消息平台',
icon: 'mdi-message-processing',
to: '/platforms',
},
{
title: '服务提供商',
icon: 'mdi-creation',
to: '/providers',
},
{
title: '配置',
icon: 'mdi-cog',
to: '/config',
},
+10 -1
View File
@@ -16,12 +16,21 @@ const MainRoutes = {
path: '/extension',
component: () => import('@/views/ExtensionPage.vue')
},
{
name: 'Platforms',
path: '/platforms',
component: () => import('@/views/PlatformPage.vue')
},
{
name: 'Providers',
path: '/providers',
component: () => import('@/views/ProviderPage.vue')
},
{
name: 'Configs',
path: '/config',
component: () => import('@/views/ConfigPage.vue')
},
{
name: 'Default',
path: '/dashboard/default',
+28 -5
View File
@@ -1,18 +1,41 @@
<template>
<v-card style="height: 100%;">
<v-card-text style="padding: 0; height: 100%;">
<v-card-text style="padding: 0; height: 100%; overflow-y: auto;">
<div
style="display: flex; justify-content: center; align-items: center; height: 100%; flex-direction: column;">
<div @click="selectedLogo = selectedLogo == 0 ? 1 : 0" style="height: 300px;">
<img v-if="selectedLogo == 0" width="300" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo" class="fade-in">
<img v-if="selectedLogo == 1" width="300" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo" class="fade-in">
<img v-if="selectedLogo == 0" width="300" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo"
class="fade-in">
<img v-if="selectedLogo == 1" width="300" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo"
class="fade-in">
</div>
<h1 class="mt-8">AstrBot</h1>
<span style="color: #777;" class="mt-4">By <a href="https://soulter.top">Soulter</a> And <a href="https://github.com/Soulter/AstrBot/graphs/contributors">AstrBot Contributors</a></span>
<span class="mt-2" style="color: #777;">A project out of interests and loves </span>
<v-btn class="text-primary mt-16" @click="open('https://github.com/Soulter/AstrBot')"
<span style="color: #777; margin-left: 32px; margin-right: 32px" class="mt-4">By <a
href="https://soulter.top">Soulter</a>, <a
href="https://github.com/Soulter/AstrBot/graphs/contributors">AstrBot Contributors</a>
and <a href="https://github.com/Soulter/AstrBot_Plugins_Collection/graphs/contributors">AstrBot
Plugin Authors</a>
</span>
<!-- Copy-paste in your Readme.md file -->
<img style="margin-top: 16px; width: 50%; max-width: 500px; margin-left: 32px; margin-right: 32px"
alt="Active Contributors of Soulter/AstrBot - Last 28 days"
src="https://next.ossinsight.io/widgets/official/compose-recent-active-contributors/thumbnail.png?repo_id=575865240&limit=365&image_size=auto&color_scheme=light">
<img style="margin-top: 16px; width: 50%; max-width: 500px; margin-left: 32px; margin-right: 32px"
alt="Active Contributors of Soulter/AstrBot - Last 28 days"
src="https://next.ossinsight.io/widgets/official/analyze-repo-stars-map/thumbnail.png?activity=stars&repo_id=575865240&image_size=auto&color_scheme=light
">
<!-- Made with [OSS Insight](https://ossinsight.io/) -->
<v-btn class="text-primary mt-8" @click="open('https://github.com/Soulter/AstrBot')"
color="lightprimary" variant="flat" rounded="sm">
Star 这个项目! 🌟
</v-btn>
+7 -1
View File
@@ -56,7 +56,7 @@ marked.setOptions({
<div style="margin-top: 8px; color: #aaa;">
<span>输入</span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help</span>
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">help</span>
<span>获取帮助 😊</span>
</div>
<div style="margin-top: 8px; color: #aaa;">
@@ -65,6 +65,12 @@ marked.setOptions({
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">K</span>
<span>开始语音 🎤</span>
</div>
<div style="margin-top: 8px; color: #aaa;">
<span>按</span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">Ctrl + V</span>
<span>粘贴图片 🏞️</span>
</div>
</div>
<div v-else style="max-height: 100%; padding: 16px; max-width: 700px; margin: 0 auto;">
+5
View File
@@ -44,6 +44,11 @@ import config from '@/config';
</v-expansion-panel-title>
<v-expansion-panel-text v-if="metadata[key]['metadata'][key2]?.config_template">
<!-- 带有 config_template 的配置项 -->
<v-alert style="margin-top: 16px; margin-bottom: 16px" color="primary" variant="tonal" v-if="key2 === 'platform' || key2 === 'provider'">
😄 消息平台适配器和服务提供商的配置已经迁移至更方便的独立页面推荐前往左栏配置哦
</v-alert>
<v-tabs style="margin-top: 16px;" align-tabs="left" color="deep-purple-accent-4" v-model="config_template_tab">
<v-tab v-if="metadata[key]['metadata'][key2]?.tmpl_display_title" v-for="(item, index) in config_data[key2]" :key="index" :value="index">
{{ item[metadata[key]['metadata'][key2]?.tmpl_display_title] }}
+43 -28
View File
@@ -9,34 +9,42 @@ import axios from 'axios';
<div
style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px; display: flex; flex-direction: row; align-items: center; justify-content: space-between;">
<h4>控制台</h4>
<v-dialog v-model="pipDialog" width="400">
<template v-slot:activator="{ props }">
<v-btn variant="plain" v-bind="props">安装 pip </v-btn>
</template>
<v-card>
<v-card-title>
<span class="text-h5">安装 Pip </span>
</v-card-title>
<v-card-text>
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
<small>如果不填镜像站链接默认使用阿里云镜像https://mirrors.aliyun.com/pypi/simple/</small>
<div>
<small>{{ status }}</small>
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="pipInstall" :loading="loading">
安装
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<div class="d-flex align-center">
<v-switch
v-model="autoScrollDisabled"
:label="autoScrollDisabled ? '自动滚动已关闭' : '自动滚动已开启'"
hide-details
density="compact"
style="margin-right: 16px;"
></v-switch>
<v-dialog v-model="pipDialog" width="400">
<template v-slot:activator="{ props }">
<v-btn variant="plain" v-bind="props">安装 pip </v-btn>
</template>
<v-card>
<v-card-title>
<span class="text-h5">安装 Pip </span>
</v-card-title>
<v-card-text>
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
<small>如果不填镜像站链接默认使用阿里云镜像https://mirrors.aliyun.com/pypi/simple/</small>
<div>
<small>{{ status }}</small>
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="pipInstall" :loading="loading">
安装
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
</div>
<ConsoleDisplayer style="height: calc(100vh - 160px); " />
<ConsoleDisplayer ref="consoleDisplayer" style="height: calc(100vh - 160px); " />
</div>
</template>
<script>
@@ -47,6 +55,7 @@ export default {
},
data() {
return {
autoScrollDisabled: false,
pipDialog: false,
pipInstallPayload: {
package: '',
@@ -56,7 +65,13 @@ export default {
status: ''
}
},
watch: {
autoScrollDisabled(val) {
if (this.$refs.consoleDisplayer) {
this.$refs.consoleDisplayer.autoScroll = !val;
}
}
},
methods: {
pipInstall() {
this.loading = true;
+241
View File
@@ -0,0 +1,241 @@
<template>
<v-card style="height: 100%;">
<v-card-text style="padding: 32px; height: 100%;">
<v-menu>
<template v-slot:activator="{ props }">
<v-btn class="flex-grow-1" variant="tonal" @click="new_platform_dialog = true" size="large"
rounded="lg" v-bind="props" color="primary">
<template v-slot:default>
<v-icon>mdi-plus</v-icon>
新增平台适配器
</template>
</v-btn>
</template>
<v-list @update:selected="addFromDefaultConfigTmpl($event)">
<v-list-item
v-for="(item, index) in metadata['platform_group']['metadata']['platform'].config_template"
:key="index" rounded="xl" :value="index">
<v-list-item-title>{{ index }}</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
<v-row style="margin-top: 16px;">
<v-col v-for="(platform, index) in config_data['platform']" :key="index" cols="12" md="6" lg="3">
<v-card class="fade-in"
style="margin-bottom: 16px; min-height: 200px; display: flex; justify-content: space-between; flex-direction: column;">
<v-card-title class="d-flex justify-space-between align-center">
<span class="text-h4">{{ platform.id }}</span>
<v-switch color="primary" hide-details density="compact" v-model="platform['enable']"
@update:modelValue="platformStatusChange(platform)"></v-switch>
</v-card-title>
<v-card-text>
<div>
<span style="font-size:12px">适配器类型: </span>
<v-chip size="small" color="primary" text>{{ platform.type }}</v-chip>
</div>
</v-card-text>
<v-card-actions class="d-flex justify-end">
<v-btn color="error" text @click="deletePlatform(platform.id);">
删除
</v-btn>
<v-btn color="blue-darken-1" text
@click="updatingMode = true; showPlatformCfg = true; newSelectedPlatformConfig = platform; newSelectedPlatformName = platform.id">
配置
</v-btn>
</v-card-actions>
</v-card>
</v-col>
</v-row>
<v-dialog v-model="showPlatformCfg" width="700">
<v-card>
<v-card-title>
<span class="text-h4">{{ newSelectedPlatformName }} 配置</span>
</v-card-title>
<v-card-text>
<AstrBotConfig :iterable="newSelectedPlatformConfig"
:metadata="metadata['platform_group']['metadata']" metadataKey="platform" />
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="newPlatform" :loading="loading">
保存
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<v-btn style="margin-top: 16px" class="flex-grow-1" variant="tonal" size="large" rounded="lg" color="gray" @click="showConsole = !showConsole">
<template v-slot:default>
<v-icon>mdi-console-line</v-icon>
{{ showConsole ? '隐藏' : '显示' }}日志
</template>
</v-btn>
<div v-if="showConsole" style="margin-top: 32px; ">
<ConsoleDisplayer style="background-color: #fff; height: 300px"></ConsoleDisplayer>
</div>
</v-card-text>
</v-card>
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack">
{{ save_message }}
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script>
import axios from 'axios';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
export default {
name: 'PlatformPage',
components: {
AstrBotConfig,
WaitingForRestart,
ConsoleDisplayer
},
data() {
return {
config_data: {},
fetched: false,
metadata: {},
showPlatformCfg: false,
newSelectedPlatformName: '',
newSelectedPlatformConfig: {},
updatingMode: false,
loading: false,
save_message_snack: false,
save_message: "",
save_message_success: "",
showConsole: false,
}
},
mounted() {
this.getConfig();
},
methods: {
getConfig() {
// 获取配置
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
this.fetched = true
this.metadata = res.data.data.metadata;
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
},
addFromDefaultConfigTmpl(index) {
// 从默认配置模板中添加
console.log(index);
this.newSelectedPlatformName = index[0];
this.showPlatformCfg = true;
this.updatingMode = false;
this.newSelectedPlatformConfig = this.metadata['platform_group']['metadata']['platform'].config_template[index[0]];
},
newPlatform() {
// 新建或者更新平台
this.loading = true;
if (this.updatingMode) {
axios.post('/api/config/platform/update', {
id: this.newSelectedPlatformName,
config: this.newSelectedPlatformConfig
}).then((res) => {
this.loading = false;
this.showPlatformCfg = false;
this.getConfig();
this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
this.updatingMode = false;
} else {
axios.post('/api/config/platform/new', this.newSelectedPlatformConfig).then((res) => {
this.loading = false;
this.showPlatformCfg = false;
this.getConfig();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
},
deletePlatform(platform_id) {
// 删除平台
axios.post('/api/config/platform/delete', { id: platform_id }).then((res) => {
this.getConfig();
this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
},
platformStatusChange(platform) {
// 平台状态改变
axios.post('/api/config/platform/update', {
id: platform.id,
config: platform
}).then((res) => {
this.getConfig();
this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
</style>
+240
View File
@@ -0,0 +1,240 @@
<template>
<v-card style="height: 100%;">
<v-card-text style="padding: 32px; height: 100%;">
<v-menu>
<template v-slot:activator="{ props }">
<v-btn class="flex-grow-1" variant="tonal" @click="new_provider_dialog = true" size="large"
rounded="lg" v-bind="props" color="primary">
<template v-slot:default>
<v-icon>mdi-plus</v-icon>
新增服务提供商
</template>
</v-btn>
</template>
<v-list @update:selected="addFromDefaultConfigTmpl($event)">
<v-list-item
v-for="(item, index) in metadata['provider_group']['metadata']['provider'].config_template"
:key="index" rounded="xl" :value="index">
<v-list-item-title>{{ index }}</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
<v-row style="margin-top: 16px;">
<v-col v-for="(provider, index) in config_data['provider']" :key="index" cols="12" md="6" lg="3">
<v-card class="fade-in" style="margin-bottom: 16px; min-height: 200px; display: flex; justify-content: space-between; flex-direction: column;">
<v-card-title class="d-flex justify-space-between align-center">
<span class="text-h4">{{ provider.id }}</span>
<v-switch color="primary" hide-details density="compact" v-model="provider['enable']"
@update:modelValue="providerStatusChange(provider)"></v-switch>
</v-card-title>
<v-card-text>
<div>
<span style="font-size:12px">适配器类型: </span> <v-chip size="small" color="primary" text>{{ provider.type }}</v-chip>
</div>
<div v-if="provider?.api_base" style="margin-top: 8px;">
<span style="font-size:12px">API Base: </span> <v-chip size="small" color="primary" text>{{ provider?.api_base }}</v-chip>
</div>
</v-card-text>
<v-card-actions class="d-flex justify-end">
<v-btn color="error" text @click="deleteprovider(provider.id);">
删除
</v-btn>
<v-btn color="blue-darken-1" text
@click="updatingMode = true; showproviderCfg = true; newSelectedproviderConfig = provider; newSelectedproviderName = provider.id">
配置
</v-btn>
</v-card-actions>
</v-card>
</v-col>
</v-row>
<v-dialog v-model="showproviderCfg" width="700">
<v-card>
<v-card-title>
<span class="text-h4">{{ newSelectedproviderName }} 配置</span>
</v-card-title>
<v-card-text>
<AstrBotConfig :iterable="newSelectedproviderConfig"
:metadata="metadata['provider_group']['metadata']" metadataKey="provider" />
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="newprovider" :loading="loading">
保存
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<v-btn style="margin-top: 16px" class="flex-grow-1" variant="tonal" size="large" rounded="lg" color="gray" @click="showConsole = !showConsole">
<template v-slot:default>
<v-icon>mdi-console-line</v-icon>
{{ showConsole ? '隐藏' : '显示' }}日志
</template>
</v-btn>
<div v-if="showConsole" style="margin-top: 32px; ">
<ConsoleDisplayer style="background-color: #fff; height: 300px"></ConsoleDisplayer>
</div>
</v-card-text>
</v-card>
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack">
{{ save_message }}
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script>
import axios from 'axios';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
export default {
name: 'ProviderPage',
components: {
AstrBotConfig,
WaitingForRestart,
ConsoleDisplayer
},
data() {
return {
config_data: {},
fetched: false,
metadata: {},
showproviderCfg: false,
newSelectedproviderName: '',
newSelectedproviderConfig: {},
updatingMode: false,
loading: false,
save_message_snack: false,
save_message: "",
save_message_success: "",
showConsole: false,
}
},
mounted() {
this.getConfig();
},
methods: {
getConfig() {
// 获取配置
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
this.fetched = true
this.metadata = res.data.data.metadata;
}).catch((err) => {
save_message = err;
save_message_snack = true;
save_message_success = "error";
});
},
addFromDefaultConfigTmpl(index) {
// 从默认配置模板中添加
console.log(index);
this.newSelectedproviderName = index[0];
this.showproviderCfg = true;
this.updatingMode = false;
this.newSelectedproviderConfig = this.metadata['provider_group']['metadata']['provider'].config_template[index[0]];
},
newprovider() {
// 新建或者更新平台
this.loading = true;
if (this.updatingMode) {
axios.post('/api/config/provider/update', {
id: this.newSelectedproviderName,
config: this.newSelectedproviderConfig
}).then((res) => {
this.loading = false;
this.showproviderCfg = false;
this.getConfig();
// this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
this.updatingMode = false;
} else {
axios.post('/api/config/provider/new', this.newSelectedproviderConfig).then((res) => {
this.loading = false;
this.showproviderCfg = false;
this.getConfig();
}).catch((err) => {
this.loading = false;
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
},
deleteprovider(provider_id) {
// 删除平台
axios.post('/api/config/provider/delete', { id: provider_id }).then((res) => {
this.getConfig();
// this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
},
providerStatusChange(provider) {
// 平台状态改变
axios.post('/api/config/provider/update', {
id: provider.id,
config: provider
}).then((res) => {
this.getConfig();
// this.$refs.wfr.check();
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
</style>
+22 -2
View File
@@ -5,23 +5,39 @@
<v-list lines="two">
<v-list-subheader>网络</v-list-subheader>
<v-list-item subtitle="设置下载插件时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效" title="GitHub 加速地址">
<v-list-item subtitle="设置下载插件或者更新 AstrBot 时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效" title="GitHub 加速地址">
<v-combobox variant="outlined" style="width: 100%; margin-top: 16px;" v-model="selectedGitHubProxy" :items="githubProxies"
label="选择 GitHub 加速地址">
</v-combobox>
</v-list-item>
<v-list-subheader>系统</v-list-subheader>
<v-list-item subtitle="重启 AstrBot" title="重启">
<v-btn style="margin-top: 16px;" color="error" @click="restartAstrBot">重启</v-btn>
</v-list-item>
</v-list>
</div>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script>
import axios from 'axios';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
export default {
components: {
WaitingForRestart,
},
data() {
return {
githubProxies: [
@@ -35,7 +51,11 @@ export default {
}
},
methods: {
restartAstrBot() {
axios.post('/api/stat/restart-core').then(() => {
this.$refs.wfr.check();
})
}
},
mounted() {
this.selectedGitHubProxy = localStorage.getItem('selectedGitHubProxy') || "";
-4
View File
@@ -125,10 +125,6 @@ class LongTermMemory:
else:
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
req.system_prompt += chats_str
if self.image_caption:
req.system_prompt += (
"The images sent by the members are displayed in text form above."
)
async def after_req_llm(self, event: AstrMessageEvent):
if event.unified_msg_origin not in self.session_chats:
+61 -25
View File
@@ -83,9 +83,6 @@ AstrBot 指令:
/tool ls: 函数工具
/key: API Key(op)
/websearch: 网页搜索
[其他]
/set 变量名 值: 为会话定义变量(Dify 工作流输入)
{notice}"""
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@@ -96,6 +93,7 @@ AstrBot 指令:
@tool.command("ls")
async def tool_ls(self, event: AstrMessageEvent):
'''查看函数工具列表'''
tm = self.context.get_llm_tool_manager()
msg = "函数工具:\n"
for tool in tm.func_list:
@@ -107,6 +105,7 @@ AstrBot 指令:
@tool.command("on")
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
'''启用一个函数工具'''
if self.context.activate_llm_tool(tool_name):
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 成功。"))
else:
@@ -114,6 +113,7 @@ AstrBot 指令:
@tool.command("off")
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
'''停用一个函数工具'''
if self.context.deactivate_llm_tool(tool_name):
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。"))
else:
@@ -121,6 +121,7 @@ AstrBot 指令:
@tool.command("off_all")
async def tool_all_off(self, event: AstrMessageEvent):
'''停用所有函数工具'''
tm = self.context.get_llm_tool_manager()
for tool in tm.func_list:
self.context.deactivate_llm_tool(tool.name)
@@ -128,6 +129,7 @@ AstrBot 指令:
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
'''插件管理'''
if oper1 is None:
plugin_list_info = "已加载的插件:\n"
for plugin in self.context.get_all_stars():
@@ -189,6 +191,7 @@ AstrBot 指令:
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
'''开关文本转图片'''
config = self.context.get_config()
if config['t2i']:
config['t2i'] = False
@@ -201,6 +204,7 @@ AstrBot 指令:
@filter.command("tts")
async def tts(self, event: AstrMessageEvent):
'''开关文本转语音'''
config = self.context.get_config()
if config['provider_tts_settings']['enable']:
config['provider_tts_settings']['enable'] = False
@@ -213,6 +217,7 @@ AstrBot 指令:
@filter.command("sid")
async def sid(self, event: AstrMessageEvent):
'''获取会话 ID 和 管理员 ID'''
sid = event.unified_msg_origin
user_id = str(event.get_sender_id())
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
@@ -222,6 +227,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str):
'''授权管理员。op <admin_id>'''
self.context.get_config()['admins_id'].append(admin_id)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功。"))
@@ -229,6 +235,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str):
'''取消授权管理员。deop <admin_id>'''
try:
self.context.get_config()['admins_id'].remove(admin_id)
self.context.get_config().save_config()
@@ -340,16 +347,20 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
else:
event.set_result(MessageEventResult().message("无效的参数。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
'''重置 LLM 会话'''
is_unique_session = self.context.get_config()['platform_settings']['unique_session']
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
message.set_result(MessageEventResult().message(f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限重置当前对话。"))
return
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
provider = self.context.get_using_provider()
print(provider.meta())
if provider and provider.meta().type == 'dify':
assert isinstance(provider, ProviderDify)
await provider.forget(message.unified_msg_origin)
@@ -393,6 +404,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
for model in models:
ret += f"\n{i}. {model}"
i += 1
curr_model = self.context.get_using_provider().get_model() or ""
ret += f"\n当前模型: [{curr_model}]"
ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
@@ -418,7 +433,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
message.set_result(
MessageEventResult().message(f"切换模型到 {self.context.get_using_provider().get_model()}"))
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
'''查看对话记录'''
@@ -458,6 +472,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
provider = self.context.get_using_provider()
if provider and provider.meta().type == 'dify':
"""原有的Dify处理逻辑保持不变"""
ret = "Dify 对话列表:\n"
assert isinstance(provider, ProviderDify)
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
@@ -474,32 +489,45 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
return
size_per_page = 6
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
total_pages = len(conversations) // size_per_page
if len(conversations) % size_per_page != 0:
total_pages += 1
conversations = conversations[(page-1)*size_per_page:page*size_per_page]
"""获取所有对话列表"""
conversations_all = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
"""计算总页数"""
total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page
"""确保页码有效"""
page = max(1, min(page, total_pages))
"""分页处理"""
start_idx = (page - 1) * size_per_page
end_idx = start_idx + size_per_page
conversations_paged = conversations_all[start_idx:end_idx]
ret = "对话列表:\n---\n"
global_index = (page - 1) * size_per_page + 1
"""全局序号从当前页的第一个开始"""
global_index = start_idx + 1
"""生成所有对话的标题字典"""
_titles = {}
for conv in conversations:
for conv in conversations_all:
persona_id = conv.persona_id
if not persona_id and not persona_id == "[%None]":
if not persona_id or persona_id == "[%None]":
persona_id = self.context.provider_manager.selected_default_persona['name']
title = conv.title if conv.title else "新对话"
_titles[conv.cid] = title
"""遍历分页后的对话生成列表显示"""
for conv in conversations_paged:
persona_id = conv.persona_id
if not persona_id or persona_id == "[%None]":
persona_id = self.context.provider_manager.selected_default_persona['name']
title = _titles.get(conv.cid, "新对话")
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
global_index += 1
ret += "---\n"
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if curr_cid:
ret += f"\n当前对话: {_titles[curr_cid]}({curr_cid[:4]})"
"""从所有对话的标题字典中获取标题"""
title = _titles.get(curr_cid, "新对话")
ret += f"\n当前对话: {title}({curr_cid[:4]})"
else:
ret += "\n当前对话: 无"
@@ -508,11 +536,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
ret += "\n会话隔离粒度: 个人"
else:
ret += "\n会话隔离粒度: 群聊"
ret += f"\n{page} 页 | 共 {total_pages}"
ret += "\n*输入 /ls 2 跳转到第 2 页"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
return
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent):
@@ -582,10 +611,14 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await self.context.conversation_manager.update_conversation_title(message.unified_msg_origin, new_name)
message.set_result(MessageEventResult().message("重命名对话成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("del")
async def del_conv(self, message: AstrMessageEvent):
'''删除当前对话'''
is_unique_session = self.context.get_config()['platform_settings']['unique_session']
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
message.set_result(MessageEventResult().message(f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。"))
return
provider = self.context.get_using_provider()
if provider and provider.meta().type == 'dify':
@@ -604,7 +637,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid)
message.set_result(MessageEventResult().message("删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int=None):
@@ -753,6 +785,14 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
yield event.plain_result("已登出 gewechat,请重启 AstrBot")
return
@filter.command("gewe_code")
async def gewe_code(self, event: AstrMessageEvent, code: str):
'''保存 gewechat 验证码'''
with open("data/temp/gewe_code", "w", encoding='utf-8') as f:
f.write(code)
yield event.plain_result("验证码已保存。")
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''群聊记忆增强'''
@@ -928,10 +968,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
# def kdb(self):
# pass
# @kdb.command("off")
# async def off_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = False
# yield event.plain_result("知识库已关闭")
# @filter.on_llm_request()
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
+37 -7
View File
@@ -85,7 +85,8 @@ DEFAULT_CONFIG = {
"sandbox": {
"image": "soulter/astrbot-code-interpreter-sandbox",
"docker_mirror": "", # cjie.eu.org
}
},
"docker_host_astrbot_abs_path": ""
}
PATH = "data/config/python_interpreter.json"
@@ -95,8 +96,14 @@ class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
self.workplace_path = os.path.join(self.curr_dir, "workplace")
self.shared_path = os.path.join(self.curr_dir, "shared")
self.shared_path = os.path.join("data", "py_interpreter_shared")
if not os.path.exists(self.shared_path):
# 复制 api.py 到 shared 目录
os.makedirs(self.shared_path, exist_ok=True)
shared_api_file = os.path.join(self.curr_dir, "shared", "api.py")
shutil.copy(shared_api_file, self.shared_path)
self.workplace_path = os.path.join("data", "py_interpreter_workplace")
os.makedirs(self.workplace_path, exist_ok=True)
self.user_file_msg_buffer = defaultdict(list)
@@ -195,7 +202,16 @@ class Main(star.Star):
@filter.command_group("pi")
def pi(self):
pass
@pi.command("absdir")
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
'''设置 Docker 宿主机绝对路径'''
if not path:
yield event.plain_result(f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}")
else:
self.config["docker_host_astrbot_abs_path"] = path
self._save_config()
yield event.plain_result(f"设置 Docker 宿主机绝对路径成功: {path}")
@pi.command("mirror")
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
@@ -305,6 +321,20 @@ class Main(star.Star):
yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})")
self.docker_host_astrbot_abs_path = self.config.get("docker_host_astrbot_abs_path", "")
if self.docker_host_astrbot_abs_path:
host_shared = os.path.join(self.docker_host_astrbot_abs_path, self.shared_path)
host_output = os.path.join(self.docker_host_astrbot_abs_path, output_path)
host_workplace = os.path.join(self.docker_host_astrbot_abs_path, workplace_path)
else:
host_shared = os.path.abspath(self.shared_path)
host_output = os.path.abspath(output_path)
host_workplace = os.path.abspath(workplace_path)
logger.debug(f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}")
container = await docker.containers.run({
"Image": image_name,
"Cmd": ["python", "exec.py"],
@@ -312,9 +342,9 @@ class Main(star.Star):
"NanoCPUs": 1000000000,
"HostConfig": {
"Binds": [
f"{self.shared_path}:/astrbot_sandbox/shared:ro",
f"{output_path}:/astrbot_sandbox/output:rw",
f"{workplace_path}:/astrbot_sandbox:rw",
f"{host_shared}:/astrbot_sandbox/shared:ro",
f"{host_output}:/astrbot_sandbox/output:rw",
f"{host_workplace}:/astrbot_sandbox:rw",
]
},
"Env": [
+12 -4
View File
@@ -13,13 +13,13 @@ class Main(star.Star):
'''使用 LLM 待办提醒。只需对 LLM 说想要提醒的事情和时间即可。比如:`之后每天这个时候都提醒我做多邻国`'''
def __init__(self, context: star.Context) -> None:
self.context = context
self.scheduler = AsyncIOScheduler()
self.scheduler = AsyncIOScheduler(timezone='Asia/Shanghai')
# set and load config
if not os.path.exists("data/astrbot-reminder.json"):
with open("data/astrbot-reminder.json", "w") as f:
with open("data/astrbot-reminder.json", "w", encoding='utf-8') as f:
f.write("{}")
with open("data/astrbot-reminder.json", "r") as f:
with open("data/astrbot-reminder.json", "r", encoding='utf-8') as f:
self.reminder_data = json.load(f)
self._init_scheduler()
@@ -64,7 +64,7 @@ class Main(star.Star):
async def _save_data(self):
'''Save the reminder data.'''
with open("data/astrbot-reminder.json", "w") as f:
with open("data/astrbot-reminder.json", "w", encoding='utf-8') as f:
json.dump(self.reminder_data, f, ensure_ascii=False)
def _parse_cron_expr(self, cron_expr: str):
@@ -175,10 +175,18 @@ class Main(star.Star):
else:
reminder = reminders.pop(index - 1)
job_id = reminder.get("id")
# self.reminder_data[event.unified_msg_origin] = reminder
users_reminders = self.reminder_data.get(event.unified_msg_origin, [])
for i, r in enumerate(users_reminders):
if r.get("id") == job_id:
users_reminders.pop(i)
try:
self.scheduler.remove_job(job_id)
except Exception as e:
logger.error(f"Remove job error: {e}")
yield event.plain_result(f"成功移除对应的待办事项。删除定时任务失败: {str(e)} 可能需要重启 AstrBot 以取消该提醒任务。")
await self._save_data()
yield event.plain_result("成功删除待办事项:\n" + reminder["text"])
+5 -1
View File
@@ -1,6 +1,7 @@
pydantic~=2.10.3
aiohttp
openai
anthropic
qq-botpy
chardet~=5.1.0
Pillow
@@ -17,7 +18,10 @@ apscheduler
docstring_parser
aiodocker
silk-python
psutil>=5.8.0
lark-oapi
ormsgpack
cryptography
cryptography
dashscope