Compare commits

..

22 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] a2fe0ec5a1 Add webhook signature verification for security
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:27:51 +00:00
copilot-swe-agent[bot] 6957ec713d Clean up unused imports in tests
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:24:18 +00:00
copilot-swe-agent[bot] d97c8b5b2b Add tests for GitHub webhook platform adapter
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:23:22 +00:00
copilot-swe-agent[bot] d07a1ad5c9 Add GitHub webhook platform adapter with event handlers
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:20:33 +00:00
copilot-swe-agent[bot] d8e6dfbd6b Initial plan 2025-12-12 14:14:49 +00:00
Soulter 8a0b7717cc feat: supports webhook mode for Lark platform (#4016)
* feat: add Lark platform support with unified webhook configuration

* fix: update token verification logic in LarkWebhookServer

* feat: implement event deduplication and cleanup for Lark webhook events
2025-12-12 22:12:13 +08:00
Copilot 3b81fb4985 fix: mobile dialog close button visibility (#4010)
* Initial plan

* Fix mobile dialog close button visibility by adding max-height and scrollable content

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 16:02:24 +08:00
Soulter c09d57a820 refactor: improve UI layout and interaction for list item management (#4002)
* refactor: improve UI layout and interaction for list item management

* feat: enhance list configuration UI with batch import functionality

* feat: add internationalization support for list configuration UI
2025-12-11 18:55:56 +08:00
Soulter ec408a2aff fix: lark message timestamp 2025-12-11 18:20:50 +08:00
Soulter 417179a6b9 ci: add smoke test 2025-12-11 10:44:15 +08:00
Soulter fcd29445c7 refactor: remove unused current provider initialization in StarRequestSubStage 2025-12-11 10:36:33 +08:00
BiDuang 5f535001db fix: incorrect modalities enum of gemini api provider (#3993) 2025-12-10 20:27:51 +08:00
PaloMiku 750d245b16 docs: Update README with new Zread link and badges (#3992)
ZRead 是由智谱 AI 推出的 DeepWiki 类似平替品。
2025-12-10 20:22:56 +08:00
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00
Soulter aa6d07afcc refactor: move all internal commands from astrbot plugin to default_command plugin (#3960)
* refactor: move all internal commands from astrbot plugin to default_command plugin

* ruff check

* feat: add config

* ruff check
2025-12-08 22:17:32 +08:00
Soulter 2c36649874 feat: add Agent Runner test prompt dialog in ProviderPage (#3968) 2025-12-08 21:46:47 +08:00
Soulter c95735dcc0 docs: update readme 2025-12-08 12:05:57 +08:00
Soulter 03bb278f50 chore: ruff check 2025-12-08 11:00:43 +08:00
Soulter a5e0974da3 chore: ruff format 2025-12-08 00:36:56 +08:00
vmoranv f0fb447fbc feat: custom plugin api source manager (#3956)
* feat: custom plugin api source manager

* fix: rename plugin source file in a safer way

* chore: turned the way of saving plugin source to backend and refacted some components

* style: clean up whitespace and improve logging message formatting

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-08 00:32:50 +08:00
Soulter 37566182b0 feat: segment reply supports segmentation words (#3959)
* feat: segment reply supports segmentation words

* chore: ruff format

* feat: enhance segmented reply processing by refining word extraction logic

* ruff format
2025-12-08 00:27:17 +08:00
Soulter e460b411da chore: remove dev version from webui (#3951)
* chore: remove dev version

* chore: remove development version references from header localization files
2025-12-07 15:23:30 +08:00
144 changed files with 3510 additions and 1113 deletions
+58
View File
@@ -0,0 +1,58 @@
name: Smoke Test
on:
push:
branches:
- master
paths-ignore:
- 'README*.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
smoke-test:
name: Run smoke tests
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install UV package manager
run: |
pip install uv
- name: Install dependencies
run: |
uv sync
timeout-minutes: 15
- name: Run smoke tests
run: |
uv run main.py &
APP_PID=$!
echo "Waiting for application to start..."
for i in {1..60}; do
if curl -f http://localhost:6185 > /dev/null 2>&1; then
echo "Application started successfully!"
kill $APP_PID
exit 0
fi
sleep 1
done
echo "Application failed to start within 30 seconds"
kill $APP_PID 2>/dev/null || true
exit 1
timeout-minutes: 2
+2
View File
@@ -20,6 +20,7 @@
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest"> <img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python"> <img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot"> <img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a> <a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600"> <img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot"> <img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
@@ -206,6 +207,7 @@ pre-commit install
- 3 群:630166526 - 3 群:630166526
- 5 群:822130018 - 5 群:822130018
- 6 群:753075035 - 6 群:753075035
- 7 群:743746109
- 开发者群:975206796 - 开发者群:975206796
### Telegram 群组 ### Telegram 群组
@@ -97,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp_result = None llm_resp_result = None
async for llm_response in self._iter_llm_responses(): async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk: if llm_response.is_chunk:
if llm_response.result_chain: if llm_response.result_chain:
yield AgentResponse( yield AgentResponse(
+7 -2
View File
@@ -1,4 +1,4 @@
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic from typing import Any, Generic
import jsonschema import jsonschema
@@ -7,6 +7,8 @@ from deprecated import deprecated
from pydantic import Field, model_validator from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageEventResult
from .run_context import ContextWrapper, TContext from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any] ParametersType = dict[str, Any]
@@ -38,7 +40,10 @@ class ToolSchema:
class FunctionTool(ToolSchema, Generic[TContext]): class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling.""" """A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None handler: (
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
| None
) = None
"""a callable that implements the tool's functionality. It should be an async function.""" """a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None handler_module_path: str | None = None
+5 -1
View File
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def call_local_llm_tool( async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext], context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]], handler: T.Callable[
...,
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
],
method_name: str, method_name: str,
*args, *args,
**kwargs, **kwargs,
+4
View File
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
""" """
config_path: str
default_config: dict
schema: dict | None
def __init__( def __init__(
self, self,
config_path: str = ASTRBOT_CONFIG_PATH, config_path: str = ASTRBOT_CONFIG_PATH,
+58
View File
@@ -13,6 +13,7 @@ WEBHOOK_SUPPORTED_PLATFORMS = [
"wecom", "wecom",
"wecom_ai_bot", "wecom_ai_bot",
"slack", "slack",
"lark",
] ]
# 默认配置 # 默认配置
@@ -42,7 +43,15 @@ DEFAULT_CONFIG = {
"interval": "1.5,3.5", "interval": "1.5,3.5",
"log_base": 2.6, "log_base": 2.6,
"words_count_threshold": 150, "words_count_threshold": 150,
"split_mode": "regex", # regex 或 words
"regex": ".*?[。?!~…]+|.+$", "regex": ".*?[。?!~…]+|.+$",
"split_words": [
"",
"",
"",
"~",
"",
], # 当 split_mode 为 words 时使用
"content_cleanup_rule": "", "content_cleanup_rule": "",
}, },
"no_permission_reply": True, "no_permission_reply": True,
@@ -157,6 +166,7 @@ DEFAULT_CONFIG = {
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量 "kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量 "kb_final_top_k": 5, # 知识库检索最终返回结果数量
"kb_agentic_mode": False, "kb_agentic_mode": False,
"disable_builtin_commands": False,
} }
@@ -268,6 +278,10 @@ CONFIG_METADATA_2 = {
"app_id": "", "app_id": "",
"app_secret": "", "app_secret": "",
"domain": "https://open.feishu.cn", "domain": "https://open.feishu.cn",
"lark_connection_mode": "socket", # webhook, socket
"webhook_uuid": "",
"lark_encrypt_key": "",
"lark_verification_token": "",
}, },
"钉钉(DingTalk)": { "钉钉(DingTalk)": {
"id": "dingtalk", "id": "dingtalk",
@@ -361,6 +375,28 @@ CONFIG_METADATA_2 = {
# "type": "string", # "type": "string",
# "options": ["fullscreen", "embedded"], # "options": ["fullscreen", "embedded"],
# }, # },
"lark_connection_mode": {
"description": "订阅方式",
"type": "string",
"options": ["socket", "webhook"],
"labels": ["长连接模式", "推送至服务器模式"],
},
"lark_encrypt_key": {
"description": "Encrypt Key",
"type": "string",
"hint": "用于解密飞书回调数据的加密密钥",
"condition": {
"lark_connection_mode": "webhook",
},
},
"lark_verification_token": {
"description": "Verification Token",
"type": "string",
"hint": "用于验证飞书回调请求的令牌",
"condition": {
"lark_connection_mode": "webhook",
},
},
"is_sandbox": { "is_sandbox": {
"description": "沙箱模式", "description": "沙箱模式",
"type": "bool", "type": "bool",
@@ -2661,6 +2697,11 @@ CONFIG_METADATA_3 = {
"description": "只 @ 机器人是否触发等待", "description": "只 @ 机器人是否触发等待",
"type": "bool", "type": "bool",
}, },
"disable_builtin_commands": {
"description": "禁用自带指令",
"type": "bool",
"hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
},
}, },
}, },
"whitelist": { "whitelist": {
@@ -2875,9 +2916,26 @@ CONFIG_METADATA_3 = {
"description": "分段回复字数阈值", "description": "分段回复字数阈值",
"type": "int", "type": "int",
}, },
"platform_settings.segmented_reply.split_mode": {
"description": "分段模式",
"type": "string",
"options": ["regex", "words"],
"labels": ["正则表达式", "分段词列表"],
},
"platform_settings.segmented_reply.regex": { "platform_settings.segmented_reply.regex": {
"description": "分段正则表达式", "description": "分段正则表达式",
"type": "string", "type": "string",
"condition": {
"platform_settings.segmented_reply.split_mode": "regex",
},
},
"platform_settings.segmented_reply.split_words": {
"description": "分段词列表",
"type": "list",
"hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
"condition": {
"platform_settings.segmented_reply.split_mode": "words",
},
}, },
"platform_settings.segmented_reply.content_cleanup_rule": { "platform_settings.segmented_reply.content_cleanup_rule": {
"description": "内容过滤正则表达式", "description": "内容过滤正则表达式",
+1 -1
View File
@@ -197,7 +197,7 @@ class AstrBotCoreLifecycle:
# 把插件中注册的所有协程函数注册到事件总线中并执行 # 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = [] extra_tasks = []
for task in self.star_context._register_tasks: for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
tasks_ = [event_bus_task, *extra_tasks] tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_: for task in tasks_:
+2 -3
View File
@@ -5,8 +5,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from deprecated import deprecated from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment, Attachment,
@@ -32,7 +31,7 @@ class BaseDatabase(abc.ABC):
echo=False, echo=False,
future=True, future=True,
) )
self.AsyncSessionLocal = sessionmaker( self.AsyncSessionLocal = async_sessionmaker(
self.engine, self.engine,
class_=AsyncSession, class_=AsyncSession,
expire_on_commit=False, expire_on_commit=False,
@@ -70,6 +70,7 @@ async def migration_conversation_table(
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
) )
continue
if ":" not in conv.user_id: if ":" not in conv.user_id:
continue continue
session = MessageSesion.from_str(session_str=conv.user_id) session = MessageSesion.from_str(session_str=conv.user_id)
@@ -207,6 +208,7 @@ async def migration_webchat_data(
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
) )
continue
if ":" in conv.user_id: if ":" in conv.user_id:
continue continue
platform_id = "webchat" platform_id = "webchat"
+6 -4
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str conn.text_factory = str
return conn return conn
def _exec_sql(self, sql: str, params: tuple = None): def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn conn = self.conn
try: try:
c = self.conn.cursor() c = self.conn.cursor()
@@ -224,9 +224,11 @@ class SQLiteDatabase:
c.close() c.close()
return Stats(platform, [], []) return Stats(platform)
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: def get_conversation_by_user_id(
self, user_id: str, cid: str
) -> Conversation | None:
try: try:
c = self.conn.cursor() c = self.conn.cursor()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
@@ -258,7 +260,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at), (user_id, cid, history, updated_at, created_at),
) )
def get_conversations(self, user_id: str) -> tuple: def get_conversations(self, user_id: str) -> list[Conversation]:
try: try:
c = self.conn.cursor() c = self.conn.cursor()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
+16 -15
View File
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here. Note: In astrbot v4, we moved `platform` table to here.
""" """
__tablename__ = "platform_stats" # type: ignore __tablename__: str = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False) timestamp: datetime = Field(nullable=False)
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True): class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" # type: ignore __tablename__: str = "conversations"
inner_conversation_id: int = Field( inner_conversation_id: int | None = Field(
default=None,
primary_key=True, primary_key=True,
sa_column_kwargs={"autoincrement": True}, sa_column_kwargs={"autoincrement": True},
) )
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs. It can be used to customize the behavior of LLMs.
""" """
__tablename__ = "personas" # type: ignore __tablename__: str = "personas"
id: int | None = Field( id: int | None = Field(
primary_key=True, primary_key=True,
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True): class Preference(SQLModel, table=True):
"""This class represents preferences for bots.""" """This class represents preferences for bots."""
__tablename__ = "preferences" # type: ignore __tablename__: str = "preferences"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages. or platform-specific messages.
""" """
__tablename__ = "platform_message_history" # type: ignore __tablename__: str = "platform_message_history"
id: int | None = Field( id: int | None = Field(
primary_key=True, primary_key=True,
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
Each session can have multiple conversations (对话) associated with it. Each session can have multiple conversations (对话) associated with it.
""" """
__tablename__ = "platform_sessions" # type: ignore __tablename__: str = "platform_sessions"
inner_id: int | None = Field( inner_id: int | None = Field(
primary_key=True, primary_key=True,
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
Attachments can be images, files, or other media types. Attachments can be images, files, or other media types.
""" """
__tablename__ = "attachments" # type: ignore __tablename__: str = "attachments"
inner_attachment_id: int | None = Field( inner_attachment_id: int | None = Field(
primary_key=True, primary_key=True,
@@ -261,17 +262,17 @@ class Personality(TypedDict):
v4.0.0 版本及之后推荐使用上面的 Persona 并且 mood_imitation_dialogs 字段已被废弃 v4.0.0 版本及之后推荐使用上面的 Persona 并且 mood_imitation_dialogs 字段已被废弃
""" """
prompt: str = "" prompt: str
name: str = "" name: str
begin_dialogs: list[str] = [] begin_dialogs: list[str]
mood_imitation_dialogs: list[str] = [] mood_imitation_dialogs: list[str]
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。""" """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None tools: list[str] | None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具""" """工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache # cache
_begin_dialogs_processed: list[dict] = [] _begin_dialogs_processed: list[dict]
_mood_imitation_dialogs_processed: str = "" _mood_imitation_dialogs_processed: str
# ==== # ====
+4 -3
View File
@@ -3,6 +3,7 @@ import threading
import typing as T import typing as T
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update from sqlmodel import col, delete, desc, func, or_, select, text, update
@@ -489,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
query = select(Attachment).where( query = select(Attachment).where(
Attachment.attachment_id.in_(attachment_ids) col(Attachment.attachment_id).in_(attachment_ids)
) )
result = await session.execute(query) result = await session.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
@@ -505,7 +506,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where( query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id col(Attachment.attachment_id) == attachment_id
) )
result = await session.execute(query) result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0 return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int: async def delete_attachments(self, attachment_ids: list[str]) -> int:
@@ -521,7 +522,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where( query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids) col(Attachment.attachment_id).in_(attachment_ids)
) )
result = await session.execute(query) result = T.cast(CursorResult, await session.execute(query))
return result.rowcount return result.rowcount
async def insert_persona( async def insert_persona(
@@ -90,4 +90,6 @@ class EmbeddingStorage:
path (str): 保存索引的路径 path (str): 保存索引的路径
""" """
if self.index is None:
return
faiss.write_index(self.index, self.path) faiss.write_index(self.index, self.path)
+6 -1
View File
@@ -27,7 +27,7 @@ class EventBus:
self, self,
event_queue: Queue, event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler], pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None, astrbot_config_mgr: AstrBotConfigManager,
): ):
self.event_queue = event_queue # 事件队列 self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler # abconf uuid -> scheduler
@@ -40,6 +40,11 @@ class EventBus:
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"]) self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
if not scheduler:
logger.error(
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
)
continue
asyncio.create_task(scheduler.execute(event)) asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str): def _print_event(self, event: AstrMessageEvent, conf_name: str):
@@ -166,7 +166,11 @@ class RetrievalManager:
# 5. Rerank # 5. Rerank
first_rerank = None first_rerank = None
for kb_id in kb_ids: for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"] rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if ( if (
vec_db vec_db
+8 -2
View File
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel): class BaseMessageComponent(BaseModel):
type: ComponentType type: ComponentType
def __init__(self, **kwargs):
super().__init__(**kwargs)
def toDict(self): def toDict(self):
data = {} data = {}
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
id: int | None = 0 # 忽略 id: int | None = 0 # 忽略
name: str | None = "" # qq昵称 name: str | None = "" # qq昵称
uin: str | None = "0" # qq号 uin: str | None = "0" # qq号
content: list[BaseMessageComponent] | None = [] content: list[BaseMessageComponent] = []
seq: str | list | None = "" # 忽略 seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略 time: int | None = 0 # 忽略
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
ret["messages"].append(d) ret["messages"].append(d)
return ret return ret
async def to_dict(self): async def to_dict(self) -> dict:
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []} ret = {"messages": []}
for node in self.nodes: for node in self.nodes:
@@ -714,12 +717,15 @@ class File(BaseMessageComponent):
if self.url: if self.url:
await self._download_file() await self._download_file()
if self.file_:
return os.path.abspath(self.file_) return os.path.abspath(self.file_)
return "" return ""
async def _download_file(self): async def _download_file(self):
"""下载文件""" """下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True) os.makedirs(download_dir, exist_ok=True)
if self.name: if self.name:
+2 -2
View File
@@ -98,8 +98,8 @@ class PersonaManager:
self, self,
persona_id: str, persona_id: str,
system_prompt: str, system_prompt: str,
begin_dialogs: list[str] = None, begin_dialogs: list[str] | None = None,
tools: list[str] = None, tools: list[str] | None = None,
) -> Persona: ) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id): if await self.db.get_persona_by_id(persona_id):
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
check_text: str | None = None, check_text: str | None = None,
) -> None | AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
"""检查内容安全""" """检查内容安全"""
text = check_text if check_text else event.get_message_str() text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text) ok, info = self.strategy_selector.check(text)
+2 -1
View File
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
async def call_handler( async def call_handler(
event: AstrMessageEvent, event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]], handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
*args, *args,
**kwargs, **kwargs,
) -> T.AsyncGenerator[T.Any, None]: ) -> T.AsyncGenerator[T.Any, None]:
@@ -91,6 +91,7 @@ async def call_event_hook(
) )
for handler in handlers: for handler in handlers:
try: try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug( logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
) )
@@ -16,7 +16,6 @@ from ..stage import Stage
class StarRequestSubStage(Stage): class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None: async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"] self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"] self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
self.ctx = ctx self.ctx = ctx
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
async def process( async def process(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[Any, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra( activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers", "activated_handlers",
) )
+1 -1
View File
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
): ):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if ( if (
event.get_result() and not event.get_result().is_stopped() event.get_result() and not event.is_stopped()
) or not event.get_result(): ) or not event.get_result():
async for _ in self.agent_sub_stage.process(event): async for _ in self.agent_sub_stage.process(event):
yield yield
+4 -2
View File
@@ -117,7 +117,9 @@ class RespondStage(Stage):
if not self.enable_seg: if not self.enable_seg:
return False return False
if self.only_llm_result and not event.get_result().is_llm_result(): if (result := event.get_result()) is None:
return False
if self.only_llm_result and result.is_llm_result():
return False return False
if event.get_platform_name() in [ if event.get_platform_name() in [
@@ -185,7 +187,7 @@ class RespondStage(Stage):
if isinstance(component, Comp.File) and component.file: if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。 # 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file) component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component result.chain[idx] = component
# 检查消息链是否为空 # 检查消息链是否为空
try: try:
+48 -1
View File
@@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.session_llm_manager import SessionServiceManager
@@ -53,7 +54,22 @@ class ResultDecorateStage(Stage):
self.only_llm_result = ctx.astrbot_config["platform_settings"][ self.only_llm_result = ctx.astrbot_config["platform_settings"][
"segmented_reply" "segmented_reply"
]["only_llm_result"] ]["only_llm_result"]
self.split_mode = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_mode", "regex")
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"] self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
self.split_words = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_words", ["", "", "", "~", ""])
if self.split_words:
escaped_words = sorted(
[re.escape(word) for word in self.split_words], key=len, reverse=True
)
self.split_words_pattern = re.compile(
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
)
else:
self.split_words_pattern = None
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][ self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
"segmented_reply" "segmented_reply"
]["content_cleanup_rule"] ]["content_cleanup_rule"]
@@ -69,6 +85,28 @@ class ResultDecorateStage(Stage):
self.content_safe_check_stage = stage_cls() self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx) await self.content_safe_check_stage.initialize(ctx)
def _split_text_by_words(self, text: str) -> list[str]:
"""使用分段词列表分段文本"""
if not self.split_words_pattern:
return [text]
segments = self.split_words_pattern.findall(text)
result = []
for seg in segments:
if isinstance(seg, tuple):
content = seg[0]
if not isinstance(content, str):
continue
for word in self.split_words:
if content.endswith(word):
content = content[: -len(word)]
break
if content.strip():
result.append(content)
elif seg and seg.strip():
result.append(seg)
return result if result else [text]
async def process( async def process(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
@@ -93,6 +131,8 @@ class ResultDecorateStage(Stage):
for comp in result.chain: for comp in result.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
text += comp.text text += comp.text
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
async for _ in self.content_safe_check_stage.process( async for _ in self.content_safe_check_stage.process(
event, event,
check_text=text, check_text=text,
@@ -114,7 +154,8 @@ class ResultDecorateStage(Stage):
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
) )
await handler.handler(event) await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
if (result := event.get_result()) is None or not result.chain:
logger.debug( logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。", f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
) )
@@ -161,6 +202,11 @@ class ResultDecorateStage(Stage):
# 不分段回复 # 不分段回复
new_chain.append(comp) new_chain.append(comp)
continue continue
# 根据 split_mode 选择分段方式
if self.split_mode == "words":
split_response = self._split_text_by_words(comp.text)
else: # regex 模式
try: try:
split_response = re.findall( split_response = re.findall(
self.regex, self.regex,
@@ -176,6 +222,7 @@ class ResultDecorateStage(Stage):
comp.text, comp.text,
re.DOTALL | re.MULTILINE, re.DOTALL | re.MULTILINE,
) )
if not split_response: if not split_response:
new_chain.append(comp) new_chain.append(comp)
continue continue
+5 -1
View File
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent from astrbot.core.platform import AstrMessageEvent
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
WecomAIBotMessageEvent,
)
from . import STAGES_ORDER from . import STAGES_ORDER
from .context import PipelineContext from .context import PipelineContext
@@ -78,7 +82,7 @@ class PipelineScheduler:
await self._process_stages(event) await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]: if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
await event.send(None) await event.send(None)
logger.debug("pipeline 执行完毕。") logger.debug("pipeline 执行完毕。")
@@ -50,6 +50,9 @@ class WakingCheckStage(Stage):
"ignore_at_all", "ignore_at_all",
False, False,
) )
self.disable_builtin_commands = self.ctx.astrbot_config.get(
"disable_builtin_commands", False
)
async def process( async def process(
self, self,
@@ -131,6 +134,13 @@ class WakingCheckStage(Stage):
EventType.AdapterMessageEvent, EventType.AdapterMessageEvent,
plugins_name=event.plugins_name, plugins_name=event.plugins_name,
): ):
if (
self.disable_builtin_commands
and handler.handler_module_path == "packages.builtin_commands.main"
):
logger.debug("skipping builtin command")
continue
# filter 需满足 AND 逻辑关系 # filter 需满足 AND 逻辑关系
passed = True passed = True
permission_not_pass = False permission_not_pass = False
+4 -2
View File
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
def get_sender_name(self) -> str: def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)""" """获取消息发送者的名称。(可能会返回空字符串)"""
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value): def set_extra(self, key, value):
"""设置额外的信息。""" """设置额外的信息。"""
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
""" """
self.call_llm = call_llm self.call_llm = call_llm
def get_result(self) -> MessageEventResult: def get_result(self) -> MessageEventResult | None:
"""获取消息事件的结果。""" """获取消息事件的结果。"""
return self._result return self._result
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
self, self,
prompt: str, prompt: str,
func_tool_manager=None, func_tool_manager=None,
session_id: str = None, session_id: str = "",
image_urls: list[str] | None = None, image_urls: list[str] | None = None,
contexts: list | None = None, contexts: list | None = None,
system_prompt: str = "", system_prompt: str = "",
+2 -2
View File
@@ -54,7 +54,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。 session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id message_id: str # 消息id
group: Group # 群组 group: Group | None # 群组
sender: MessageMember # 发送者 sender: MessageMember # 发送者
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串 message_str: str # 最直观的纯文本消息字符串
@@ -78,7 +78,7 @@ class AstrBotMessage:
return "" return ""
@group_id.setter @group_id.setter
def group_id(self, value: str): def group_id(self, value: str | None):
"""设置 group_id""" """设置 group_id"""
if value: if value:
if self.group: if self.group:
+8
View File
@@ -5,6 +5,7 @@ from asyncio import Queue
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
from .platform import Platform, PlatformStatus from .platform import Platform, PlatformStatus
from .register import platform_cls_map from .register import platform_cls_map
@@ -18,6 +19,7 @@ class PlatformManager:
self._inst_map: dict[str, dict] = {} self._inst_map: dict[str, dict] = {}
self.astrbot_config = config
self.platforms_config = config["platform"] self.platforms_config = config["platform"]
self.settings = config["platform_settings"] self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性; """NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
@@ -29,6 +31,8 @@ class PlatformManager:
"""初始化所有平台适配器""" """初始化所有平台适配器"""
for platform in self.platforms_config: for platform in self.platforms_config:
try: try:
if ensure_platform_webhook_config(platform):
self.astrbot_config.save_config()
await self.load_platform(platform) await self.load_platform(platform)
except Exception as e: except Exception as e:
logger.error(f"初始化 {platform} 平台适配器失败: {e}") logger.error(f"初始化 {platform} 平台适配器失败: {e}")
@@ -108,6 +112,10 @@ class PlatformManager:
from .sources.satori.satori_adapter import ( from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401 SatoriPlatformAdapter, # noqa: F401
) )
case "github_webhook":
from .sources.github_webhook.github_webhook_adapter import (
GitHubWebhookPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.error( logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
+11 -3
View File
@@ -1,7 +1,7 @@
import abc import abc
import uuid import uuid
from asyncio import Queue from asyncio import Queue
from collections.abc import Awaitable from collections.abc import Coroutine
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -80,6 +80,13 @@ class Platform(abc.ABC):
if self._status == PlatformStatus.ERROR: if self._status == PlatformStatus.ERROR:
self._status = PlatformStatus.RUNNING self._status = PlatformStatus.RUNNING
def unified_webhook(self) -> bool:
"""是否正在使用统一 Webhook 模式"""
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("webhook_uuid")
)
def get_stats(self) -> dict: def get_stats(self) -> dict:
"""获取平台统计信息""" """获取平台统计信息"""
meta = self.meta() meta = self.meta()
@@ -97,10 +104,11 @@ class Platform(abc.ABC):
} }
if self.last_error if self.last_error
else None, else None,
"unified_webhook": self.unified_webhook(),
} }
@abc.abstractmethod @abc.abstractmethod
def run(self) -> Awaitable[Any]: def run(self) -> Coroutine[Any, Any, None]:
"""得到一个平台的运行实例,需要返回一个协程对象。""" """得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError raise NotImplementedError
@@ -116,7 +124,7 @@ class Platform(abc.ABC):
self, self,
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
): ) -> None:
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 """通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法 异步方法
+1 -1
View File
@@ -7,7 +7,7 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" """平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str description: str
"""平台的描述""" """平台的描述"""
id: str | None = None id: str
"""平台的唯一标识符,用于配置中识别特定平台""" """平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict | None = None default_config_tmpl: dict | None = None
+1
View File
@@ -40,6 +40,7 @@ def register_platform_adapter(
pm = PlatformMetadata( pm = PlatformMetadata(
name=adapter_name, name=adapter_name,
description=desc, description=desc,
id=adapter_name,
default_config_tmpl=default_config_tmpl, default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name, adapter_display_name=adapter_display_name,
logo_path=logo_path, logo_path=logo_path,
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
bot: CQHttp, bot: CQHttp,
event: Event | None, event: Event | None,
is_group: bool, is_group: bool,
session_id: str, session_id: str | None,
messages: list[dict], messages: list[dict],
): ):
# session_id 必须是纯数字字符串 # session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
)
if is_group and isinstance(session_id, int): if is_group and isinstance(session_id_int, int):
await bot.send_group_msg(group_id=session_id, message=messages) await bot.send_group_msg(group_id=session_id_int, message=messages)
elif not is_group and isinstance(session_id, int): elif not is_group and isinstance(session_id_int, int):
await bot.send_private_msg(user_id=session_id, message=messages) await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底 elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages) await bot.send(event=event, message=messages)
else: else:
@@ -4,7 +4,7 @@ import logging
import time import time
import uuid import uuid
from collections.abc import Awaitable from collections.abc import Awaitable
from typing import Any from typing import Any, cast
from aiocqhttp import CQHttp, Event from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed from aiocqhttp.exceptions import ActionFailed
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="aiocqhttp", name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件""" """OneBot V11 请求类事件"""
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) abm.sender = MessageMember(
user_id=str(event.user_id), nickname=str(event.user_id)
)
abm.type = MessageType.OTHER_MESSAGE abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"): if event.get("group_id"):
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
@param event: 事件对象 @param event: 事件对象
@param get_reply: 是否获取回复消息这个参数是为了防止多个回复嵌套 @param get_reply: 是否获取回复消息这个参数是为了防止多个回复嵌套
""" """
assert event.sender is not None
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember( abm.sender = MessageMember(
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
if event["message_type"] == "group": if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id) abm.group_id = str(event.group_id)
abm.group = Group(str(event.group_id))
abm.group.group_name = event.get("group_name", "N/A") abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private": elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send(event, err) await self.bot.send(event, err)
except BaseException as e: except BaseException as e:
logger.error(f"回复消息失败: {e}") logger.error(f"回复消息失败: {e}")
return None raise ValueError(err)
# 按消息段类型类型适配 # 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]): for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
@@ -417,7 +421,7 @@ class AiocqhttpAdapter(Platform):
async def shutdown_trigger_placeholder(self): async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait() await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭") logger.info("aiocqhttp 适配器已被关闭")
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return self.metadata return self.metadata
@@ -2,6 +2,7 @@ import asyncio
import os import os
import threading import threading
import uuid import uuid
from typing import cast
import aiohttp import aiohttp
import dingtalk_stream import dingtalk_stream
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
self.client_id = platform_config["client_id"] self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"] self.client_secret = platform_config["client_secret"]
outer_self = self
class AstrCallbackClient(dingtalk_stream.ChatbotHandler): class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage): async def process(self, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}") logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data) im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im) abm = await outer_self.convert_msg(im)
await self.handle_msg(abm) await outer_self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK" return AckMessage.STATUS_OK, "OK"
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
self.client, self.client,
) )
self.client_ = client # 用于 websockets 的 client self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
def _id_to_sid(self, dingtalk_id: str | None) -> str: def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id: if not dingtalk_id:
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="dingtalk", name="dingtalk",
description="钉钉机器人官方 API 适配器", description="钉钉机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
abm = AstrBotMessage() abm = AstrBotMessage()
abm.message = [] abm.message = []
abm.message_str = "" abm.message_str = ""
abm.timestamp = int(message.create_at / 1000) abm.timestamp = int(cast(int, message.create_at) / 1000)
abm.type = ( abm.type = (
MessageType.GROUP_MESSAGE MessageType.GROUP_MESSAGE
if message.conversation_type == "2" if message.conversation_type == "2"
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
nickname=message.sender_nick, nickname=message.sender_nick,
) )
abm.self_id = self._id_to_sid(message.chatbot_user_id) abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.message_id = message.message_id abm.message_id = cast(str, message.message_id)
abm.raw_message = message abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE: if abm.type == MessageType.GROUP_MESSAGE:
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
else: else:
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
message_type: str = message.message_type message_type: str = cast(str, message.message_type)
match message_type: match message_type:
case "text": case "text":
abm.message_str = message.text.content.strip() abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str)) abm.message.append(Plain(abm.message_str))
case "richText": case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content rtc: dingtalk_stream.RichTextContent = cast(
contents: list[dict] = rtc.rich_text_list dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
for content in contents: for content in contents:
plains = "" plains = ""
if "text" in content: if "text" in content:
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
elif "type" in content and content["type"] == "picture": elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file( f_path = await self.download_ding_file(
content["downloadCode"], content["downloadCode"],
message.robot_code, cast(str, message.robot_code),
"jpg", "jpg",
) )
abm.message.append(Image.fromFileSystem(f_path)) abm.message.append(Image.fromFileSystem(f_path))
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error( logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}", f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
) )
return None return ""
resp_data = await resp.json() resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"] download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path) await download_file(download_url, f_path)
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error( logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}", f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
) )
return None return ""
return (await resp.json())["data"]["accessToken"] return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage): async def handle_msg(self, abm: AstrBotMessage):
@@ -239,7 +245,7 @@ class DingtalkPlatformAdapter(Platform):
task.result() task.result()
except Exception as e: except Exception as e:
if "Graceful shutdown" in str(e): if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭") logger.info("钉钉适配器已被关闭")
return return
logger.error(f"钉钉机器人启动失败: {e}") logger.error(f"钉钉机器人启动失败: {e}")
@@ -250,8 +256,10 @@ class DingtalkPlatformAdapter(Platform):
def monkey_patch_close(): def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown") raise KeyboardInterrupt("Graceful shutdown")
if self.client_.websocket is not None:
self.client_.open_connection = monkey_patch_close self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown") await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
if self._shutdown_event is not None:
self._shutdown_event.set() self._shutdown_event.set()
def get_client(self): def get_client(self):
@@ -1,4 +1,5 @@
import asyncio import asyncio
from typing import cast
import dingtalk_stream import dingtalk_stream
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown, client.reply_markdown,
segment.text, segment.text,
segment.text, segment.text,
self.message_obj.raw_message, cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
) )
elif isinstance(segment, Comp.Image): elif isinstance(segment, Comp.Image):
markdown_str = "" markdown_str = ""
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown, client.reply_markdown,
"😄", "😄",
markdown_str, markdown_str,
self.message_obj.raw_message, cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
) )
logger.debug(f"send image: {ret}") logger.debug(f"send image: {ret}")
@@ -1,4 +1,5 @@
import sys import sys
from collections.abc import Awaitable, Callable
import discord import discord
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
super().__init__(intents=intents, proxy=proxy) super().__init__(intents=intents, proxy=proxy)
# 回调函数 # 回调函数
self.on_message_received = None self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
self.on_ready_once_callback = None self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False self._ready_once_fired = False
@override
async def on_ready(self): async def on_ready(self):
"""当机器人成功连接并准备就绪时触发""" """当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
return
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。") logger.info("[Discord] 客户端已准备就绪。")
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
def _create_message_data(self, message: discord.Message) -> dict: def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典""" """从 discord.Message 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
is_mentioned = self.user in message.mentions is_mentioned = self.user in message.mentions
return { return {
"message": message, "message": message,
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
def _create_interaction_data(self, interaction: discord.Interaction) -> dict: def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典""" """从 discord.Interaction 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
if interaction.user is None:
raise ValueError("Interaction received without a valid user")
return { return {
"interaction": interaction, "interaction": interaction,
"bot_id": str(self.user.id), "bot_id": str(self.user.id),
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
"type": "interaction", "type": "interaction",
} }
@override
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
"""当接收到消息时触发""" """当接收到消息时触发"""
if message.author.bot: if message.author.bot:
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
def __init__( def __init__(
self, self,
components: list[BaseMessageComponent] = None, components: list[BaseMessageComponent] | None = None,
timeout: float = None, timeout: float | None = None,
): ):
self.components = components or [] self.components = components or []
self.timeout = timeout self.timeout = timeout
@@ -1,10 +1,10 @@
import asyncio import asyncio
import re import re
import sys import sys
from typing import Any from typing import Any, cast
import discord import discord
from discord.abc import Messageable from discord.abc import GuildChannel, Messageable, PrivateChannel
from discord.channel import DMChannel from discord.channel import DMChannel
from astrbot import logger from astrbot import logger
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
) -> None: ) -> None:
super().__init__(platform_config, event_queue) super().__init__(platform_config, event_queue)
self.settings = platform_settings self.settings = platform_settings
self.client_self_id = None self.client_self_id: str | None = None
self.registered_handlers = [] self.registered_handlers = []
# 指令注册相关 # 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True) self.enable_command_register = self.config.get("discord_command_register", True)
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
message_chain: MessageChain, message_chain: MessageChain,
): ):
"""通过会话发送消息""" """通过会话发送消息"""
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
)
return
# 创建一个 message_obj 以便在 event 中使用 # 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage() message_obj = AstrBotMessage()
if "_" in session.session_id: if "_" in session.session_id:
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
user_id=str(self.client_self_id), user_id=str(self.client_self_id),
nickname=self.client.user.display_name, nickname=self.client.user.display_name,
) )
message_obj.self_id = self.client_self_id message_obj.self_id = cast(str, self.client_self_id)
message_obj.session_id = session.session_id message_obj.session_id = session.session_id
message_obj.message = message_chain.chain message_obj.message = message_chain.chain
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
"discord", "discord",
"Discord 适配器", "Discord 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
default_config_tmpl=self.config, default_config_tmpl=self.config,
support_streaming_message=False, support_streaming_message=False,
) )
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
def _get_message_type( def _get_message_type(
self, self,
channel: Messageable, channel: Messageable | GuildChannel | PrivateChannel,
guild_id: int | None = None, guild_id: int | None = None,
) -> MessageType: ) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型""" """根据 channel 对象和 guild_id 判断消息类型"""
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
return MessageType.FRIEND_MESSAGE return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE return MessageType.GROUP_MESSAGE
def _get_channel_id(self, channel: Messageable) -> str: def _get_channel_id(
self, channel: Messageable | GuildChannel | PrivateChannel
) -> str:
"""根据 channel 对象获取ID""" """根据 channel 对象获取ID"""
return str(getattr(channel, "id", None)) return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage""" """将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"] message = data["message"]
content = message.content content = message.content
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
) )
abm.message = message_chain abm.message = message_chain
abm.raw_message = message abm.raw_message = message
abm.self_id = self.client_self_id abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(message.channel.id) abm.session_id = str(message.channel.id)
abm.message_id = str(message.id) abm.message_id = str(message.id)
return abm return abm
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
interaction_followup_webhook=followup_webhook, interaction_followup_webhook=followup_webhook,
) )
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
)
return
# 检查是否为斜杠指令 # 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None is_slash_command = message_event.interaction_followup_webhook is not None
# 1. 优先处理斜杠指令
if is_slash_command:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
return
# 2. 处理普通消息(提及检测)
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
raw_message = message.raw_message
if not isinstance(raw_message, discord.Message):
logger.warning(
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
)
return
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention # 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False is_mention = False
# User Mention # User Mention
if ( # 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
self.client if self.client.user in raw_message.mentions:
and self.client.user
and hasattr(message.raw_message, "mentions")
):
if self.client.user in message.raw_message.mentions:
is_mention = True is_mention = True
# Role MentionBot 拥有的角色被提及) # Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"): if not is_mention and raw_message.role_mentions:
bot_member = None bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild: if raw_message.guild:
try: try:
bot_member = message.raw_message.guild.get_member( bot_member = raw_message.guild.get_member(
self.client.user.id, self.client.user.id,
) )
except Exception: except Exception:
bot_member = None bot_member = None
if bot_member and hasattr(bot_member, "roles"): if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles) bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions) mentioned_roles = set(raw_message.role_mentions)
if ( if (
bot_roles bot_roles
and mentioned_roles and mentioned_roles
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
): ):
is_mention = True is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态 # 如果是被@的消息,设置为唤醒状态
if is_slash_command or is_mention: if is_mention:
message_event.is_wake = True message_event.is_wake = True
message_event.is_at_or_wake_command = True message_event.is_at_or_wake_command = True
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
) )
abm.message = [Plain(text=message_str_for_filter)] abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(ctx.channel_id) abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id) abm.message_id = str(ctx.interaction.id)
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
def _extract_command_info( def _extract_command_info(
event_filter: Any, event_filter: Any,
handler_metadata: StarHandlerMetadata, handler_metadata: StarHandlerMetadata,
) -> tuple[str, str, CommandFilter] | None: ) -> tuple[str, str, CommandFilter | None] | None:
"""从事件过滤器中提取指令信息""" """从事件过滤器中提取指令信息"""
cmd_name = None cmd_name = None
# is_group = False # is_group = False
@@ -4,8 +4,10 @@ import binascii
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import cast
import discord import discord
from discord.types.interactions import ComponentInteractionData
from astrbot import logger from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
channel = await self._get_channel() channel = await self._get_channel()
if not channel: if not channel:
return return
if not isinstance(channel, discord.abc.Messageable):
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
return
await channel.send(**kwargs) await channel.send(**kwargs)
except Exception as e: except Exception as e:
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
await self.send(buffer) await self.send(buffer)
return await super().send_streaming(generator, use_fallback) return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None: async def _get_channel(
self,
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
"""获取当前事件对应的频道对象""" """获取当前事件对应的频道对象"""
try: try:
channel_id = int(self.session_id) channel_id = int(self.session_id)
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
async def _parse_to_discord( async def _parse_to_discord(
self, self,
message: MessageChain, message: MessageChain,
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: ) -> tuple[
str,
list[discord.File],
discord.ui.View | None,
list[discord.Embed],
str | int | None,
]:
"""将 MessageChain 解析为 Discord 发送所需的内容""" """将 MessageChain 解析为 Discord 发送所需的内容"""
content_parts = [] content_parts = []
files = [] files = []
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message, self.message_obj.raw_message,
"add_reaction", "add_reaction",
): ):
await self.message_obj.raw_message.add_reaction(emoji) await cast(discord.Message, self.message_obj.raw_message).add_reaction(
emoji
)
except Exception as e: except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}") logger.error(f"[Discord] 添加反应失败: {e}")
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
return ( return (
hasattr(self.message_obj, "raw_message") hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type") and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.application_command == discord.InteractionType.application_command
) )
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
return ( return (
hasattr(self.message_obj, "raw_message") hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type") and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.component
) )
def get_interaction_custom_id(self) -> str: def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id""" """获取交互组件的custom_id"""
if self.is_button_interaction(): if self.is_button_interaction():
try: try:
return self.message_obj.raw_message.data.get("custom_id", "") return cast(
ComponentInteractionData,
cast(discord.Interaction, self.message_obj.raw_message).data,
).get("custom_id", "")
except Exception: except Exception:
pass pass
return "" return ""
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
): ):
return any( return any(
mention.id == int(self.message_obj.self_id) mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions for mention in cast(
discord.Message, self.message_obj.raw_message
).mentions
) )
return False return False
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message, self.message_obj.raw_message,
"clean_content", "clean_content",
): ):
return self.message_obj.raw_message.clean_content return cast(discord.Message, self.message_obj.raw_message).clean_content
return self.message_str return self.message_str
@@ -0,0 +1,315 @@
import asyncio
import hashlib
import hmac
from typing import Any, cast
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.platform import PlatformStatus
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .github_webhook_event import GitHubWebhookMessageEvent
@register_platform_adapter(
"github_webhook",
"GitHub Webhook 适配器",
support_streaming_message=False,
)
class GitHubWebhookPlatformAdapter(Platform):
"""GitHub Webhook 平台适配器
支持的事件:
- issues (created)
- issue_comment (created)
- pull_request (opened)
"""
def __init__(
self,
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(platform_config, event_queue)
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", True)
self.webhook_secret = platform_config.get("webhook_secret", "")
self.shutdown_event = asyncio.Event()
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
"""GitHub Webhook 是单向接收,不支持主动发送消息"""
logger.warning("GitHub Webhook 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="github_webhook",
description="GitHub Webhook 适配器",
id=cast(str, self.config.get("id")),
)
async def run(self):
"""运行适配器"""
self.status = PlatformStatus.RUNNING
# 如果启用统一 webhook 模式
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(GitHub Webhook)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.shutdown_event.wait()
else:
logger.warning("GitHub Webhook 适配器需要启用统一 webhook 模式")
await self.shutdown_event.wait()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口
处理 GitHub webhook 事件
Args:
request: Quart 请求对象
Returns:
响应数据
"""
try:
# 获取事件类型
event_type = request.headers.get("X-GitHub-Event", "")
# 获取请求数据
payload = await request.json
# 验证 webhook 签名(如果配置了 secret
if self.webhook_secret:
if not await self._verify_signature(request, payload):
logger.warning("GitHub webhook 签名验证失败")
return {"error": "Invalid signature"}, 401
logger.debug(f"收到 GitHub Webhook 事件: {event_type}")
# 处理不同类型的事件
if event_type == "issues":
await self._handle_issue_event(payload)
elif event_type == "issue_comment":
await self._handle_issue_comment_event(payload)
elif event_type == "pull_request":
await self._handle_pull_request_event(payload)
elif event_type == "ping":
# GitHub webhook 验证事件
return {"message": "pong"}
else:
logger.debug(f"忽略不支持的 GitHub 事件类型: {event_type}")
return {"status": "ok"}
except Exception as e:
logger.error(f"处理 GitHub webhook 回调时发生错误: {e}", exc_info=True)
return {"error": str(e)}, 500
async def _verify_signature(self, request: Any, payload: dict) -> bool:
"""验证 GitHub webhook 签名
Args:
request: Quart 请求对象
payload: 请求负载数据
Returns:
签名是否有效
"""
signature_header = request.headers.get("X-Hub-Signature-256", "")
if not signature_header:
# 如果没有签名头,检查是否有旧版本的签名
signature_header = request.headers.get("X-Hub-Signature", "")
if not signature_header:
return False
# 获取原始请求体
body = await request.get_data()
# 计算 HMAC
if signature_header.startswith("sha256="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha256,
).hexdigest()
received_signature = signature_header.replace("sha256=", "")
elif signature_header.startswith("sha1="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha1,
).hexdigest()
received_signature = signature_header.replace("sha1=", "")
else:
return False
# 使用 hmac.compare_digest 防止时序攻击
return hmac.compare_digest(expected_signature, received_signature)
async def _handle_issue_event(self, payload: dict):
"""处理 issue 事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created" and action != "opened":
return
issue = payload.get("issue", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"📝 新 Issue 创建\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {issue.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {issue.get('html_url', '')}\n"
f"内容:\n{issue.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issues",
payload,
)
)
async def _handle_issue_comment_event(self, payload: dict):
"""处理 issue 评论事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created":
return
issue = payload.get("issue", {})
comment = payload.get("comment", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"💬 新 Issue 评论\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"Issue: {issue.get('title', 'No title')}\n"
f"评论者: {sender.get('login', 'unknown')}\n"
f"链接: {comment.get('html_url', '')}\n"
f"内容:\n{comment.get('body', 'No comment')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issue_comment",
payload,
)
)
async def _handle_pull_request_event(self, payload: dict):
"""处理 pull request 事件"""
action = payload.get("action", "")
# 只处理打开事件
if action != "opened":
return
pr = payload.get("pull_request", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"🔀 新 Pull Request\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {pr.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {pr.get('html_url', '')}\n"
f"内容:\n{pr.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"pull_request",
payload,
)
)
def _create_message(
self,
message_text: str,
user_id: str,
nickname: str,
session_id: str,
) -> AstrBotMessage:
"""创建 AstrBotMessage 对象"""
abm = AstrBotMessage()
abm.type = MessageType.GROUP_MESSAGE
abm.self_id = self.client_self_id
abm.session_id = session_id
abm.message_id = ""
abm.sender = MessageMember(user_id=user_id, nickname=nickname)
abm.message = [Plain(message_text)]
abm.message_str = message_text
abm.raw_message = message_text
return abm
async def terminate(self):
"""终止适配器运行"""
self.shutdown_event.set()
logger.info("GitHub Webhook 适配器已经被优雅地关闭")
@@ -0,0 +1,22 @@
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from ...astr_message_event import AstrMessageEvent
class GitHubWebhookMessageEvent(AstrMessageEvent):
"""GitHub Webhook 消息事件"""
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
event_type: str,
event_data: dict,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.event_type = event_type
"""GitHub 事件类型: issues, issue_comment, pull_request"""
self.event_data = event_data
"""原始事件数据"""
@@ -2,10 +2,17 @@ import asyncio
import base64 import base64
import json import json
import re import re
import time
import uuid import uuid
from typing import Any, cast
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
GetMessageResourceRequest,
)
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
import astrbot.api.message_components as Comp import astrbot.api.message_components as Comp
from astrbot import logger from astrbot import logger
@@ -18,9 +25,11 @@ from astrbot.api.platform import (
PlatformMetadata, PlatformMetadata,
) )
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent from .lark_event import LarkMessageEvent
from .server import LarkWebhookServer
@register_platform_adapter( @register_platform_adapter(
@@ -42,9 +51,13 @@ class LarkPlatformAdapter(Platform):
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN) self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
self.bot_name = platform_config.get("lark_bot_name", "astrbot") self.bot_name = platform_config.get("lark_bot_name", "astrbot")
# socket or webhook
self.connection_mode = platform_config.get("lark_connection_mode", "socket")
if not self.bot_name: if not self.bot_name:
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
# 初始化 WebSocket 长连接相关配置
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
await self.convert_msg(event) await self.convert_msg(event)
@@ -57,6 +70,8 @@ class LarkPlatformAdapter(Platform):
.build() .build()
) )
self.do_v2_msg_event = do_v2_msg_event
self.client = lark.ws.Client( self.client = lark.ws.Client(
app_id=self.appid, app_id=self.appid,
app_secret=self.appsecret, app_secret=self.appsecret,
@@ -69,11 +84,48 @@ class LarkPlatformAdapter(Platform):
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build() lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
) )
self.webhook_server = None
if self.connection_mode == "webhook":
self.webhook_server = LarkWebhookServer(platform_config, event_queue)
self.webhook_server.set_callback(self.handle_webhook_event)
self.event_id_timestamps: dict[str, float] = {}
def _clean_expired_events(self):
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
expired_keys = [
event_id
for event_id, timestamp in self.event_id_timestamps.items()
if current_time - timestamp > 1800
]
for event_id in expired_keys:
del self.event_id_timestamps[event_id]
def _is_duplicate_event(self, event_id: str) -> bool:
"""检查事件是否重复
Args:
event_id: 事件ID
Returns:
True 表示重复事件False 表示新事件
"""
self._clean_expired_events()
if event_id in self.event_id_timestamps:
return True
self.event_id_timestamps[event_id] = time.time()
return False
async def send_by_session( async def send_by_session(
self, self,
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
): ):
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
return
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api) res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = { wrapped = {
"zh_cn": { "zh_cn": {
@@ -114,14 +166,25 @@ class LarkPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="lark", name="lark",
description="飞书机器人官方 API 适配器", description="飞书机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
message = event.event.message message = event.event.message
if message is None:
logger.debug("[Lark] 事件中没有消息体(message is None)")
return
abm = AstrBotMessage() abm = AstrBotMessage()
abm.timestamp = int(message.create_time) / 1000
if message.create_time:
abm.timestamp = int(message.create_time) // 1000
else:
abm.timestamp = int(time.time())
abm.message = [] abm.message = []
abm.type = ( abm.type = (
MessageType.GROUP_MESSAGE MessageType.GROUP_MESSAGE
@@ -136,14 +199,28 @@ class LarkPlatformAdapter(Platform):
at_list = {} at_list = {}
if message.mentions: if message.mentions:
for m in message.mentions: for m in message.mentions:
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) if m.id is None:
continue
# 飞书 open_id 可能是 None,这里做个防护
open_id = m.id.open_id if m.id.open_id else ""
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
if m.name == self.bot_name: if m.name == self.bot_name:
if m.id.open_id is not None:
abm.self_id = m.id.open_id abm.self_id = m.id.open_id
if message.content is None:
logger.warning("[Lark] 消息内容为空")
return
try:
content_json_b = json.loads(message.content) content_json_b = json.loads(message.content)
except json.JSONDecodeError:
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
if message.message_type == "text": if message.message_type == "text":
message_str_raw = content_json_b["text"] # 带有 @ 的消息 message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则 at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw) # at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分 # 拆分文本,去掉AT符号部分
@@ -168,27 +245,47 @@ class LarkPlatformAdapter(Platform):
content_json_b = _ls content_json_b = _ls
elif message.message_type == "image": elif message.message_type == "image":
content_json_b = [ content_json_b = [
{"tag": "img", "image_key": content_json_b["image_key"], "style": []}, {
"tag": "img",
"image_key": content_json_b.get("image_key"),
"style": [],
},
] ]
if message.message_type in ("post", "image"): if message.message_type in ("post", "image"):
for comp in content_json_b: for comp in content_json_b:
if comp["tag"] == "at": if comp.get("tag") == "at":
abm.message.append(at_list[comp["user_id"]]) user_id = comp.get("user_id")
elif comp["tag"] == "text" and comp["text"].strip(): if user_id in at_list:
abm.message.append(at_list[user_id])
elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip())) abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img": elif comp.get("tag") == "img":
image_key = comp["image_key"] image_key = comp.get("image_key")
if not image_key:
continue
request = ( request = (
GetMessageResourceRequest.builder() GetMessageResourceRequest.builder()
.message_id(message.message_id) .message_id(cast(str, message.message_id))
.file_key(image_key) .file_key(image_key)
.type("image") .type("image")
.build() .build()
) )
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
continue
response = await self.lark_api.im.v1.message_resource.aget(request) response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success(): if not response.success():
logger.error(f"无法下载飞书图片: {image_key}") logger.error(f"无法下载飞书图片: {image_key}")
continue
if response.file is None:
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
continue
image_bytes = response.file.read() image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode() image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64)) abm.message.append(Comp.Image.fromBase64(image_base64))
@@ -196,6 +293,19 @@ class LarkPlatformAdapter(Platform):
for comp in abm.message: for comp in abm.message:
if isinstance(comp, Comp.Plain): if isinstance(comp, Comp.Plain):
abm.message_str += comp.text abm.message_str += comp.text
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
return
if (
event.event.sender is None
or event.event.sender.sender_id is None
or event.event.sender.sender_id.open_id is None
):
logger.error("[Lark] 消息发送者信息不完整")
return
abm.message_id = message.message_id abm.message_id = message.message_id
abm.raw_message = message abm.raw_message = message
abm.sender = MessageMember( abm.sender = MessageMember(
@@ -227,13 +337,61 @@ class LarkPlatformAdapter(Platform):
self._event_queue.put_nowait(event) self._event_queue.put_nowait(event)
async def handle_webhook_event(self, event_data: dict):
"""处理 Webhook 事件
Args:
event_data: Webhook 事件数据
"""
try:
header = event_data.get("header", {})
event_id = header.get("event_id", "")
if event_id and self._is_duplicate_event(event_id):
logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
return
event_type = header.get("event_type", "")
if event_type == "im.message.receive_v1":
processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
data = (processor.type())(event_data)
processor.do(data)
else:
logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
async def run(self): async def run(self):
# self.client.start() if self.connection_mode == "webhook":
# Webhook 模式
if self.webhook_server is None:
logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
return
webhook_uuid = self.config.get("webhook_uuid")
if webhook_uuid:
log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
else:
logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
else:
# 长连接模式
await self.client._connect() await self.client._connect()
async def terminate(self): async def webhook_callback(self, request: Any) -> Any:
await self.client._disconnect() """统一 Webhook 回调入口"""
logger.info("飞书(Lark) 适配器已被优雅地关闭") if not self.webhook_server:
return {"error": "Webhook server not initialized"}, 500
def get_client(self) -> lark.Client: return await self.webhook_server.handle_callback(request)
async def terminate(self):
if self.connection_mode == "socket":
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已关闭")
def get_client(self) -> lark.ws.Client:
return self.client return self.client
def unified_webhook(self) -> bool:
return bool(
self.config.get("lark_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -5,7 +5,15 @@ import uuid
from io import BytesIO from io import BytesIO
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import (
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
ReplyMessageRequest,
ReplyMessageRequestBody,
)
from astrbot import logger from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
file_path = comp.file.replace("file:///", "") file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"): elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file) image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path file_path = image_file_path if image_file_path else ""
elif comp.file and comp.file.startswith("base64://"): elif comp.file and comp.file.startswith("base64://"):
base64_str = comp.file.removeprefix("base64://") base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str) image_data = base64.b64decode(base64_str)
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue()) f.write(BytesIO(image_data).getvalue())
else: else:
file_path = comp.file file_path = comp.file if comp.file else ""
if image_file is None: if image_file is None:
if not file_path:
logger.error("[Lark] 图片路径为空,无法上传")
continue
try:
image_file = open(file_path, "rb") image_file = open(file_path, "rb")
except Exception as e:
logger.error(f"[Lark] 无法打开图片文件: {e}")
continue
request = ( request = (
CreateImageRequest.builder() CreateImageRequest.builder()
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
) )
.build() .build()
) )
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
continue
response = await lark_client.im.v1.image.acreate(request) response = await lark_client.im.v1.image.acreate(request)
if not response.success(): if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}") logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
continue
if response.data is None:
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
continue
image_key = response.data.image_key image_key = response.data.image_key
logger.debug(image_key) logger.debug(image_key)
ret.append(_stage) ret.append(_stage)
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
.build() .build()
) )
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
return
response = await self.bot.im.v1.message.areply(request) response = await self.bot.im.v1.message.areply(request)
if not response.success(): if not response.success():
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message) await super().send(message)
async def react(self, emoji: str): async def react(self, emoji: str):
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
request = ( request = (
CreateMessageReactionRequest.builder() CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id) .message_id(self.message_obj.message_id)
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
) )
.build() .build()
) )
response = await self.bot.im.v1.message_reaction.acreate(request) response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success(): if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}") logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
@@ -0,0 +1,206 @@
"""飞书(Lark) Webhook 服务器实现
实现飞书事件订阅的 Webhook 模式支持:
1. 请求 URL 验证 (challenge 验证)
2. 事件加密/解密 (AES-256-CBC)
3. 签名校验 (SHA256)
4. 事件接收和处理
"""
import asyncio
import base64
import hashlib
import json
from collections.abc import Awaitable, Callable
from Crypto.Cipher import AES
from astrbot.api import logger
class AESCipher:
"""AES 加密/解密工具类"""
def __init__(self, key: str):
self.bs = AES.block_size
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode("utf8"))
if isinstance(data, u_type):
return data.encode("utf8")
return data
@staticmethod
def _unpad(s):
return s[: -ord(s[len(s) - 1 :])]
def decrypt(self, enc):
iv = enc[: AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode("utf8")
class LarkWebhookServer:
"""飞书 Webhook 服务器
仅支持统一 Webhook 模式
"""
def __init__(self, config: dict, event_queue: asyncio.Queue):
"""初始化 Webhook 服务器
Args:
config: 飞书配置
event_queue: 事件队列
"""
self.app_id = config["app_id"]
self.app_secret = config["app_secret"]
self.encrypt_key = config.get("lark_encrypt_key", "")
self.verification_token = config.get("lark_verification_token", "")
self.event_queue = event_queue
self.callback: Callable[[dict], Awaitable[None]] | None = None
# 初始化加密工具
self.cipher = None
if self.encrypt_key:
self.cipher = AESCipher(self.encrypt_key)
def verify_signature(
self,
timestamp: str,
nonce: str,
encrypt_key: str,
body: bytes,
signature: str,
) -> bool:
"""验证签名
Args:
timestamp: 请求时间戳
nonce: 随机数
encrypt_key: 加密密钥
body: 请求体
signature: 签名
Returns:
签名是否有效
"""
# 拼接字符串: timestamp + nonce + encrypt_key + body
bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8")
bytes_b = bytes_b1 + body
h = hashlib.sha256(bytes_b)
calculated_signature = h.hexdigest()
return calculated_signature == signature
def decrypt_event(self, encrypted_data: str) -> dict:
"""解密事件数据
Args:
encrypted_data: 加密的事件数据
Returns:
解密后的事件字典
"""
if not self.cipher:
raise ValueError("未配置 encrypt_key,无法解密事件")
decrypted_str = self.cipher.decrypt_string(encrypted_data)
return json.loads(decrypted_str)
async def handle_challenge(self, event_data: dict) -> dict:
"""处理 challenge 验证请求
Args:
event_data: 事件数据
Returns:
包含 challenge 的响应
"""
challenge = event_data.get("challenge", "")
logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}")
return {"challenge": challenge}
async def handle_callback(self, request) -> tuple[dict, int] | dict:
"""处理 webhook 回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应数据
"""
# 获取原始请求体
body = await request.get_data()
try:
event_data = await request.json
except Exception as e:
logger.error(f"[Lark Webhook] 解析请求体失败: {e}")
return {"error": "Invalid JSON"}, 400
if not event_data:
logger.error("[Lark Webhook] 请求体为空")
return {"error": "Empty request body"}, 400
# 如果配置了 encrypt_key,进行签名验证
if self.encrypt_key:
timestamp = request.headers.get("X-Lark-Request-Timestamp", "")
nonce = request.headers.get("X-Lark-Request-Nonce", "")
signature = request.headers.get("X-Lark-Signature", "")
if timestamp and nonce and signature:
if not self.verify_signature(
timestamp, nonce, self.encrypt_key, body, signature
):
logger.error("[Lark Webhook] 签名验证失败")
return {"error": "Invalid signature"}, 401
# 检查是否是加密事件
if "encrypt" in event_data:
try:
event_data = self.decrypt_event(event_data["encrypt"])
logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}")
except Exception as e:
logger.error(f"[Lark Webhook] 解密事件失败: {e}")
return {"error": "Decryption failed"}, 400
# 验证 token
if self.verification_token:
header = event_data.get("header", {})
if header:
token = header.get("token", "")
else:
token = event_data.get("token", "")
if token != self.verification_token:
logger.error("[Lark Webhook] Verification Token 不匹配。")
return {"error": "Invalid verification token"}, 401
# 处理 URL 验证 (challenge)
if event_data.get("type") == "url_verification":
return await self.handle_challenge(event_data)
# 调用回调函数处理事件
if self.callback:
try:
await self.callback(event_data)
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True)
return {"error": "Event processing failed"}, 500
return {}
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
"""设置事件回调函数
Args:
callback: 处理事件的异步函数
"""
self.callback = callback
@@ -1,7 +1,6 @@
import asyncio import asyncio
import os import os
import random import random
from collections.abc import Awaitable
from typing import Any from typing import Any
import astrbot.api.message_components as Comp import astrbot.api.message_components as Comp
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
if not isinstance(message.raw_message, dict): if not isinstance(message.raw_message, dict):
message.raw_message = {} message.raw_message = {}
message.raw_message["poll"] = poll message.raw_message["poll"] = poll
message.poll = poll message.__setattr__("poll", poll)
except Exception: except Exception:
pass pass
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
self, self,
session: MessageSession, session: MessageSession,
message_chain: MessageChain, message_chain: MessageChain,
) -> Awaitable[Any]: ) -> None:
if not self.api: if not self.api:
logger.error("[Misskey] API 客户端未初始化") logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain) return await super().send_by_session(session, message_chain)
@@ -3,6 +3,7 @@ import base64
import os import os
import random import random
import uuid import uuid
from typing import cast
import aiofiles import aiofiles
import botpy import botpy
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
time_since_last_edit = current_time - last_edit_time time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval: if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload) ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1 stream_payload["index"] += 1
stream_payload["id"] = ret["id"] stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time() last_edit_time = asyncio.get_event_loop().time()
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return None return None
source = self.message_obj.raw_message source = self.message_obj.raw_message
assert isinstance(
if not isinstance(
source, source,
( (
botpy.message.Message, botpy.message.Message,
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
botpy.message.DirectMessage, botpy.message.DirectMessage,
botpy.message.C2CMessage, botpy.message.C2CMessage,
), ),
) ):
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
return None
( (
plain_text, plain_text,
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
): ):
return None return None
payload = { payload: dict = {
"content": plain_text, "content": plain_text,
"msg_id": self.message_obj.message_id, "msg_id": self.message_obj.message_id,
} }
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ret = None ret = None
match type(source): match source:
case botpy.message.GroupMessage: case botpy.message.GroupMessage():
if not source.group_openid:
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
return None
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
group_openid=source.group_openid, group_openid=source.group_openid,
**payload, **payload,
) )
case botpy.message.C2CMessage:
case botpy.message.C2CMessage():
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
**payload, **payload,
) )
logger.debug(f"Message sent to C2C: {ret}") logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
case botpy.message.Message():
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
ret = await self.bot.api.post_message( ret = await self.bot.api.post_message(
channel_id=source.channel_id, channel_id=source.channel_id,
**payload, **payload,
) )
case botpy.message.DirectMessage:
case botpy.message.DirectMessage():
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
case _:
pass
await super().send(self.send_buffer) await super().send(self.send_buffer)
self.send_buffer = None self.send_buffer = None
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"file_type": file_type, "file_type": file_type,
"srv_send_msg": False, "srv_send_msg": False,
} }
result = None
if "openid" in kwargs: if "openid" in kwargs:
payload["openid"] = kwargs["openid"] payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if "group_openid" in kwargs: elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"] payload["group_openid"] = kwargs["group_openid"]
route = Route( route = Route(
"POST", "POST",
"/v2/groups/{group_openid}/files", "/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"], group_openid=kwargs["group_openid"],
) )
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
else:
raise ValueError("Invalid upload parameters")
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to upload image, response is not dict: {result}"
)
return Media(
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
)
async def upload_group_and_c2c_record( async def upload_group_and_c2c_record(
self, self,
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
result = await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if result: if result:
if not isinstance(result, dict):
logger.error(f"上传文件响应格式错误: {result}")
return None
return Media( return Media(
file_uuid=result.get("file_uuid"), file_uuid=result["file_uuid"],
file_info=result.get("file_info"), file_info=result["file_info"],
ttl=result.get("ttl", 0), ttl=result.get("ttl", 0),
file_id=result.get("id", ""),
) )
except Exception as e: except Exception as e:
logger.error(f"上传请求错误: {e}") logger.error(f"上传请求错误: {e}")
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
message_reference: message.Reference | None = None, message_reference: message.Reference | None = None,
media: message.Media | None = None, media: message.Media | None = None,
msg_id: str | None = None, msg_id: str | None = None,
msg_seq: str = 1, msg_seq: int | None = 1,
event_id: str | None = None, event_id: str | None = None,
markdown: message.MarkdownPayload | None = None, markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None, keyboard: message.Keyboard | None = None,
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = locals() payload = locals()
payload.pop("self", None) payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid) route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to post c2c message, response is not dict: {result}"
)
return message.Message(**result)
@staticmethod @staticmethod
async def _parse_to_qqofficial(message: MessageChain): async def _parse_to_qqofficial(message: MessageChain):
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = file_to_base64(image_file_path) image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"): elif i.file and i.file.startswith("base64://"):
image_base64 = i.file image_base64 = i.file
else: elif i.file:
image_base64 = file_to_base64(i.file) image_base64 = file_to_base64(i.file)
else:
raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://") image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record): elif isinstance(i, Record):
if i.file: if i.file:
@@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
import time import time
from typing import cast
import botpy import botpy
import botpy.message import botpy.message
@@ -44,7 +45,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE, MessageType.GROUP_MESSAGE,
) )
abm.session_id = ( abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
) )
self._commit(abm) self._commit(abm)
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
self.appid = platform_config["appid"] self.appid = platform_config["appid"]
self.secret = platform_config["secret"] self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"] self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"] qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"] guild_dm = platform_config["enable_guild_direct_message"]
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="qq_official", name="qq_official",
description="QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
) )
@staticmethod @staticmethod
def _parse_from_qqofficial( def _parse_from_qqofficial(
message: botpy.message.Message | botpy.message.GroupMessage, message: botpy.message.Message
| botpy.message.GroupMessage
| botpy.message.DirectMessage
| botpy.message.C2CMessage,
message_type: MessageType, message_type: MessageType,
): ):
abm = AstrBotMessage() abm = AstrBotMessage()
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.raw_message = message abm.raw_message = message
abm.message_id = message.id abm.message_id = message.id
abm.tag = "qq_official" # abm.tag = "qq_official"
msg: list[BaseMessageComponent] = [] msg: list[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance( if isinstance(message, botpy.message.GroupMessage) or isinstance(
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
message, message,
botpy.message.DirectMessage, botpy.message.DirectMessage,
): ):
try: if isinstance(message, botpy.message.Message):
abm.self_id = str(message.mentions[0].id) abm.self_id = str(message.mentions[0].id)
except BaseException as _: else:
abm.self_id = "" abm.self_id = ""
plain_content = message.content.replace( plain_content = message.content.replace(
@@ -1,6 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import Any from typing import Any, cast
import botpy import botpy
import botpy.message import botpy.message
@@ -36,7 +36,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE, MessageType.GROUP_MESSAGE,
) )
abm.session_id = ( abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
) )
self._commit(abm) self._commit(abm)
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="qq_official_webhook", name="qq_official_webhook",
description="QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
) )
async def run(self): async def run(self):
@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import cast
import quart import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
if opcode == 13: if opcode == 13:
# validation # validation
signed = await self.webhook_validation(data) signed = await self.webhook_validation(cast(dict, data))
print(signed) print(signed)
return signed return signed
@@ -4,9 +4,11 @@ import hmac
import json import json
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import cast
from quart import Quart, Response, request from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
@@ -66,7 +68,7 @@ class SlackWebhookClient:
""" """
try: try:
# 获取请求体和头部 # 获取请求体和头部
body = await req.get_data() body = cast(bytes, await req.get_data())
event_data = json.loads(body.decode("utf-8")) event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature # Verify Slack request signature
@@ -139,9 +141,14 @@ class SlackSocketClient:
self.event_handler = event_handler self.event_handler = event_handler
self.socket_client = None self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest): async def _handle_events(
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
):
"""处理 Socket Mode 事件""" """处理 Socket Mode 事件"""
try: try:
if self.socket_client is None:
raise RuntimeError("Socket client is not initialized")
# 确认收到事件 # 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id) response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response) await self.socket_client.send_socket_mode_response(response)
@@ -3,8 +3,7 @@ import base64
import re import re
import time import time
import uuid import uuid
from collections.abc import Awaitable from typing import Any, cast
from typing import Any
import aiohttp import aiohttp
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="slack", name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。", description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
logger.debug(f"[slack] RawMessage {event}") logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = self.bot_self_id abm.self_id = cast(str, self.bot_self_id)
# 获取用户信息 # 获取用户信息
user_id = event.get("user", "") user_id = event.get("user", "")
try: try:
user_info = await self.web_client.users_info(user=user_id) user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"] user_data = cast(dict, user_info["user"])
user_name = user_data.get("real_name") or user_data.get("name", user_id) user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception: except Exception:
user_name = user_id user_name = user_id
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
channel_id = event.get("channel", "") channel_id = event.get("channel", "")
try: try:
channel_info = await self.web_client.conversations_info(channel=channel_id) channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"] is_im = cast(dict, channel_info["channel"])["is_im"]
if is_im: if is_im:
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
for mention in mentions: for mention in mentions:
try: try:
mentioned_user = await self.web_client.users_info(user=mention) mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"] user_data = cast(dict, mentioned_user["user"])
user_name = user_data.get("real_name") or user_data.get( user_name = user_data.get("real_name") or user_data.get(
"name", "name",
mention, mention,
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
) )
raise Exception(f"下载文件失败: {resp.status}") raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]: async def run(self) -> None:
self.bot_self_id = await self.get_bot_user_id() self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}") logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
@@ -410,7 +409,7 @@ class SlackAdapter(Platform):
await self.socket_client.stop() await self.socket_client.stop()
if self.webhook_client: if self.webhook_client:
await self.webhook_client.stop() await self.webhook_client.stop()
logger.info("Slack 适配器已被优雅地关闭") logger.info("Slack 适配器已被关闭")
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return self.metadata return self.metadata
@@ -428,3 +427,10 @@ class SlackAdapter(Platform):
def get_client(self): def get_client(self):
return self.web_client return self.web_client
def unified_webhook(self) -> bool:
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("slack_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -1,6 +1,7 @@
import asyncio import asyncio
import re import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Iterable
from typing import cast
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
if isinstance(segment, Image): if isinstance(segment, Image):
# upload file # upload file
url = segment.url or segment.file url = segment.url or segment.file
if url.startswith("http"): if url and url.startswith("http"):
return { return {
"type": "image", "type": "image",
"image_url": url, "image_url": url,
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section", "type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"}, "text": {"type": "mrkdwn", "text": "图片上传失败"},
} }
image_url = response["files"][0]["url_private"] image_url = cast(list, response["files"])[0]["url_private"]
logger.debug(f"Slack file upload response: {response}") logger.debug(f"Slack file upload response: {response}")
return { return {
"type": "image", "type": "image",
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section", "type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"}, "text": {"type": "mrkdwn", "text": "文件上传失败"},
} }
file_url = response["files"][0]["permalink"] file_url = cast(list, response["files"])[0]["permalink"]
return { return {
"type": "section", "type": "section",
"text": { "text": {
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
) )
members = [] members = []
for member_id in members_response["members"]: for member_id in cast(Iterable, members_response["members"]):
try: try:
user_info = await self.web_client.users_info(user=member_id) user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"] user_data = cast(dict, user_info["user"])
members.append( members.append(
MessageMember( MessageMember(
user_id=member_id, user_id=member_id,
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
# 如果获取用户信息失败,使用默认信息 # 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id)) members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"] channel_data = cast(dict, channel_info["channel"])
return Group( return Group(
group_id=channel_id, group_id=channel_id,
group_name=channel_data.get("name", ""), group_name=channel_data.get("name", ""),
@@ -424,6 +424,6 @@ class TelegramPlatformAdapter(Platform):
if self.application.updater is not None: if self.application.updater is not None:
await self.application.updater.stop() await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭") logger.info("Telegram 适配器已被关闭")
except Exception as e: except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}") logger.error(f"Telegram 适配器关闭时出错: {e}")
@@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import re import re
from typing import Any, cast
import telegramify_markdown import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
Reply, Reply,
) )
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
class TelegramPlatformEvent(AstrMessageEvent): class TelegramPlatformEvent(AstrMessageEvent):
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name, "chat_id": user_name,
} }
if has_reply: if has_reply:
payload["reply_to_message_id"] = reply_message_id payload["reply_to_message_id"] = str(reply_message_id)
if message_thread_id: if message_thread_id:
payload["message_thread_id"] = message_thread_id payload["message_thread_id"] = message_thread_id
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
try: try:
md_text = telegramify_markdown.markdownify( md_text = telegramify_markdown.markdownify(
chunk, chunk,
max_line_length=None,
normalize_whitespace=False, normalize_whitespace=False,
) )
await client.send_message( await client.send_message(
text=md_text, text=md_text,
parse_mode="MarkdownV2", parse_mode="MarkdownV2",
**payload, **cast(Any, payload),
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead.", f"MarkdownV2 send failed: {e}. Using plain text instead.",
) )
await client.send_message(text=chunk, **payload) await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image): elif isinstance(i, Image):
image_path = await i.convert_to_file_path() image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload) await client.send_photo(photo=image_path, **cast(Any, payload))
elif isinstance(i, File): elif isinstance(i, File):
if i.file.startswith("https://"): path = await i.get_file()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") name = i.name or os.path.basename(path)
path = os.path.join(temp_dir, i.name) await client.send_document(
await download_file(i.file, path) document=path, filename=name, **cast(Any, payload)
i.file = path )
await client.send_document(document=i.file, filename=i.name, **payload)
elif isinstance(i, Record): elif isinstance(i, Record):
path = await i.convert_to_file_path() path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload) await client.send_voice(voice=path, **cast(Any, payload))
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE: if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text delta += i.text
elif isinstance(i, Image): elif isinstance(i, Image):
image_path = await i.convert_to_file_path() image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload) await self.client.send_photo(
photo=image_path, **cast(Any, payload)
)
continue continue
elif isinstance(i, File): elif isinstance(i, File):
if i.file.startswith("https://"): path = await i.get_file()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") name = i.name or os.path.basename(path)
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await self.client.send_document( await self.client.send_document(
document=i.file, document=path,
filename=i.name, filename=name,
**payload, **cast(Any, payload),
) )
continue continue
elif isinstance(i, Record): elif isinstance(i, Record):
path = await i.convert_to_file_path() path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload) await self.client.send_voice(voice=path, **cast(Any, payload))
continue continue
else: else:
logger.warning(f"不支持的消息类型: {type(i)}") logger.warning(f"不支持的消息类型: {type(i)}")
@@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
else: else:
# delta 长度一般不会大于 4096,因此这里直接发送 # delta 长度一般不会大于 4096,因此这里直接发送
try: try:
msg = await self.client.send_message(text=delta, **payload) msg = await self.client.send_message(
text=delta, **cast(Any, payload)
)
current_content = delta current_content = delta
except Exception as e: except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}") logger.warning(f"发送消息失败(streaming): {e!s}")
@@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try: try:
markdown_text = telegramify_markdown.markdownify( markdown_text = telegramify_markdown.markdownify(
delta, delta,
max_line_length=None,
normalize_whitespace=False, normalize_whitespace=False,
) )
await self.client.edit_message_text( await self.client.edit_message_text(
@@ -2,7 +2,7 @@ import asyncio
import os import os
import time import time
import uuid import uuid
from collections.abc import Awaitable, Callable from collections.abc import Callable, Coroutine
from typing import Any from typing import Any
from astrbot import logger from astrbot import logger
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
abm.raw_message = data abm.raw_message = data
return abm return abm
def run(self) -> Awaitable[Any]: def run(self) -> Coroutine[Any, Any, None]:
async def callback(data: tuple): async def callback(data: tuple):
abm = await self.convert_message(data) abm = await self.convert_message(data)
await self.handle_msg(abm) await self.handle_msg(abm)
@@ -101,9 +101,9 @@ class WebChatMessageEvent(AstrMessageEvent):
return data return data
async def send(self, message: MessageChain): async def send(self, message: MessageChain | None):
await WebChatMessageEvent._send(message, session_id=self.session_id) await WebChatMessageEvent._send(message, session_id=self.session_id)
await super().send(message) await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback: bool = False): async def send_streaming(self, generator, use_fallback: bool = False):
final_data = "" final_data = ""
@@ -4,6 +4,7 @@ import json
import os import os
import time import time
import traceback import traceback
from typing import cast
import aiohttp import aiohttp
import anyio import anyio
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
) )
self.base_url = f"http://{self.host}:{self.port}" self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码 self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join( self.credentials_file = os.path.join(
get_astrbot_data_path(), get_astrbot_data_path(),
"wechatpadpro_credentials.json", "wechatpadpro_credentials.json",
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
) )
await asyncio.sleep(5) await asyncio.sleep(5)
async def handle_websocket_message(self, message: str): async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。""" """处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}") logger.debug(f"收到 WebSocket 消息: {message}")
try: try:
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None: async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。""" """将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
if self.wxid is None:
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
return None
abm = AstrBotMessage() abm = AstrBotMessage()
abm.raw_message = raw_message abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id")) abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = raw_message.get("create_time") abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180: if int(time.time()) - abm.timestamp > 180:
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "") content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "") push_content = raw_message.get("push_content", "")
msg_type = raw_message.get("msg_type") msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = "" abm.message_str = ""
abm.message = [] abm.message = []
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
from_user_name: str, from_user_name: str,
to_user_name: str, to_user_name: str,
msg_id: int, msg_id: int,
): ) -> dict | None:
"""下载原始图片。""" """下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg" url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key} params = {"key": self.auth_key}
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
# 图片消息 # 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "") from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = raw_message.get("msg_id") msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image( image_resp = await self._download_raw_image(
from_user_name, from_user_name,
to_user_name, to_user_name,
msg_id, msg_id,
) )
if image_resp is None:
logger.error(f"下载图片失败: msg_id={msg_id}")
return
image_bs64_data = ( image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None) image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
) )
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
bufid = 0 bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id") new_msg_id = raw_message.get("new_msg_id")
if new_msg_id is None:
logger.error("语音消息缺少 new_msg_id")
return
data_parser = GeweDataParser( data_parser = GeweDataParser(
content=content, content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
) )
voicemsg = data_parser._format_to_xml().find("voicemsg") voicemsg = data_parser._format_to_xml().find("voicemsg")
if voicemsg is None:
logger.error("无法从 XML 解析 voicemsg 节点")
return
bufid = voicemsg.get("bufid") or "0" bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0) length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice( voice_resp = await self.download_voice(
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
bufid=bufid, bufid=bufid,
length=length, length=length,
) )
if voice_resp is None:
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None) voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data: if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data) voice_bs64_data = base64.b64decode(voice_bs64_data)
@@ -827,6 +843,7 @@ class WeChatPadProAdapter(Platform):
try: try:
if self.ws_handle_task: if self.ws_handle_task:
self.ws_handle_task.cancel() self.ws_handle_task.cancel()
if self._shutdown_event is not None:
self._shutdown_event.set() self._shutdown_event.set()
except Exception: except Exception:
pass pass
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
async def get_contact_details_list( async def get_contact_details_list(
self, self,
room_wx_id_list: list[str] = None, room_wx_id_list: list[str] | None = None,
user_names: list[str] = None, user_names: list[str] | None = None,
) -> dict | None: ) -> dict | None:
"""获取联系人详情列表。""" """获取联系人详情列表。"""
if room_wx_id_list is None: if room_wx_id_list is None:
@@ -2,7 +2,8 @@ import asyncio
import os import os
import sys import sys
import uuid import uuid
from typing import Any from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart import quart
from requests import Response from requests import Response
@@ -40,7 +41,7 @@ else:
class WecomServer: class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict): def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.port = int(config.get("port")) self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule( self.server.add_url_rule(
"/callback/command", "/callback/command",
@@ -60,7 +61,7 @@ class WecomServer:
config["corpid"].strip(), config["corpid"].strip(),
) )
self.callback = None self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
async def verify(self): async def verify(self):
@@ -114,7 +115,7 @@ class WecomServer:
logger.error("解密失败,签名异常,请检查配置。") logger.error("解密失败,签名异常,请检查配置。")
raise raise
else: else:
msg = parse_message(xml) msg = cast(BaseMessage, parse_message(xml))
logger.info(f"解析成功: {msg}") logger.info(f"解析成功: {msg}")
if self.callback: if self.callback:
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
# inject # inject
self.wechat_kf_api = WeChatKF(client=self.client) self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client) self.wechat_kf_message_api = WeChatKFMessage(self.client)
self.client.kf = self.wechat_kf_api self.client.__setattr__("kf", self.wechat_kf_api)
self.client.kf_message = self.wechat_kf_message_api self.client.__setattr__("kf_message", self.wechat_kf_message_api)
self.client.API_BASE_URL = self.api_base_url self.client.__setattr__("API_BASE_URL", self.api_base_url)
async def callback(msg: BaseMessage): async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage() abm = AstrBotMessage()
if msg.type == "text": if isinstance(msg, TextMessage):
assert isinstance(msg, TextMessage)
abm.message_str = msg.content abm.message_str = msg.content
abm.self_id = str(msg.agent) abm.self_id = str(msg.agent)
abm.message = [Plain(msg.content)] abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
elif msg.type == "image": elif isinstance(msg, ImageMessage):
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]" abm.message_str = "[图片]"
abm.self_id = str(msg.agent) abm.self_id = str(msg.agent)
abm.message = [Image(file=msg.image, url=msg.image)] abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
elif msg.type == "voice": elif isinstance(msg, VoiceMessage):
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_event_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client.media.download, self.client.media.download,
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)] abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
else: else:
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype") msgtype = msg.get("msgtype")
external_userid = msg.get("external_userid") external_userid = cast(str, msg.get("external_userid"))
abm = AstrBotMessage() abm = AstrBotMessage()
abm.raw_message = msg abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理 abm.raw_message["_wechat_kf_flag"] = None # 方便处理
@@ -425,4 +422,4 @@ class WecomPlatformAdapter(Platform):
await self.server.server.shutdown() await self.server.server.shutdown()
except Exception as _: except Exception as _:
pass pass
logger.info("企业微信 适配器已被优雅地关闭") logger.info("企业微信 适配器已被关闭")
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
if is_wechat_kf: if is_wechat_kf:
# 微信客服 # 微信客服
kf_message_api = getattr(self.client, "kf_message", None) kf_message_api = getattr(self.client, "kf_message", None)
if not kf_message_api: if not isinstance(kf_message_api, WeChatKFMessage):
logger.warning("未找到微信客服发送消息方法。") logger.warning("未找到微信客服发送消息方法。")
return return
assert isinstance(kf_message_api, WeChatKFMessage)
user_id = self.get_sender_id() user_id = self.get_sender_id()
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
@staticmethod @staticmethod
async def _send( async def _send(
message_chain: MessageChain, message_chain: MessageChain | None,
stream_id: str, stream_id: str,
queue_mgr: WecomAIQueueMgr, queue_mgr: WecomAIQueueMgr,
streaming: bool = False, streaming: bool = False,
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data return data
async def send(self, message: MessageChain): async def send(self, message: MessageChain | None):
"""发送消息""" """发送消息"""
raw = self.message_obj.raw_message raw = self.message_obj.raw_message
assert isinstance(raw, dict), ( assert isinstance(raw, dict), (
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
) )
stream_id = raw.get("stream_id", self.session_id) stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await super().send(message) await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False): async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息,参考webchat的send_streaming设计""" """流式发送消息,参考webchat的send_streaming设计"""
@@ -1,7 +1,8 @@
import asyncio import asyncio
import sys import sys
import uuid import uuid
from typing import Any from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart import quart
from requests import Response from requests import Response
@@ -36,7 +37,7 @@ else:
class WeixinOfficialAccountServer: class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict): def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.port = int(config.get("port")) self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token") self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key") self.encoding_aes_key = config.get("encoding_aes_key")
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
self.event_queue = event_queue self.event_queue = event_queue
self.callback = None self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
async def verify(self): async def verify(self):
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
raise raise
else: else:
msg = parse_message(xml) msg = parse_message(xml)
if not msg:
logger.error("解析失败。msg为None。")
raise
logger.info(f"解析成功: {msg}") logger.info(f"解析成功: {msg}")
if self.callback: if self.callback:
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.config["secret"].strip(), self.config["secret"].strip(),
) )
self.client.API_BASE_URL = self.api_base_url self.client.__setattr__("API_BASE_URL", self.api_base_url)
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重 # 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future # msgid -> Future
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.convert_message(msg, None) await self.convert_message(msg, None)
else: else:
if msg.id in self.wexin_event_workers: if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id] future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}") logger.debug(f"duplicate message id checked: {msg.id}")
else: else:
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future) await self.convert_message(msg, future)
# I love shield so much! # I love shield so much!
result = await asyncio.wait_for( result = await asyncio.wait_for(
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
60, 60,
) # wait for 60s ) # wait for 60s
logger.debug(f"Got future result: {result}") logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None) self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
async def convert_message( async def convert_message(
self, self,
msg, msg,
future: asyncio.Future = None, future: asyncio.Future | None = None,
) -> AstrBotMessage | None: ) -> AstrBotMessage | None:
abm = AstrBotMessage() abm = AstrBotMessage()
if isinstance(msg, TextMessage): if isinstance(msg, TextMessage):
abm.message_str = msg.content abm.message_str = cast(str, msg.content)
abm.self_id = str(msg.target) abm.self_id = str(msg.target)
abm.message = [Plain(msg.content)] abm.message = [Plain(cast(str, msg.content))]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
elif msg.type == "image": elif msg.type == "image":
assert isinstance(msg, ImageMessage) assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]" abm.message_str = "[图片]"
abm.self_id = str(msg.target) abm.self_id = str(msg.target)
abm.message = [Image(file=msg.image, url=msg.image)] abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
elif msg.type == "voice": elif msg.type == "voice":
assert isinstance(msg, VoiceMessage) assert isinstance(msg, VoiceMessage)
@@ -306,14 +310,15 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)] abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
else: else:
logger.warning(f"暂未实现的事件: {msg.type}") logger.warning(f"暂未实现的事件: {msg.type}")
if future:
future.set_result(None) future.set_result(None)
return return
# 很不优雅 :( # 很不优雅 :(
@@ -344,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.server.server.shutdown() await self.server.server.shutdown()
except Exception as _: except Exception as _:
pass pass
logger.info("微信公众平台 适配器已被优雅地关闭") logger.info("微信公众平台 适配器已被关闭")
@@ -1,5 +1,6 @@
import asyncio import asyncio
import uuid import uuid
from typing import cast
from wechatpy import WeChatClient from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
message_obj = self.message_obj message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False) active_send_mode = cast(dict, message_obj.raw_message).get(
"active_send_mode", False
)
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
# Split long text messages if needed # Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = TextReply( reply = TextReply(
content=chunk, content=chunk,
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = ImageReply( reply = ImageReply(
media_id=response["media_id"], media_id=response["media_id"],
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = VoiceReply( reply = VoiceReply(
media_id=response["media_id"], media_id=response["media_id"],
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
+3 -3
View File
@@ -4,7 +4,7 @@ import asyncio
import copy import copy
import json import json
import os import os
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import Any
import aiohttp import aiohttp
@@ -118,7 +118,7 @@ class FunctionToolManager:
name: str, name: str,
func_args: list[dict], func_args: list[dict],
desc: str, desc: str,
handler: Callable[..., Awaitable[Any]], handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> FuncTool: ) -> FuncTool:
params = { params = {
"type": "object", # hard-coded here "type": "object", # hard-coded here
@@ -140,7 +140,7 @@ class FunctionToolManager:
name: str, name: str,
func_args: list, func_args: list,
desc: str, desc: str,
handler: Callable[..., Awaitable[Any]], handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> None: ) -> None:
"""添加函数调用工具 """添加函数调用工具
+77 -21
View File
@@ -1,5 +1,6 @@
import asyncio import asyncio
import traceback import traceback
from typing import Protocol, runtime_checkable
from astrbot.core import astrbot_config, logger, sp from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -10,6 +11,7 @@ from .entities import ProviderType
from .provider import ( from .provider import (
EmbeddingProvider, EmbeddingProvider,
Provider, Provider,
Providers,
RerankProvider, RerankProvider,
STTProvider, STTProvider,
TTSProvider, TTSProvider,
@@ -17,6 +19,11 @@ from .provider import (
from .register import llm_tools, provider_cls_map from .register import llm_tools, provider_cls_map
@runtime_checkable
class HasInitialize(Protocol):
async def initialize(self) -> None: ...
class ProviderManager: class ProviderManager:
def __init__( def __init__(
self, self,
@@ -48,7 +55,7 @@ class ProviderManager:
"""加载的 Rerank Provider 的实例""" """加载的 Rerank Provider 的实例"""
self.inst_map: dict[ self.inst_map: dict[
str, str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider, Providers,
] = {} ] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例""" """Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools self.llm_tools = llm_tools
@@ -123,15 +130,13 @@ class ProviderManager:
self.curr_provider_inst = prov self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global") sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None: async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例""" """根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id) return self.inst_map.get(provider_id)
def get_using_provider( def get_using_provider(
self, self, provider_type: ProviderType, umo=None
provider_type: ProviderType, ) -> Providers | None:
umo=None,
) -> Provider | STTProvider | TTSProvider | None:
"""获取正在使用的提供商实例。 """获取正在使用的提供商实例。
Args: Args:
@@ -191,7 +196,6 @@ class ProviderManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(e) logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get( selected_provider_id = sp.get(
"curr_provider", "curr_provider",
self.provider_settings.get("default_provider_id"), self.provider_settings.get("default_provider_id"),
@@ -210,15 +214,37 @@ class ProviderManager:
scope="global", scope="global",
scope_id="global", scope_id="global",
) )
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
temp_provider = (
self.inst_map.get(selected_provider_id)
if isinstance(selected_provider_id, str)
else None
)
self.curr_provider_inst = (
temp_provider if isinstance(temp_provider, Provider) else None
)
if not self.curr_provider_inst and self.provider_insts: if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0] self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) temp_stt = (
self.inst_map.get(selected_stt_provider_id)
if isinstance(selected_stt_provider_id, str)
else None
)
self.curr_stt_provider_inst = (
temp_stt if isinstance(temp_stt, STTProvider) else None
)
if not self.curr_stt_provider_inst and self.stt_provider_insts: if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0] self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) temp_tts = (
self.inst_map.get(selected_tts_provider_id)
if isinstance(selected_tts_provider_id, str)
else None
)
self.curr_tts_provider_inst = (
temp_tts if isinstance(temp_tts, TTSProvider) else None
)
if not self.curr_tts_provider_inst and self.tts_provider_insts: if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0] self.curr_tts_provider_inst = self.tts_provider_insts[0]
@@ -358,11 +384,16 @@ class ProviderManager:
provider_metadata.id = provider_config["id"] provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: match provider_metadata.provider_type:
case ProviderType.SPEECH_TO_TEXT:
# STT 任务 # STT 任务
if not issubclass(cls_type, STTProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of STTProvider"
)
inst = cls_type(provider_config, self.provider_settings) inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): if isinstance(inst, HasInitialize):
await inst.initialize() await inst.initialize()
self.stt_provider_insts.append(inst) self.stt_provider_insts.append(inst)
@@ -377,15 +408,22 @@ class ProviderManager:
if not self.curr_stt_provider_inst: if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: case ProviderType.TEXT_TO_SPEECH:
# TTS 任务 # TTS 任务
if not issubclass(cls_type, TTSProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of TTSProvider"
)
inst = cls_type(provider_config, self.provider_settings) inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): if isinstance(inst, HasInitialize):
await inst.initialize() await inst.initialize()
self.tts_provider_insts.append(inst) self.tts_provider_insts.append(inst)
if self.provider_settings.get("provider_id") == provider_config["id"]: if (
self.provider_settings.get("provider_id")
== provider_config["id"]
):
self.curr_tts_provider_inst = inst self.curr_tts_provider_inst = inst
logger.info( logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
@@ -393,14 +431,18 @@ class ProviderManager:
if not self.curr_tts_provider_inst: if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: case ProviderType.CHAT_COMPLETION:
# 文本生成任务 # 文本生成任务
if not issubclass(cls_type, Provider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of Provider"
)
inst = cls_type( inst = cls_type(
provider_config, provider_config,
self.provider_settings, self.provider_settings,
) )
if getattr(inst, "initialize", None): if isinstance(inst, HasInitialize):
await inst.initialize() await inst.initialize()
self.provider_insts.append(inst) self.provider_insts.append(inst)
@@ -415,16 +457,30 @@ class ProviderManager:
if not self.curr_provider_inst: if not self.curr_provider_inst:
self.curr_provider_inst = inst self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING: case ProviderType.EMBEDDING:
if not issubclass(cls_type, EmbeddingProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
)
inst = cls_type(provider_config, self.provider_settings) inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): if isinstance(inst, HasInitialize):
await inst.initialize() await inst.initialize()
self.embedding_provider_insts.append(inst) self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK: case ProviderType.RERANK:
if not issubclass(cls_type, RerankProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of RerankProvider"
)
inst = cls_type(provider_config, self.provider_settings) inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): if isinstance(inst, HasInitialize):
await inst.initialize() await inst.initialize()
self.rerank_provider_insts.append(inst) self.rerank_provider_insts.append(inst)
case _:
# 未知供应商抛出异常,确保inst初始化
# Should be unreachable
raise Exception(
f"未知的提供商类型:{provider_metadata.provider_type}"
)
self.inst_map[provider_config["id"]] = inst self.inst_map[provider_config["id"]] = inst
except Exception as e: except Exception as e:
+12 -1
View File
@@ -2,6 +2,7 @@ import abc
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.register import provider_cls_map from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path from astrbot.core.utils.astrbot_path import get_astrbot_path
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
"""Provider Abstract Class""" """Provider Abstract Class"""
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误 - 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
""" """
... if False: # pragma: no cover - make this an async generator for typing
yield None # type: ignore
raise NotImplementedError()
async def pop_record(self, context: list): async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录""" """弹出 context 第一条非系统提示词对话记录"""
@@ -29,15 +29,24 @@ class OTTSProvider:
self.last_sync_time = 0 self.last_sync_time = 0
self.timeout = Timeout(10.0) self.timeout = Timeout(10.0)
self.retry_count = 3 self.retry_count = 3
self.client = None self._client: AsyncClient | None = None
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self): async def __aenter__(self):
self.client = AsyncClient(timeout=self.timeout) self._client = AsyncClient(timeout=self.timeout)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client: if self._client:
await self.client.aclose() await self._client.aclose()
self._client = None
async def _sync_time(self): async def _sync_time(self):
try: try:
@@ -90,6 +99,7 @@ class OTTSProvider:
if attempt == self.retry_count - 1: if attempt == self.retry_count - 1:
raise RuntimeError(f"OTTS请求失败: {e!s}") from e raise RuntimeError(f"OTTS请求失败: {e!s}") from e
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
raise RuntimeError("OTTS未返回音频文件")
class AzureNativeProvider(TTSProvider): class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
self.endpoint = ( self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
) )
self.client = None self._client: AsyncClient | None = None
self.token = None self.token = None
self.token_expire = 0 self.token_expire = 0
self.voice_params = { self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100"), "volume": provider_config.get("azure_tts_volume", "100"),
} }
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self): async def __aenter__(self):
self.client = AsyncClient( self._client = AsyncClient(
headers={ headers={
"User-Agent": f"AstrBot/{VERSION}", "User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml", "Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client: if self._client:
await self.client.aclose() await self._client.aclose()
self._client = None
async def _refresh_token(self): async def _refresh_token(self):
token_url = ( token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
key_value = provider_config.get("azure_tts_subscription_key", "") key_value = provider_config.get("azure_tts_subscription_key", "")
self.provider = self._parse_provider(key_value, provider_config) self.provider = self._parse_provider(key_value, provider_config)
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider: def _parse_provider(
self, key_value: str, config: dict
) -> OTTSProvider | AzureNativeProvider:
if key_value.lower().startswith("other["): if key_value.lower().startswith("other["):
json_str = ""
try: try:
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL) match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
if not match: if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
Returns: Returns:
重排序结果列表 重排序结果列表
""" """
if not self.client:
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
return []
if not documents: if not documents:
logger.warning("文档列表为空,返回空结果") logger.warning("文档列表为空,返回空结果")
return [] return []
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "") self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella") self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000 self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key dashscope.api_key = self.chosen_api_key
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
kwargs = { kwargs = {
"model": model, "model": model,
"text": text, "messages": None,
"api_key": self.chosen_api_key, "api_key": self.chosen_api_key,
"voice": self.voice or "Cherry", "voice": self.voice or "Cherry",
"text": text,
} }
if not self.voice: if not self.voice:
logging.warning( logging.warning(
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
from pyffmpeg import FFmpeg from pyffmpeg import FFmpeg
ff = FFmpeg() ff = FFmpeg()
ff.convert(input=mp3_path, output=wav_path) ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e: except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line # use ffmpeg command line
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
self.headers = { self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}", "Authorization": f"Bearer {self.chosen_api_key}",
} }
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
async def _get_reference_id_by_character(self, character: str) -> str: async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id """获取角色的reference_id
Args: Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
pattern = r"^[a-fA-F0-9]{32}$" pattern = r"^[a-fA-F0-9]{32}$"
return bool(re.match(pattern, reference_id.strip())) return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict: async def _generate_request(self, text: str) -> ServeTTSRequest:
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询 # 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip(): if self.reference_id and self.reference_id.strip():
# 验证reference_id格式 # 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
async for chunk in response.aiter_bytes(): async for chunk in response.aiter_bytes():
f.write(chunk) f.write(chunk)
return path return path
text = await response.aread() body = await response.aread()
text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}") raise Exception(f"Fish Audio API请求失败: {text}")
@@ -1,3 +1,5 @@
from typing import cast
from google import genai from google import genai
from google.genai import types from google.genai import types
from google.genai.errors import APIError from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
self.provider_config = provider_config self.provider_config = provider_config
self.provider_settings = provider_settings self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key") api_key: str = provider_config["embedding_api_key"]
api_base: str = provider_config.get("embedding_api_base") api_base: str = provider_config["embedding_api_base"]
timeout: int = int(provider_config.get("timeout", 20)) timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000) http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
model=self.model, model=self.model,
contents=text, contents=text,
) )
assert result.embeddings is not None
assert result.embeddings[0].values is not None
return result.embeddings[0].values return result.embeddings[0].values
except APIError as e: except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}") raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]: async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入""" """批量获取文本的嵌入"""
try: try:
result = await self.client.models.embed_content( result = await self.client.models.embed_content(
model=self.model, model=self.model,
contents=texts, contents=cast(types.ContentListUnion, text),
) )
return [embedding.values for embedding in result.embeddings] assert result.embeddings is not None
embeddings: list[list[float]] = []
for embedding in result.embeddings:
assert embedding.values is not None
embeddings.append(embedding.values)
return embeddings
except APIError as e: except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
+12 -10
View File
@@ -4,6 +4,7 @@ import json
import logging import logging
import random import random
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import cast
from google import genai from google import genai
from google.genai import types from google.genai import types
@@ -126,17 +127,17 @@ class ProviderGoogleGenAI(Provider):
) -> types.GenerateContentConfig: ) -> types.GenerateContentConfig:
"""准备查询配置""" """准备查询配置"""
if not modalities: if not modalities:
modalities = ["Text"] modalities = ["TEXT"]
# 流式输出不支持图片模态 # 流式输出不支持图片模态
if ( if (
self.provider_settings.get("streaming_response", False) self.provider_settings.get("streaming_response", False)
and "Image" in modalities and "IMAGE" in modalities
): ):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态") logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"] modalities = ["TEXT"]
tool_list = [] tool_list: list[types.Tool] | None = []
model_name = self.get_model() model_name = self.get_model()
native_coderunner = self.provider_config.get("gm_native_coderunner", False) native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False) native_search = self.provider_config.get("gm_native_search", False)
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
logprobs=payloads.get("logprobs"), logprobs=payloads.get("logprobs"),
seed=payloads.get("seed"), seed=payloads.get("seed"),
response_modalities=modalities, response_modalities=modalities,
tools=tool_list, tools=cast(types.ToolListUnion | None, tool_list),
safety_settings=self.safety_settings if self.safety_settings else None, safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=( thinking_config=(
types.ThinkingConfig( types.ThinkingConfig(
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
content_cls: type[types.Content], content_cls: type[types.Content],
) -> None: ) -> None:
if contents and isinstance(contents[-1], content_cls): if contents and isinstance(contents[-1], content_cls):
assert contents[-1].parts is not None
contents[-1].parts.extend(part) contents[-1].parts.extend(part)
else: else:
contents.append(content_cls(parts=part)) contents.append(content_cls(parts=part))
@@ -429,9 +431,9 @@ class ProviderGoogleGenAI(Provider):
None, None,
) )
modalities = ["Text"] modalities = ["TEXT"]
if self.provider_config.get("gm_resp_image_modal", False): if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("Image") modalities.append("IMAGE")
conversation = self._prepare_conversation(payloads) conversation = self._prepare_conversation(payloads)
temperature = payloads.get("temperature", 0.7) temperature = payloads.get("temperature", 0.7)
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
) )
result = await self.client.models.generate_content( result = await self.client.models.generate_content(
model=self.get_model(), model=self.get_model(),
contents=conversation, contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
logger.debug(f"genai result: {result}") logger.debug(f"genai result: {result}")
@@ -488,7 +490,7 @@ class ProviderGoogleGenAI(Provider):
logger.warning( logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态", f"{self.get_model()} 不支持多模态输出,降级为文本模态",
) )
modalities = ["Text"] modalities = ["TEXT"]
else: else:
raise raise
continue continue
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
) )
result = await self.client.models.generate_content_stream( result = await self.client.models.generate_content_stream(
model=self.get_model(), model=self.get_model(),
contents=conversation, contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
break break
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return json.dumps(dict_body) return json.dumps(dict_body)
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
"""进行流式请求""" """进行流式请求"""
try: try:
async with ( async with (
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
data = json.loads(message[6:]) data = json.loads(message[6:])
if "extra_info" in data: if "extra_info" in data:
continue continue
audio = data.get("data", {}).get("audio") audio: str | None = data.get("data", {}).get(
"audio"
)
if audio is not None: if audio is not None:
yield audio yield audio
except json.JSONDecodeError: except json.JSONDecodeError:
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
embedding = await self.client.embeddings.create(input=text, model=self.model) embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]: async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入""" """批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model) embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data] return [item.embedding for item in embeddings.data]
def get_dim(self) -> int: def get_dim(self) -> int:
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
if isinstance(tool_call, str): if isinstance(tool_call, str):
# workaround for #1359 # workaround for #1359
tool_call = json.loads(tool_call) tool_call = json.loads(tool_call)
if tools is None:
# 工具集未提供
# Should be unreachable
raise Exception("工具集未提供")
for tool in tools.func_list: for tool in tools.func_list:
if ( if (
tool_call.type == "function" tool_call.type == "function"
@@ -7,6 +7,7 @@ import asyncio
import os import os
import re import re
from datetime import datetime from datetime import datetime
from typing import cast
from funasr_onnx import SenseVoiceSmall from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
provider_settings: dict, provider_settings: dict,
) -> None: ) -> None:
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("stt_model")) self.set_model(provider_config["stt_model"])
self.model = None self.model = None
self.is_emotion = provider_config.get("is_emotion", False) self.is_emotion = provider_config.get("is_emotion", False)
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
res = await loop.run_in_executor( res = await loop.run_in_executor(
None, # 使用默认的线程池 None, # 使用默认的线程池
lambda: self.model(audio_url, language="auto", use_itn=True), lambda: cast(SenseVoiceSmall, self.model)(
audio_url, language="auto", use_itn=True
),
) )
# res = self.model(audio_url, language="auto", use_itn=True) # res = self.model(audio_url, language="auto", use_itn=True)
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
} }
if top_n is not None: if top_n is not None:
payload["top_n"] = top_n payload["top_n"] = top_n
assert self.client is not None
async with self.client.post( async with self.client.post(
f"{self.base_url}/v1/rerank", f"{self.base_url}/v1/rerank",
json=payload, json=payload,
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
timeout=provider_config.get("timeout", NOT_GIVEN), timeout=provider_config.get("timeout", NOT_GIVEN),
) )
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
async def _get_audio_format(self, file_path): async def _get_audio_format(self, file_path):
# 定义要检测的头部字节 # 定义要检测的头部字节
@@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import uuid import uuid
from typing import cast
import whisper import whisper
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
provider_settings: dict, provider_settings: dict,
) -> None: ) -> None:
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
self.model = None self.model = None
async def initialize(self): async def initialize(self):
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
await tencent_silk_to_wav(audio_url, output_path) await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path audio_url = output_path
if not self.model:
raise RuntimeError("Whisper 模型未初始化")
result = await loop.run_in_executor(None, self.model.transcribe, audio_url) result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result["text"] return cast(str, result["text"])
@@ -1,6 +1,11 @@
from typing import cast
from xinference_client.client.restful.async_restful_client import ( from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client, AsyncClient as Client,
) )
from xinference_client.client.restful.async_restful_client import (
AsyncRESTfulRerankModelHandle,
)
from astrbot import logger from astrbot import logger
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
False, False,
) )
self.client = None self.client = None
self.model = None self.model: AsyncRESTfulRerankModelHandle | None = None
self.model_uid = None self.model_uid = None
async def initialize(self): async def initialize(self):
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
return return
if self.model_uid: if self.model_uid:
self.model = await self.client.get_model(self.model_uid) self.model = cast(
AsyncRESTfulRerankModelHandle,
await self.client.get_model(self.model_uid),
)
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}") logger.error(f"Failed to initialize Xinference model: {e}")
+2 -2
View File
@@ -285,7 +285,7 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。""" """获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str | None = None) -> Provider | None: def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args: Args:
@@ -296,7 +296,7 @@ class Context:
provider_type=ProviderType.CHAT_COMPLETION, provider_type=ProviderType.CHAT_COMPLETION,
umo=umo, umo=umo,
) )
if prov and not isinstance(prov, Provider): if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型") raise ValueError("返回的 Provider 不是 Provider 类型")
return prov return prov
+22 -5
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import Any
import docstring_parser import docstring_parser
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools from astrbot.core.provider.register import llm_tools
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str: def get_handler_full_name(
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> str:
"""获取 Handler 的全名""" """获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}" return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create( def get_handler_or_create(
handler: Callable[..., Awaitable[Any]], handler: Callable[
...,
Awaitable[MessageEventResult | str | None]
| AsyncGenerator[MessageEventResult | str | None],
],
event_type: EventType, event_type: EventType,
dont_add=False, dont_add=False,
**kwargs, **kwargs,
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
for ( for (
sub_handle sub_handle
) in parent_register_commandable.parent_group.sub_command_filters: ) in parent_register_commandable.parent_group.sub_command_filters:
if isinstance(sub_handle, CommandGroupFilter):
continue
# 所有符合fullname一致的子指令handle添加自定义过滤器。 # 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
sub_handle_md = sub_handle.get_handler_md() sub_handle_md = sub_handle.get_handler_md()
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
else: else:
# 裸指令 # 裸指令
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
assert isinstance(awaitable, Callable)
handler_md = get_handler_or_create( handler_md = get_handler_or_create(
awaitable, awaitable,
EventType.AdapterMessageEvent, EventType.AdapterMessageEvent,
@@ -237,7 +248,7 @@ class RegisteringCommandable:
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
command: Callable[..., Callable[..., None]] = register_command command: Callable[..., Callable[..., None]] = register_command
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter): def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group self.parent_group = parent_group
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
if kwargs.get("registering_agent"): if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"] registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Callable[..., Awaitable[Any]]): def decorator(
awaitable: Callable[
...,
AsyncGenerator[MessageEventResult | str | None]
| Awaitable[MessageEventResult | str | None],
],
):
llm_tool_name = name_ if name_ else awaitable.__name__ llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or "" func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc) docstring = docstring_parser.parse(func_doc)
+85 -4
View File
@@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar from typing import Any, Generic, Literal, TypeVar, overload
from .filter import HandlerFilter from .filter import HandlerFilter
from .star import star_map from .star import star_map
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
for handler in self._handlers: for handler in self._handlers:
print(handler.handler_full_name) print(handler.handler_full_name)
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAstrBotLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnPlatformLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.AdapterMessageEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMRequestEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMResponseEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnDecoratingResultEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnCallingFuncToolEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAfterMessageSentEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: EventType,
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
def get_handlers_by_event_type( def get_handlers_by_event_type(
self, self,
event_type: EventType, event_type: EventType,
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
OnAfterMessageSentEvent = enum.auto() # 发送消息后 OnAfterMessageSentEvent = enum.auto() # 发送消息后
H = TypeVar("H", bound=Callable[..., Any])
@dataclass @dataclass
class StarHandlerMetadata: class StarHandlerMetadata(Generic[H]):
"""描述一个 Star 所注册的某一个 Handler。""" """描述一个 Star 所注册的某一个 Handler。"""
event_type: EventType event_type: EventType
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
handler_module_path: str handler_module_path: str
"""Handler 所在的模块路径。""" """Handler 所在的模块路径。"""
handler: Callable[..., Awaitable[Any]] handler: H
"""Handler 的函数对象,应当是一个异步函数""" """Handler 的函数对象,应当是一个异步函数"""
event_filters: list[HandlerFilter] event_filters: list[HandlerFilter]
+3 -3
View File
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
async def check_update( async def check_update(
self, self,
url: str, url: str | None,
current_version: str, current_version: str | None,
consider_prerelease: bool = True, consider_prerelease: bool = True,
) -> ReleaseInfo: ) -> ReleaseInfo | None:
"""检查更新""" """检查更新"""
return await super().check_update( return await super().check_update(
self.ASTRBOT_RELEASE_API, self.ASTRBOT_RELEASE_API,
+1 -1
View File
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
return False return False
def save_temp_img(img: Image.Image | str) -> str: def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = os.path.join(get_astrbot_data_path(), "temp")
# 获得文件创建时间,清除超过 12 小时的 # 获得文件创建时间,清除超过 12 小时的
try: try:
+17 -10
View File
@@ -20,16 +20,16 @@ class SessionController:
def __init__(self): def __init__(self):
self.future = asyncio.Future() self.future = asyncio.Future()
self.current_event: asyncio.Event = None self.current_event: asyncio.Event | None = None
"""当前正在等待的所用的异步事件""" """当前正在等待的所用的异步事件"""
self.ts: float = None self.ts: float | None = None
"""上次保持(keep)开始时的时间""" """上次保持(keep)开始时的时间"""
self.timeout: float | int = None self.timeout: float | int | None = None
"""上次保持(keep)开始时的超时时间""" """上次保持(keep)开始时的超时时间"""
self.history_chains: list[list[Comp.BaseMessageComponent]] = [] self.history_chains: list[list[Comp.BaseMessageComponent]] = []
def stop(self, error: Exception = None): def stop(self, error: Exception | None = None):
"""立即结束这个会话""" """立即结束这个会话"""
if not self.future.done(): if not self.future.done():
if error: if error:
@@ -53,6 +53,8 @@ class SessionController:
self.stop() self.stop()
return return
else: else:
assert self.timeout is not None
assert self.ts is not None
left_timeout = self.timeout - (new_ts - self.ts) left_timeout = self.timeout - (new_ts - self.ts)
timeout = left_timeout + timeout timeout = left_timeout + timeout
if timeout <= 0: if timeout <= 0:
@@ -69,7 +71,7 @@ class SessionController:
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
async def _holding(self, event: asyncio.Event, timeout: int): async def _holding(self, event: asyncio.Event, timeout: float):
"""等待事件结束或超时""" """等待事件结束或超时"""
try: try:
await asyncio.wait_for(event.wait(), timeout) await asyncio.wait_for(event.wait(), timeout)
@@ -108,7 +110,9 @@ class SessionWaiter:
): ):
self.session_id = session_id self.session_id = session_id
self.session_filter = session_filter self.session_filter = session_filter
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数 self.handler: (
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
) = None # 处理函数
self.session_controller = SessionController() self.session_controller = SessionController()
self.record_history_chains = record_history_chains self.record_history_chains = record_history_chains
@@ -119,7 +123,7 @@ class SessionWaiter:
async def register_wait( async def register_wait(
self, self,
handler: Callable[[str], Awaitable[Any]], handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
timeout: int = 30, timeout: int = 30,
) -> Any: ) -> Any:
"""等待外部输入并处理""" """等待外部输入并处理"""
@@ -137,7 +141,7 @@ class SessionWaiter:
finally: finally:
self._cleanup() self._cleanup()
def _cleanup(self, error: Exception = None): def _cleanup(self, error: Exception | None = None):
"""清理会话""" """清理会话"""
USER_SESSIONS.pop(self.session_id, None) USER_SESSIONS.pop(self.session_id, None)
try: try:
@@ -161,6 +165,7 @@ class SessionWaiter:
) )
try: try:
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行 # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
assert session.handler is not None
await session.handler(session.session_controller, event) await session.handler(session.session_controller, event)
except Exception as e: except Exception as e:
session.session_controller.stop(e) session.session_controller.stop(e)
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
:param record_history_chain: 是否自动记录历史消息链可以通过 controller.get_history_chains() 获取深拷贝 :param record_history_chain: 是否自动记录历史消息链可以通过 controller.get_history_chains() 获取深拷贝
""" """
def decorator(func: Callable[[str], Awaitable[Any]]): def decorator(
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
):
@functools.wraps(func) @functools.wraps(func)
async def wrapper( async def wrapper(
event: AstrMessageEvent, event: AstrMessageEvent,
session_filter: SessionFilter = None, session_filter: SessionFilter | None = None,
*args, *args,
**kwargs, **kwargs,
): ):
+32
View File
@@ -53,6 +53,38 @@ class SharedPreferences:
ret = await self.db_helper.get_preferences(scope, scope_id, key) ret = await self.db_helper.get_preferences(scope, scope_id, key)
return ret return ret
@overload
async def session_get(
self,
umo: str,
key: str,
default: _VT = None,
) -> _VT: ...
@overload
async def session_get(
self,
umo: None,
key: str,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: str,
key: None,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: None,
key: None,
default: Any = None,
) -> list[Preference]: ...
async def session_get( async def session_get(
self, self,
umo: str | None, umo: str | None,
+2 -2
View File
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
class RenderStrategy(ABC): class RenderStrategy(ABC):
@abstractmethod @abstractmethod
def render(self, text: str, return_url: bool) -> str: async def render(self, text: str, return_url: bool) -> str:
pass pass
@abstractmethod @abstractmethod
def render_custom_template( async def render_custom_template(
self, self,
tmpl_str: str, tmpl_str: str,
tmpl_data: dict, tmpl_data: dict,
+25 -31
View File
@@ -20,7 +20,7 @@ class FontManager:
_font_cache = {} _font_cache = {}
@classmethod @classmethod
def get_font(cls, size: int) -> ImageFont.FreeTypeFont: def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont:
"""获取指定大小的字体,优先从缓存获取""" """获取指定大小的字体,优先从缓存获取"""
if size in cls._font_cache: if size in cls._font_cache:
return cls._font_cache[size] return cls._font_cache[size]
@@ -66,23 +66,17 @@ class TextMeasurer:
"""测量文本尺寸的工具类""" """测量文本尺寸的工具类"""
@staticmethod @staticmethod
def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]:
"""获取文本的尺寸""" """获取文本的尺寸"""
try:
# PIL 9.0.0 以上版本 # 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0
return ( left, top, right, bottom = font.getbbox("Hello world")
font.getbbox(text)[2:] return int(right - left), int(bottom - top)
if hasattr(font, "getbbox")
else font.getsize(text)
)
except Exception:
# 兼容旧版本
return font.getsize(text)
@staticmethod @staticmethod
def split_text_to_fit_width( def split_text_to_fit_width(
text: str, font: ImageFont.FreeTypeFont, max_width: int text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int
) -> List[str]: ) -> list[str]:
"""将文本拆分为多行,确保每行不超过指定宽度""" """将文本拆分为多行,确保每行不超过指定宽度"""
lines = [] lines = []
if not text: if not text:
@@ -126,7 +120,7 @@ class MarkdownElement(ABC):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -152,7 +146,7 @@ class TextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement):
# 倾斜变换,使用仿射变换实现斜体效果 # 倾斜变换,使用仿射变换实现斜体效果
# 变换矩阵: [1, 0.2, 0, 0, 1, 0] # 变换矩阵: [1, 0.2, 0, 0, 1, 0]
italic_img = text_img.transform( italic_img = text_img.transform(
text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC
) )
# 粘贴到原图像 # 粘贴到原图像
@@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement):
class CodeBlockElement(MarkdownElement): class CodeBlockElement(MarkdownElement):
"""代码块元素""" """代码块元素"""
def __init__(self, content: List[str]): def __init__(self, content: list[str]):
super().__init__("\n".join(content)) super().__init__("\n".join(content))
def calculate_height(self, image_width: int, font_size: int) -> int: def calculate_height(self, image_width: int, font_size: int) -> int:
@@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -667,7 +661,7 @@ class ImageElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -686,7 +680,7 @@ class ImageElement(MarkdownElement):
if pasted_image.width > max_width: if pasted_image.width > max_width:
ratio = max_width / pasted_image.width ratio = max_width / pasted_image.width
new_size = (int(max_width), int(pasted_image.height * ratio)) new_size = (int(max_width), int(pasted_image.height * ratio))
pasted_image = pasted_image.resize(new_size, Image.LANCZOS) pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS)
# 计算居中位置 # 计算居中位置
paste_x = x + (image_width - pasted_image.width) // 2 - 10 paste_x = x + (image_width - pasted_image.width) // 2 - 10
@@ -705,7 +699,7 @@ class MarkdownParser:
"""Markdown解析器,将文本解析为元素""" """Markdown解析器,将文本解析为元素"""
@staticmethod @staticmethod
async def parse(text: str) -> List[MarkdownElement]: async def parse(text: str) -> list[MarkdownElement]:
elements = [] elements = []
lines = text.split("\n") lines = text.split("\n")
@@ -847,7 +841,7 @@ class MarkdownRenderer:
self, self,
font_size: int = 26, font_size: int = 26,
width: int = 800, width: int = 800,
bg_color: Tuple[int, int, int] = (255, 255, 255), bg_color: tuple[int, int, int] = (255, 255, 255),
): ):
self.font_size = font_size self.font_size = font_size
self.width = width self.width = width
+1 -1
View File
@@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
from pyffmpeg import FFmpeg from pyffmpeg import FFmpeg
ff = FFmpeg() ff = FFmpeg()
ff.convert(input=input_path, output=output_path) ff.convert(input_file=input_path, output_file=output_path)
except Exception as e: except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
+6 -3
View File
@@ -60,9 +60,12 @@ class VersionComparator:
return -1 return -1
if isinstance(p1, str) and isinstance(p2, int): if isinstance(p1, str) and isinstance(p2, int):
return 1 return 1
if (isinstance(p1, int) and isinstance(p2, int)) or ( if isinstance(p1, int) and isinstance(p2, int):
isinstance(p1, str) and isinstance(p2, str) if p1 > p2:
): return 1
if p1 < p2:
return -1
if isinstance(p1, str) and isinstance(p2, str):
if p1 > p2: if p1 > p2:
return 1 return 1
if p1 < p2: if p1 < p2:
+19
View File
@@ -1,4 +1,7 @@
import uuid
from astrbot.core import astrbot_config, logger from astrbot.core import astrbot_config, logger
from astrbot.core.config.default import WEBHOOK_SUPPORTED_PLATFORMS
def _get_callback_api_base() -> str: def _get_callback_api_base() -> str:
@@ -45,3 +48,19 @@ def log_webhook_info(platform_name: str, webhook_uuid: str):
"====================\n" "====================\n"
) )
logger.info(display_log) logger.info(display_log)
def ensure_platform_webhook_config(platform_cfg: dict) -> bool:
"""为支持统一 webhook 的平台自动生成 webhook_uuid
Args:
platform_cfg (dict): 平台配置字典
Returns:
bool: 如果生成了 webhook_uuid 则返回 True否则返回 False
"""
pt = platform_cfg.get("type", "")
if pt in WEBHOOK_SUPPORTED_PLATFORMS and not platform_cfg.get("webhook_uuid"):
platform_cfg["webhook_uuid"] = uuid.uuid4().hex[:16]
return True
return False
+7 -2
View File
@@ -4,7 +4,9 @@ import mimetypes
import os import os
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import cast
from quart import Response as QuartResponse
from quart import g, make_response, request, send_file from quart import g, make_response, request, send_file
from astrbot.core import logger from astrbot.core import logger
@@ -424,7 +426,9 @@ class ChatRoute(Route):
sender_name=username, sender_name=username,
) )
response = await make_response( response = cast(
QuartResponse,
await make_response(
stream(), stream(),
{ {
"Content-Type": "text/event-stream", "Content-Type": "text/event-stream",
@@ -432,8 +436,9 @@ class ChatRoute(Route):
"Transfer-Encoding": "chunked", "Transfer-Encoding": "chunked",
"Connection": "keep-alive", "Connection": "keep-alive",
}, },
),
) )
response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue] response.timeout = None # fix SSE auto disconnect issue
return response return response
async def delete_webchat_session(self): async def delete_webchat_session(self):
+10 -20
View File
@@ -2,7 +2,7 @@ import asyncio
import inspect import inspect
import os import os
import traceback import traceback
import uuid from typing import Any
from quart import request from quart import request
@@ -14,7 +14,6 @@ from astrbot.core.config.default import (
CONFIG_METADATA_3_SYSTEM, CONFIG_METADATA_3_SYSTEM,
DEFAULT_CONFIG, DEFAULT_CONFIG,
DEFAULT_VALUE_MAP, DEFAULT_VALUE_MAP,
WEBHOOK_SUPPORTED_PLATFORMS,
) )
from astrbot.core.config.i18n_utils import ConfigMetadataI18n from astrbot.core.config.i18n_utils import ConfigMetadataI18n
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -22,11 +21,12 @@ from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.provider.register import provider_registry from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry from astrbot.core.star.star import star_registry
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
def try_cast(value: str, type_: str): def try_cast(value: Any, type_: str):
if type_ == "int": if type_ == "int":
try: try:
return int(value) return int(value)
@@ -505,9 +505,9 @@ class ConfigRoute(Route):
if not isinstance(inst, EmbeddingProvider): if not isinstance(inst, EmbeddingProvider):
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
# 初始化 init_fn = getattr(inst, "initialize", None)
if getattr(inst, "initialize", None): if inspect.iscoroutinefunction(init_fn):
await inst.initialize() await init_fn()
# 获取嵌入向量维度 # 获取嵌入向量维度
vec = await inst.get_embedding("echo") vec = await inst.get_embedding("echo")
@@ -558,13 +558,8 @@ class ConfigRoute(Route):
async def post_new_platform(self): async def post_new_platform(self):
new_platform_config = await request.json new_platform_config = await request.json
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,自动生成 webhook_uuid # 如果是支持统一 webhook 模式的平台,生成 webhook_uuid
platform_type = new_platform_config.get("type", "") ensure_platform_webhook_config(new_platform_config)
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
if new_platform_config.get("unified_webhook_mode", False):
# 如果没有 webhook_uuid 或为空,自动生成
if not new_platform_config.get("webhook_uuid"):
new_platform_config["webhook_uuid"] = uuid.uuid4().hex[:16]
self.config["platform"].append(new_platform_config) self.config["platform"].append(new_platform_config)
try: try:
@@ -596,12 +591,7 @@ class ConfigRoute(Route):
return Response().error("参数错误").__dict__ return Response().error("参数错误").__dict__
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid
platform_type = new_config.get("type", "") ensure_platform_webhook_config(new_config)
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
if new_config.get("unified_webhook_mode", False):
# 如果没有 webhook_uuid 或为空,自动生成
if not new_config.get("webhook_uuid"):
new_config["webhook_uuid"] = uuid.uuid4().hex
for i, platform in enumerate(self.config["platform"]): for i, platform in enumerate(self.config["platform"]):
if platform["id"] == platform_id: if platform["id"] == platform_id:
@@ -777,7 +767,7 @@ class ConfigRoute(Route):
return {"metadata": CONFIG_METADATA_2, "config": config} return {"metadata": CONFIG_METADATA_2, "config": config}
async def _get_plugin_config(self, plugin_name: str): async def _get_plugin_config(self, plugin_name: str):
ret = {"metadata": None, "config": None} ret: dict = {"metadata": None, "config": None}
for plugin_md in star_registry: for plugin_md in star_registry:
if plugin_md.name == plugin_name: if plugin_md.name == plugin_name:
+6 -1
View File
@@ -1,6 +1,8 @@
import asyncio import asyncio
import json import json
from typing import cast
from quart import Response as QuartResponse
from quart import make_response from quart import make_response
from astrbot.core import LogBroker, logger from astrbot.core import LogBroker, logger
@@ -39,7 +41,9 @@ class LogRoute(Route):
if queue: if queue:
self.log_broker.unregister(queue) self.log_broker.unregister(queue)
response = await make_response( response = cast(
QuartResponse,
await make_response(
stream(), stream(),
{ {
"Content-Type": "text/event-stream", "Content-Type": "text/event-stream",
@@ -47,6 +51,7 @@ class LogRoute(Route):
"Connection": "keep-alive", "Connection": "keep-alive",
"Transfer-Encoding": "chunked", "Transfer-Encoding": "chunked",
}, },
),
) )
response.timeout = None response.timeout = None
return response return response
+1 -1
View File
@@ -82,7 +82,7 @@ class PlatformRoute(Route):
""" """
for platform in self.platform_manager.platform_insts: for platform in self.platform_manager.platform_insts:
if platform.config.get("webhook_uuid") == webhook_uuid: if platform.config.get("webhook_uuid") == webhook_uuid:
if platform.config.get("unified_webhook_mode", False): if platform.unified_webhook():
return platform return platform
return None return None
+108 -51
View File
@@ -1,14 +1,17 @@
import asyncio import asyncio
import hashlib
import json import json
import os import os
import ssl import ssl
import traceback import traceback
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import aiohttp import aiohttp
import certifi import certifi
from quart import request from quart import request
from astrbot.api import sp
from astrbot.core import DEMO_MODE, file_token_service, logger from astrbot.core import DEMO_MODE, file_token_service, logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command import CommandFilter
@@ -25,6 +28,13 @@ PLUGIN_UPDATE_CONCURRENCY = (
) )
@dataclass
class RegistrySource:
urls: list[str]
cache_file: str
md5_url: str | None # None means "no remote MD5, always treat cache as stale"
class PluginRoute(Route): class PluginRoute(Route):
def __init__( def __init__(
self, self,
@@ -45,6 +55,8 @@ class PluginRoute(Route):
"/plugin/on": ("POST", self.on_plugin), "/plugin/on": ("POST", self.on_plugin),
"/plugin/reload": ("POST", self.reload_plugins), "/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme), "/plugin/readme": ("GET", self.get_plugin_readme),
"/plugin/source/get": ("GET", self.get_custom_source),
"/plugin/source/save": ("POST", self.save_custom_source),
} }
self.core_lifecycle = core_lifecycle self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager self.plugin_manager = plugin_manager
@@ -84,22 +96,15 @@ class PluginRoute(Route):
custom = request.args.get("custom_registry") custom = request.args.get("custom_registry")
force_refresh = request.args.get("force_refresh", "false").lower() == "true" force_refresh = request.args.get("force_refresh", "false").lower() == "true"
cache_file = "data/plugins.json" # 构建注册表源信息
source = self._build_registry_source(custom)
if custom:
urls = [custom]
else:
urls = [
"https://api.soulter.top/astrbot/plugins",
"https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json",
]
# 如果不是强制刷新,先检查缓存是否有效 # 如果不是强制刷新,先检查缓存是否有效
cached_data = None cached_data = None
if not force_refresh: if not force_refresh:
# 先检查MD5是否匹配,如果匹配则使用缓存 # 先检查MD5是否匹配,如果匹配则使用缓存
if await self._is_cache_valid(cache_file): if await self._is_cache_valid(source):
cached_data = self._load_plugin_cache(cache_file) cached_data = self._load_plugin_cache(source.cache_file)
if cached_data: if cached_data:
logger.debug("缓存MD5匹配,使用缓存的插件市场数据") logger.debug("缓存MD5匹配,使用缓存的插件市场数据")
return Response().ok(cached_data).__dict__ return Response().ok(cached_data).__dict__
@@ -109,7 +114,7 @@ class PluginRoute(Route):
ssl_context = ssl.create_default_context(cafile=certifi.where()) ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context) connector = aiohttp.TCPConnector(ssl=ssl_context)
for url in urls: for url in source.urls:
try: try:
async with ( async with (
aiohttp.ClientSession( aiohttp.ClientSession(
@@ -128,11 +133,13 @@ class PluginRoute(Route):
logger.warning(f"远程插件市场数据为空: {url}") logger.warning(f"远程插件市场数据为空: {url}")
continue # 继续尝试其他URL或使用缓存 continue # 继续尝试其他URL或使用缓存
logger.info("成功获取远程插件市场数据") logger.info(
f"成功获取远程插件市场数据,包含 {len(remote_data)} 个插件"
)
# 获取最新的MD5并保存到缓存 # 获取最新的MD5并保存到缓存
current_md5 = await self._get_remote_md5() current_md5 = await self._fetch_remote_md5(source.md5_url)
self._save_plugin_cache( self._save_plugin_cache(
cache_file, source.cache_file,
remote_data, remote_data,
current_md5, current_md5,
) )
@@ -143,7 +150,7 @@ class PluginRoute(Route):
# 如果远程获取失败,尝试使用缓存数据 # 如果远程获取失败,尝试使用缓存数据
if not cached_data: if not cached_data:
cached_data = self._load_plugin_cache(cache_file) cached_data = self._load_plugin_cache(source.cache_file)
if cached_data: if cached_data:
logger.warning("远程插件市场数据获取失败,使用缓存数据") logger.warning("远程插件市场数据获取失败,使用缓存数据")
@@ -151,24 +158,75 @@ class PluginRoute(Route):
return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__ return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__
async def _is_cache_valid(self, cache_file: str) -> bool: def _build_registry_source(self, custom_url: str | None) -> RegistrySource:
"""检查缓存是否有效(基于MD5""" """构建注册表源信息"""
try: if custom_url:
if not os.path.exists(cache_file): # 对自定义URL生成一个安全的文件名
return False url_hash = hashlib.md5(custom_url.encode()).hexdigest()[:8]
cache_file = f"data/plugins_custom_{url_hash}.json"
# 加载缓存文件 # 更安全的后缀处理方式
if custom_url.endswith(".json"):
md5_url = custom_url[:-5] + "-md5.json"
else:
md5_url = custom_url + "-md5.json"
urls = [custom_url]
else:
cache_file = "data/plugins.json"
md5_url = "https://api.soulter.top/astrbot/plugins-md5"
urls = [
"https://api.soulter.top/astrbot/plugins",
"https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json",
]
return RegistrySource(urls=urls, cache_file=cache_file, md5_url=md5_url)
def _load_cached_md5(self, cache_file: str) -> str | None:
"""从缓存文件中加载MD5"""
if not os.path.exists(cache_file):
return None
try:
with open(cache_file, encoding="utf-8") as f: with open(cache_file, encoding="utf-8") as f:
cache_data = json.load(f) cache_data = json.load(f)
return cache_data.get("md5")
except Exception as e:
logger.warning(f"加载缓存MD5失败: {e}")
return None
cached_md5 = cache_data.get("md5") async def _fetch_remote_md5(self, md5_url: str | None) -> str | None:
"""获取远程MD5"""
if not md5_url:
return None
try:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with (
aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session,
session.get(md5_url) as response,
):
if response.status == 200:
data = await response.json()
return data.get("md5", "")
except Exception as e:
logger.debug(f"获取远程MD5失败: {e}")
return None
async def _is_cache_valid(self, source: RegistrySource) -> bool:
"""检查缓存是否有效(基于MD5"""
try:
cached_md5 = self._load_cached_md5(source.cache_file)
if not cached_md5: if not cached_md5:
logger.debug("缓存文件中没有MD5信息") logger.debug("缓存文件中没有MD5信息")
return False return False
# 获取远程MD5 remote_md5 = await self._fetch_remote_md5(source.md5_url)
remote_md5 = await self._get_remote_md5() if remote_md5 is None:
if not remote_md5:
logger.warning("无法获取远程MD5,将使用缓存") logger.warning("无法获取远程MD5,将使用缓存")
return True # 如果无法获取远程MD5,认为缓存有效 return True # 如果无法获取远程MD5,认为缓存有效
@@ -182,30 +240,6 @@ class PluginRoute(Route):
logger.warning(f"检查缓存有效性失败: {e}") logger.warning(f"检查缓存有效性失败: {e}")
return False return False
async def _get_remote_md5(self) -> str:
"""获取远程插件数据的MD5"""
try:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with (
aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session,
session.get(
"https://api.soulter.top/astrbot/plugins-md5",
) as response,
):
if response.status == 200:
data = await response.json()
return data.get("md5", "")
logger.error(f"获取MD5失败,状态码:{response.status}")
return ""
except Exception as e:
logger.error(f"获取远程MD5失败: {e}")
return ""
def _load_plugin_cache(self, cache_file: str): def _load_plugin_cache(self, cache_file: str):
"""加载本地缓存的插件市场数据""" """加载本地缓存的插件市场数据"""
try: try:
@@ -545,9 +579,13 @@ class PluginRoute(Route):
logger.warning(f"插件 {plugin_name} 不存在") logger.warning(f"插件 {plugin_name} 不存在")
return Response().error(f"插件 {plugin_name} 不存在").__dict__ return Response().error(f"插件 {plugin_name} 不存在").__dict__
if not plugin_obj.root_dir_name:
logger.warning(f"插件 {plugin_name} 目录不存在")
return Response().error(f"插件 {plugin_name} 目录不存在").__dict__
plugin_dir = os.path.join( plugin_dir = os.path.join(
self.plugin_manager.plugin_store_path, self.plugin_manager.plugin_store_path,
plugin_obj.root_dir_name, plugin_obj.root_dir_name or "",
) )
if not os.path.isdir(plugin_dir): if not os.path.isdir(plugin_dir):
@@ -572,3 +610,22 @@ class PluginRoute(Route):
except Exception as e: except Exception as e:
logger.error(f"/api/plugin/readme: {traceback.format_exc()}") logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
return Response().error(f"读取README文件失败: {e!s}").__dict__ return Response().error(f"读取README文件失败: {e!s}").__dict__
async def get_custom_source(self):
"""获取自定义插件源"""
sources = await sp.global_get("custom_plugin_sources", [])
return Response().ok(sources).__dict__
async def save_custom_source(self):
"""保存自定义插件源"""
try:
data = await request.get_json()
sources = data.get("sources", [])
if not isinstance(sources, list):
return Response().error("sources fields must be a list").__dict__
await sp.global_put("custom_plugin_sources", sources)
return Response().ok(None, "保存成功").__dict__
except Exception as e:
logger.error(f"/api/plugin/source/save: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+2
View File
@@ -12,6 +12,8 @@ class RouteContext:
class Route: class Route:
routes: list | dict
def __init__(self, context: RouteContext): def __init__(self, context: RouteContext):
self.app = context.app self.app = context.app
self.config = context.config self.config = context.config
+6 -3
View File
@@ -2,9 +2,12 @@ import asyncio
import logging import logging
import os import os
import socket import socket
from typing import cast
import jwt import jwt
import psutil import psutil
from flask.json.provider import DefaultJSONProvider
from psutil._common import addr as psutil_addr
from quart import Quart, g, jsonify, request from quart import Quart, g, jsonify, request
from quart.logging import default_handler from quart.logging import default_handler
@@ -21,7 +24,7 @@ from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute from .routes.session_management import SessionManagementRoute
from .routes.t2i import T2iRoute from .routes.t2i import T2iRoute
APP: Quart = None APP: Quart
class AstrBotDashboard: class AstrBotDashboard:
@@ -48,7 +51,7 @@ class AstrBotDashboard:
self.app.config["MAX_CONTENT_LENGTH"] = ( self.app.config["MAX_CONTENT_LENGTH"] = (
128 * 1024 * 1024 128 * 1024 * 1024
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
self.app.json.sort_keys = False cast(DefaultJSONProvider, self.app.json).sort_keys = False
self.app.before_request(self.auth_middleware) self.app.before_request(self.auth_middleware)
# token 用于验证请求 # token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler) logging.getLogger(self.app.name).removeHandler(default_handler)
@@ -147,7 +150,7 @@ class AstrBotDashboard:
"""获取占用端口的进程详细信息""" """获取占用端口的进程详细信息"""
try: try:
for conn in psutil.net_connections(kind="inet"): for conn in psutil.net_connections(kind="inet"):
if conn.laddr.port == port: if cast(psutil_addr, conn.laddr).port == port:
try: try:
process = psutil.Process(conn.pid) process = psutil.Process(conn.pid)
# 获取详细信息 # 获取详细信息

Some files were not shown because too many files have changed in this diff Show More