Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e2f928a7e5 | |||
| b8e4068c75 | |||
| 0916177a57 | |||
| 02cd5e396b | |||
| 56673ad78f | |||
| 9a4d05e2b6 | |||
| c3f45449e8 | |||
| 65da469deb | |||
| 16df64c405 | |||
| 6b73b19e54 | |||
| e7e97730af | |||
| 467ca1eb5c | |||
| 46528391c2 | |||
| 8a0b7717cc | |||
| 3b81fb4985 | |||
| c09d57a820 | |||
| ec408a2aff | |||
| 417179a6b9 | |||
| fcd29445c7 | |||
| 5f535001db | |||
| 750d245b16 | |||
| f624971613 | |||
| aa6d07afcc | |||
| 2c36649874 | |||
| c95735dcc0 | |||
| 03bb278f50 | |||
| a5e0974da3 | |||
| f0fb447fbc | |||
| 37566182b0 | |||
| e460b411da |
@@ -0,0 +1,58 @@
|
||||
name: Smoke Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README*.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
smoke-test:
|
||||
name: Run smoke tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install UV package manager
|
||||
run: |
|
||||
pip install uv
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync
|
||||
timeout-minutes: 15
|
||||
|
||||
- name: Run smoke tests
|
||||
run: |
|
||||
uv run main.py &
|
||||
APP_PID=$!
|
||||
|
||||
echo "Waiting for application to start..."
|
||||
for i in {1..60}; do
|
||||
if curl -f http://localhost:6185 > /dev/null 2>&1; then
|
||||
echo "Application started successfully!"
|
||||
kill $APP_PID
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Application failed to start within 30 seconds"
|
||||
kill $APP_PID 2>/dev/null || true
|
||||
exit 1
|
||||
timeout-minutes: 2
|
||||
@@ -20,6 +20,7 @@
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
|
||||
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
|
||||
@@ -206,6 +207,7 @@ pre-commit install
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 7 群:743746109
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
@@ -241,4 +243,10 @@ pre-commit install
|
||||
|
||||
</details>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
</div
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.8.0"
|
||||
__version__ = "4.9.2"
|
||||
|
||||
@@ -97,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_resp_result = None
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any, Generic
|
||||
|
||||
import jsonschema
|
||||
@@ -7,6 +7,8 @@ from deprecated import deprecated
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
@@ -38,7 +40,10 @@ class ToolSchema:
|
||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
handler: (
|
||||
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
|
||||
| None
|
||||
) = None
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
|
||||
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[
|
||||
...,
|
||||
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
|
||||
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
|
||||
],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
"""
|
||||
|
||||
config_path: str
|
||||
default_config: dict
|
||||
schema: dict | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.8.0"
|
||||
VERSION = "4.9.2"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -13,6 +13,7 @@ WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
"wecom",
|
||||
"wecom_ai_bot",
|
||||
"slack",
|
||||
"lark",
|
||||
]
|
||||
|
||||
# 默认配置
|
||||
@@ -42,7 +43,15 @@ DEFAULT_CONFIG = {
|
||||
"interval": "1.5,3.5",
|
||||
"log_base": 2.6,
|
||||
"words_count_threshold": 150,
|
||||
"split_mode": "regex", # regex 或 words
|
||||
"regex": ".*?[。?!~…]+|.+$",
|
||||
"split_words": [
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
"…",
|
||||
], # 当 split_mode 为 words 时使用
|
||||
"content_cleanup_rule": "",
|
||||
},
|
||||
"no_permission_reply": True,
|
||||
@@ -99,6 +108,7 @@ DEFAULT_CONFIG = {
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
"use_file_service": False,
|
||||
"trigger_probability": 1.0,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
@@ -157,6 +167,7 @@ DEFAULT_CONFIG = {
|
||||
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||
"kb_agentic_mode": False,
|
||||
"disable_builtin_commands": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -268,6 +279,10 @@ CONFIG_METADATA_2 = {
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"domain": "https://open.feishu.cn",
|
||||
"lark_connection_mode": "socket", # webhook, socket
|
||||
"webhook_uuid": "",
|
||||
"lark_encrypt_key": "",
|
||||
"lark_verification_token": "",
|
||||
},
|
||||
"钉钉(DingTalk)": {
|
||||
"id": "dingtalk",
|
||||
@@ -361,6 +376,28 @@ CONFIG_METADATA_2 = {
|
||||
# "type": "string",
|
||||
# "options": ["fullscreen", "embedded"],
|
||||
# },
|
||||
"lark_connection_mode": {
|
||||
"description": "订阅方式",
|
||||
"type": "string",
|
||||
"options": ["socket", "webhook"],
|
||||
"labels": ["长连接模式", "推送至服务器模式"],
|
||||
},
|
||||
"lark_encrypt_key": {
|
||||
"description": "Encrypt Key",
|
||||
"type": "string",
|
||||
"hint": "用于解密飞书回调数据的加密密钥",
|
||||
"condition": {
|
||||
"lark_connection_mode": "webhook",
|
||||
},
|
||||
},
|
||||
"lark_verification_token": {
|
||||
"description": "Verification Token",
|
||||
"type": "string",
|
||||
"hint": "用于验证飞书回调请求的令牌",
|
||||
"condition": {
|
||||
"lark_connection_mode": "webhook",
|
||||
},
|
||||
},
|
||||
"is_sandbox": {
|
||||
"description": "沙箱模式",
|
||||
"type": "bool",
|
||||
@@ -2173,6 +2210,9 @@ CONFIG_METADATA_2 = {
|
||||
"use_file_service": {
|
||||
"type": "bool",
|
||||
},
|
||||
"trigger_probability": {
|
||||
"type": "float",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -2383,6 +2423,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.trigger_probability": {
|
||||
"description": "TTS 触发概率",
|
||||
"type": "float",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.image_caption_prompt": {
|
||||
"description": "图片转述提示词",
|
||||
"type": "text",
|
||||
@@ -2661,6 +2709,11 @@ CONFIG_METADATA_3 = {
|
||||
"description": "只 @ 机器人是否触发等待",
|
||||
"type": "bool",
|
||||
},
|
||||
"disable_builtin_commands": {
|
||||
"description": "禁用自带指令",
|
||||
"type": "bool",
|
||||
"hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"whitelist": {
|
||||
@@ -2875,9 +2928,26 @@ CONFIG_METADATA_3 = {
|
||||
"description": "分段回复字数阈值",
|
||||
"type": "int",
|
||||
},
|
||||
"platform_settings.segmented_reply.split_mode": {
|
||||
"description": "分段模式",
|
||||
"type": "string",
|
||||
"options": ["regex", "words"],
|
||||
"labels": ["正则表达式", "分段词列表"],
|
||||
},
|
||||
"platform_settings.segmented_reply.regex": {
|
||||
"description": "分段正则表达式",
|
||||
"type": "string",
|
||||
"condition": {
|
||||
"platform_settings.segmented_reply.split_mode": "regex",
|
||||
},
|
||||
},
|
||||
"platform_settings.segmented_reply.split_words": {
|
||||
"description": "分段词列表",
|
||||
"type": "list",
|
||||
"hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
|
||||
"condition": {
|
||||
"platform_settings.segmented_reply.split_mode": "words",
|
||||
},
|
||||
},
|
||||
"platform_settings.segmented_reply.content_cleanup_rule": {
|
||||
"description": "内容过滤正则表达式",
|
||||
@@ -2928,6 +2998,7 @@ CONFIG_METADATA_3 = {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"hint": "0.0-1.0 之间的数值",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_ltm_settings.active_reply.enable": True,
|
||||
},
|
||||
|
||||
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
|
||||
"_special",
|
||||
"invisible",
|
||||
"options",
|
||||
"slider",
|
||||
]:
|
||||
if attr in field_data:
|
||||
field_result[attr] = field_data[attr]
|
||||
|
||||
@@ -197,7 +197,7 @@ class AstrBotCoreLifecycle:
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
extra_tasks = []
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
|
||||
|
||||
tasks_ = [event_bus_task, *extra_tasks]
|
||||
for task in tasks_:
|
||||
|
||||
@@ -5,8 +5,7 @@ from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
@@ -32,7 +31,7 @@ class BaseDatabase(abc.ABC):
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
|
||||
@@ -70,6 +70,7 @@ async def migration_conversation_table(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
@@ -207,6 +208,7 @@ async def migration_webchat_data(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
|
||||
@@ -127,7 +127,7 @@ class SQLiteDatabase:
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: tuple = None):
|
||||
def _exec_sql(self, sql: str, params: tuple | None = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -224,9 +224,11 @@ class SQLiteDatabase:
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
return Stats(platform)
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
def get_conversation_by_user_id(
|
||||
self, user_id: str, cid: str
|
||||
) -> Conversation | None:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -258,7 +260,7 @@ class SQLiteDatabase:
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> tuple:
|
||||
def get_conversations(self, user_id: str) -> list[Conversation]:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
|
||||
+16
-15
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats" # type: ignore
|
||||
__tablename__: str = "platform_stats"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations" # type: ignore
|
||||
__tablename__: str = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
inner_conversation_id: int | None = Field(
|
||||
default=None,
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas" # type: ignore
|
||||
__tablename__: str = "personas"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences" # type: ignore
|
||||
__tablename__: str = "preferences"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history" # type: ignore
|
||||
__tablename__: str = "platform_message_history"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
|
||||
Each session can have multiple conversations (对话) associated with it.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_sessions" # type: ignore
|
||||
__tablename__: str = "platform_sessions"
|
||||
|
||||
inner_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments" # type: ignore
|
||||
__tablename__: str = "attachments"
|
||||
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -261,17 +262,17 @@ class Personality(TypedDict):
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
prompt: str
|
||||
name: str
|
||||
begin_dialogs: list[str]
|
||||
mood_imitation_dialogs: list[str]
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
tools: list[str] | None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
_begin_dialogs_processed: list[dict]
|
||||
_mood_imitation_dialogs_processed: str
|
||||
|
||||
|
||||
# ====
|
||||
|
||||
@@ -3,6 +3,7 @@ import threading
|
||||
import typing as T
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
@@ -489,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(
|
||||
Attachment.attachment_id.in_(attachment_ids)
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
@@ -505,7 +506,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id) == attachment_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
@@ -521,7 +522,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount
|
||||
|
||||
async def insert_persona(
|
||||
|
||||
@@ -90,4 +90,6 @@ class EmbeddingStorage:
|
||||
path (str): 保存索引的路径
|
||||
|
||||
"""
|
||||
if self.index is None:
|
||||
return
|
||||
faiss.write_index(self.index, self.path)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EventBus:
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
@@ -40,6 +40,11 @@ class EventBus:
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
if not scheduler:
|
||||
logger.error(
|
||||
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
|
||||
)
|
||||
continue
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
|
||||
@@ -166,7 +166,11 @@ class RetrievalManager:
|
||||
# 5. Rerank
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
vec_db = kb_options[kb_id]["vec_db"]
|
||||
if not isinstance(vec_db, FaissVecDB):
|
||||
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
|
||||
continue
|
||||
|
||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||
if (
|
||||
vec_db
|
||||
|
||||
+2
-1
@@ -24,6 +24,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from collections import deque
|
||||
|
||||
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
|
||||
self.log_broker.publish(
|
||||
{
|
||||
"level": record.levelname,
|
||||
"time": record.asctime,
|
||||
"time": time.time(),
|
||||
"data": log_entry,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def toDict(self):
|
||||
data = {}
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
|
||||
id: int | None = 0 # 忽略
|
||||
name: str | None = "" # qq昵称
|
||||
uin: str | None = "0" # qq号
|
||||
content: list[BaseMessageComponent] | None = []
|
||||
content: list[BaseMessageComponent] = []
|
||||
seq: str | list | None = "" # 忽略
|
||||
time: int | None = 0 # 忽略
|
||||
|
||||
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
async def to_dict(self) -> dict:
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
@@ -714,12 +717,15 @@ class File(BaseMessageComponent):
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return os.path.abspath(self.file_)
|
||||
if self.file_:
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if not self.url:
|
||||
raise ValueError("Download failed: No URL provided in File component.")
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
if self.name:
|
||||
|
||||
@@ -98,8 +98,8 @@ class PersonaManager:
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
|
||||
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
check_text: str | None = None,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""检查内容安全"""
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
ok, info = self.strategy_selector.check(text)
|
||||
|
||||
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
@@ -91,6 +91,7 @@ async def call_event_hook(
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
assert inspect.iscoroutinefunction(handler.handler)
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||
)
|
||||
|
||||
@@ -16,7 +16,6 @@ from ..stage import Stage
|
||||
|
||||
class StarRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
||||
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
|
||||
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
|
||||
self.ctx = ctx
|
||||
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||
"activated_handlers",
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
|
||||
):
|
||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||
if (
|
||||
event.get_result() and not event.get_result().is_stopped()
|
||||
event.get_result() and not event.is_stopped()
|
||||
) or not event.get_result():
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -117,7 +117,9 @@ class RespondStage(Stage):
|
||||
if not self.enable_seg:
|
||||
return False
|
||||
|
||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||
if (result := event.get_result()) is None:
|
||||
return False
|
||||
if self.only_llm_result and result.is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
@@ -156,7 +158,11 @@ class RespondStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
if event.get_extra("_streaming_finished", False):
|
||||
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
event.set_extra("_streaming_finished", True)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
@@ -185,7 +191,7 @@ class RespondStage(Stage):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
result.chain[idx] = component
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -6,6 +7,7 @@ from collections.abc import AsyncGenerator
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
@@ -41,6 +43,18 @@ class ResultDecorateStage(Stage):
|
||||
"forward_threshold"
|
||||
]
|
||||
|
||||
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
|
||||
"trigger_probability",
|
||||
1,
|
||||
)
|
||||
try:
|
||||
self.tts_trigger_probability = max(
|
||||
0.0,
|
||||
min(float(trigger_probability), 1.0),
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
self.tts_trigger_probability = 1.0
|
||||
|
||||
# 分段回复
|
||||
self.words_count_threshold = int(
|
||||
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||
@@ -53,7 +67,22 @@ class ResultDecorateStage(Stage):
|
||||
self.only_llm_result = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["only_llm_result"]
|
||||
self.split_mode = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_mode", "regex")
|
||||
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
|
||||
self.split_words = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
].get("split_words", ["。", "?", "!", "~", "…"])
|
||||
if self.split_words:
|
||||
escaped_words = sorted(
|
||||
[re.escape(word) for word in self.split_words], key=len, reverse=True
|
||||
)
|
||||
self.split_words_pattern = re.compile(
|
||||
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
|
||||
)
|
||||
else:
|
||||
self.split_words_pattern = None
|
||||
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
|
||||
"segmented_reply"
|
||||
]["content_cleanup_rule"]
|
||||
@@ -69,6 +98,28 @@ class ResultDecorateStage(Stage):
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
def _split_text_by_words(self, text: str) -> list[str]:
|
||||
"""使用分段词列表分段文本"""
|
||||
if not self.split_words_pattern:
|
||||
return [text]
|
||||
|
||||
segments = self.split_words_pattern.findall(text)
|
||||
result = []
|
||||
for seg in segments:
|
||||
if isinstance(seg, tuple):
|
||||
content = seg[0]
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
for word in self.split_words:
|
||||
if content.endswith(word):
|
||||
content = content[: -len(word)]
|
||||
break
|
||||
if content.strip():
|
||||
result.append(content)
|
||||
elif seg and seg.strip():
|
||||
result.append(seg)
|
||||
return result if result else [text]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -93,11 +144,13 @@ class ResultDecorateStage(Stage):
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text += comp.text
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
|
||||
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
):
|
||||
yield
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
@@ -114,7 +167,8 @@ class ResultDecorateStage(Stage):
|
||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
|
||||
)
|
||||
await handler.handler(event)
|
||||
if event.get_result() is None or not event.get_result().chain:
|
||||
|
||||
if (result := event.get_result()) is None or not result.chain:
|
||||
logger.debug(
|
||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
|
||||
)
|
||||
@@ -161,21 +215,27 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
# 根据 split_mode 选择分段方式
|
||||
if self.split_mode == "words":
|
||||
split_response = self._split_text_by_words(comp.text)
|
||||
else: # regex 模式
|
||||
try:
|
||||
split_response = re.findall(
|
||||
self.regex,
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
except re.error:
|
||||
logger.error(
|
||||
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
||||
)
|
||||
split_response = re.findall(
|
||||
r".*?[。?!~…]+|.+$",
|
||||
comp.text,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
@@ -199,7 +259,14 @@ class ResultDecorateStage(Stage):
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
if not tts_provider:
|
||||
should_tts = self.tts_trigger_probability >= 1.0 or (
|
||||
self.tts_trigger_probability > 0.0
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
)
|
||||
|
||||
if not should_tts:
|
||||
logger.debug("跳过 TTS:触发概率未命中。")
|
||||
elif not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
||||
WecomAIBotMessageEvent,
|
||||
)
|
||||
|
||||
from . import STAGES_ORDER
|
||||
from .context import PipelineContext
|
||||
@@ -78,7 +82,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -50,6 +50,9 @@ class WakingCheckStage(Stage):
|
||||
"ignore_at_all",
|
||||
False,
|
||||
)
|
||||
self.disable_builtin_commands = self.ctx.astrbot_config.get(
|
||||
"disable_builtin_commands", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
@@ -131,6 +134,13 @@ class WakingCheckStage(Stage):
|
||||
EventType.AdapterMessageEvent,
|
||||
plugins_name=event.plugins_name,
|
||||
):
|
||||
if (
|
||||
self.disable_builtin_commands
|
||||
and handler.handler_module_path == "packages.builtin_commands.main"
|
||||
):
|
||||
logger.debug("skipping builtin command")
|
||||
continue
|
||||
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
|
||||
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
def get_sender_name(self) -> str:
|
||||
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
||||
return self.message_obj.sender.nickname
|
||||
if isinstance(self.message_obj.sender.nickname, str):
|
||||
return self.message_obj.sender.nickname
|
||||
return ""
|
||||
|
||||
def set_extra(self, key, value):
|
||||
"""设置额外的信息。"""
|
||||
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self.call_llm = call_llm
|
||||
|
||||
def get_result(self) -> MessageEventResult:
|
||||
def get_result(self) -> MessageEventResult | None:
|
||||
"""获取消息事件的结果。"""
|
||||
return self._result
|
||||
|
||||
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
session_id: str = None,
|
||||
session_id: str = "",
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
system_prompt: str = "",
|
||||
|
||||
@@ -54,7 +54,7 @@ class AstrBotMessage:
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group: Group # 群组
|
||||
group: Group | None # 群组
|
||||
sender: MessageMember # 发送者
|
||||
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
@@ -78,7 +78,7 @@ class AstrBotMessage:
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str):
|
||||
def group_id(self, value: str | None):
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
|
||||
@@ -5,6 +5,7 @@ from asyncio import Queue
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
|
||||
|
||||
from .platform import Platform, PlatformStatus
|
||||
from .register import platform_cls_map
|
||||
@@ -18,6 +19,7 @@ class PlatformManager:
|
||||
|
||||
self._inst_map: dict[str, dict] = {}
|
||||
|
||||
self.astrbot_config = config
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
|
||||
@@ -29,6 +31,8 @@ class PlatformManager:
|
||||
"""初始化所有平台适配器"""
|
||||
for platform in self.platforms_config:
|
||||
try:
|
||||
if ensure_platform_webhook_config(platform):
|
||||
self.astrbot_config.save_config()
|
||||
await self.load_platform(platform)
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -80,6 +80,13 @@ class Platform(abc.ABC):
|
||||
if self._status == PlatformStatus.ERROR:
|
||||
self._status = PlatformStatus.RUNNING
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
"""是否正在使用统一 Webhook 模式"""
|
||||
return bool(
|
||||
self.config.get("unified_webhook_mode", False)
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取平台统计信息"""
|
||||
meta = self.meta()
|
||||
@@ -97,10 +104,11 @@ class Platform(abc.ABC):
|
||||
}
|
||||
if self.last_error
|
||||
else None,
|
||||
"unified_webhook": self.unified_webhook(),
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self) -> Awaitable[Any]:
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -116,7 +124,7 @@ class Platform(abc.ABC):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||
|
||||
异步方法。
|
||||
|
||||
@@ -7,7 +7,7 @@ class PlatformMetadata:
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str | None = None
|
||||
id: str
|
||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||
|
||||
default_config_tmpl: dict | None = None
|
||||
|
||||
@@ -40,6 +40,7 @@ def register_platform_adapter(
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
id=adapter_name,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
|
||||
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
bot: CQHttp,
|
||||
event: Event | None,
|
||||
is_group: bool,
|
||||
session_id: str,
|
||||
session_id: str | None,
|
||||
messages: list[dict],
|
||||
):
|
||||
# session_id 必须是纯数字字符串
|
||||
session_id = int(session_id) if session_id.isdigit() else None
|
||||
session_id_int = (
|
||||
int(session_id) if session_id and session_id.isdigit() else None
|
||||
)
|
||||
|
||||
if is_group and isinstance(session_id, int):
|
||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||
elif not is_group and isinstance(session_id, int):
|
||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||
if is_group and isinstance(session_id_int, int):
|
||||
await bot.send_group_msg(group_id=session_id_int, message=messages)
|
||||
elif not is_group and isinstance(session_id_int, int):
|
||||
await bot.send_private_msg(user_id=session_id_int, message=messages)
|
||||
elif isinstance(event, Event): # 最后兜底
|
||||
await bot.send(event=event, message=messages)
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
|
||||
self.metadata = PlatformMetadata(
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 请求类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||
)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if event.get("group_id"):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
|
||||
@param event: 事件对象
|
||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||
"""
|
||||
assert event.sender is not None
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
|
||||
if event["message_type"] == "group":
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.group = Group(str(event.group_id))
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
|
||||
await self.bot.send(event, err)
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
return None
|
||||
raise ValueError(err)
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
@@ -417,7 +421,7 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
await self.shutdown_event.wait()
|
||||
logger.info("aiocqhttp 适配器已被优雅地关闭")
|
||||
logger.info("aiocqhttp 适配器已被关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self.client_id = platform_config["client_id"]
|
||||
self.client_secret = platform_config["client_secret"]
|
||||
|
||||
outer_self = self
|
||||
|
||||
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
||||
async def process(self, message: dingtalk_stream.CallbackMessage):
|
||||
logger.debug(f"dingtalk: {message.data}")
|
||||
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||
abm = await self.convert_msg(im)
|
||||
await self.handle_msg(abm)
|
||||
abm = await outer_self.convert_msg(im)
|
||||
await outer_self.handle_msg(abm)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self.client,
|
||||
)
|
||||
self.client_ = client # 用于 websockets 的 client
|
||||
self._shutdown_event: threading.Event | None = None
|
||||
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
||||
if not dingtalk_id:
|
||||
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
abm = AstrBotMessage()
|
||||
abm.message = []
|
||||
abm.message_str = ""
|
||||
abm.timestamp = int(message.create_at / 1000)
|
||||
abm.timestamp = int(cast(int, message.create_at) / 1000)
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if message.conversation_type == "2"
|
||||
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
nickname=message.sender_nick,
|
||||
)
|
||||
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
||||
abm.message_id = message.message_id
|
||||
abm.message_id = cast(str, message.message_id)
|
||||
abm.raw_message = message
|
||||
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
message_type: str = message.message_type
|
||||
message_type: str = cast(str, message.message_type)
|
||||
match message_type:
|
||||
case "text":
|
||||
abm.message_str = message.text.content.strip()
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
case "richText":
|
||||
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
||||
contents: list[dict] = rtc.rich_text_list
|
||||
rtc: dingtalk_stream.RichTextContent = cast(
|
||||
dingtalk_stream.RichTextContent, message.rich_text_content
|
||||
)
|
||||
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
|
||||
for content in contents:
|
||||
plains = ""
|
||||
if "text" in content:
|
||||
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
elif "type" in content and content["type"] == "picture":
|
||||
f_path = await self.download_ding_file(
|
||||
content["downloadCode"],
|
||||
message.robot_code,
|
||||
cast(str, message.robot_code),
|
||||
"jpg",
|
||||
)
|
||||
abm.message.append(Image.fromFileSystem(f_path))
|
||||
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return None
|
||||
return ""
|
||||
resp_data = await resp.json()
|
||||
download_url = resp_data["data"]["downloadUrl"]
|
||||
await download_file(download_url, f_path)
|
||||
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return None
|
||||
return ""
|
||||
return (await resp.json())["data"]["accessToken"]
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
@@ -239,7 +245,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
task.result()
|
||||
except Exception as e:
|
||||
if "Graceful shutdown" in str(e):
|
||||
logger.info("钉钉适配器已被优雅地关闭")
|
||||
logger.info("钉钉适配器已被关闭")
|
||||
return
|
||||
logger.error(f"钉钉机器人启动失败: {e}")
|
||||
|
||||
@@ -250,9 +256,11 @@ class DingtalkPlatformAdapter(Platform):
|
||||
def monkey_patch_close():
|
||||
raise KeyboardInterrupt("Graceful shutdown")
|
||||
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
self._shutdown_event.set()
|
||||
if self.client_.websocket is not None:
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
import dingtalk_stream
|
||||
|
||||
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client.reply_markdown,
|
||||
segment.text,
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
self.message_obj.raw_message,
|
||||
cast(
|
||||
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
|
||||
),
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import discord
|
||||
|
||||
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
|
||||
# 回调函数
|
||||
self.on_message_received = None
|
||||
self.on_ready_once_callback = None
|
||||
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
|
||||
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
|
||||
self._ready_once_fired = False
|
||||
|
||||
@override
|
||||
async def on_ready(self):
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
if self.user is None:
|
||||
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
|
||||
return
|
||||
|
||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||
logger.info("[Discord] 客户端已准备就绪。")
|
||||
|
||||
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
is_mentioned = self.user in message.mentions
|
||||
return {
|
||||
"message": message,
|
||||
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||
"""从 discord.Interaction 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
if interaction.user is None:
|
||||
raise ValueError("Interaction received without a valid user")
|
||||
|
||||
return {
|
||||
"interaction": interaction,
|
||||
"bot_id": str(self.user.id),
|
||||
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
|
||||
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
components: list[BaseMessageComponent] = None,
|
||||
timeout: float = None,
|
||||
components: list[BaseMessageComponent] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import discord
|
||||
from discord.abc import Messageable
|
||||
from discord.abc import GuildChannel, Messageable, PrivateChannel
|
||||
from discord.channel import DMChannel
|
||||
|
||||
from astrbot import logger
|
||||
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = None
|
||||
self.client_self_id: str | None = None
|
||||
self.registered_handlers = []
|
||||
# 指令注册相关
|
||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 创建一个 message_obj 以便在 event 中使用
|
||||
message_obj = AstrBotMessage()
|
||||
if "_" in session.session_id:
|
||||
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
user_id=str(self.client_self_id),
|
||||
nickname=self.client.user.display_name,
|
||||
)
|
||||
message_obj.self_id = self.client_self_id
|
||||
message_obj.self_id = cast(str, self.client_self_id)
|
||||
message_obj.session_id = session.session_id
|
||||
message_obj.message = message_chain.chain
|
||||
|
||||
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"discord",
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
default_config_tmpl=self.config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
|
||||
def _get_message_type(
|
||||
self,
|
||||
channel: Messageable,
|
||||
channel: Messageable | GuildChannel | PrivateChannel,
|
||||
guild_id: int | None = None,
|
||||
) -> MessageType:
|
||||
"""根据 channel 对象和 guild_id 判断消息类型"""
|
||||
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
|
||||
return MessageType.FRIEND_MESSAGE
|
||||
return MessageType.GROUP_MESSAGE
|
||||
|
||||
def _get_channel_id(self, channel: Messageable) -> str:
|
||||
def _get_channel_id(
|
||||
self, channel: Messageable | GuildChannel | PrivateChannel
|
||||
) -> str:
|
||||
"""根据 channel 对象获取ID"""
|
||||
return str(getattr(channel, "id", None))
|
||||
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message: discord.Message = data["message"]
|
||||
message = data["message"]
|
||||
|
||||
content = message.content
|
||||
|
||||
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = self.client_self_id
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.session_id = str(message.channel.id)
|
||||
abm.message_id = str(message.id)
|
||||
return abm
|
||||
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
|
||||
interaction_followup_webhook=followup_webhook,
|
||||
)
|
||||
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否为斜杠指令
|
||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||
|
||||
# 1. 优先处理斜杠指令
|
||||
if is_slash_command:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
self.commit_event(message_event)
|
||||
return
|
||||
|
||||
# 2. 处理普通消息(提及检测)
|
||||
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
|
||||
raw_message = message.raw_message
|
||||
if not isinstance(raw_message, discord.Message):
|
||||
logger.warning(
|
||||
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||
is_mention = False
|
||||
|
||||
# User Mention
|
||||
if (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
):
|
||||
if self.client.user in message.raw_message.mentions:
|
||||
is_mention = True
|
||||
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
|
||||
if self.client.user in raw_message.mentions:
|
||||
is_mention = True
|
||||
|
||||
# Role Mention(Bot 拥有的角色被提及)
|
||||
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
||||
if not is_mention and raw_message.role_mentions:
|
||||
bot_member = None
|
||||
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
||||
if raw_message.guild:
|
||||
try:
|
||||
bot_member = message.raw_message.guild.get_member(
|
||||
bot_member = raw_message.guild.get_member(
|
||||
self.client.user.id,
|
||||
)
|
||||
except Exception:
|
||||
bot_member = None
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
bot_roles = set(bot_member.roles)
|
||||
mentioned_roles = set(message.raw_message.role_mentions)
|
||||
mentioned_roles = set(raw_message.role_mentions)
|
||||
if (
|
||||
bot_roles
|
||||
and mentioned_roles
|
||||
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
|
||||
):
|
||||
is_mention = True
|
||||
|
||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||
if is_slash_command or is_mention:
|
||||
# 如果是被@的消息,设置为唤醒状态
|
||||
if is_mention:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
|
||||
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = [Plain(text=message_str_for_filter)]
|
||||
abm.raw_message = ctx.interaction
|
||||
abm.self_id = self.client_self_id
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.session_id = str(ctx.channel_id)
|
||||
abm.message_id = str(ctx.interaction.id)
|
||||
|
||||
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
def _extract_command_info(
|
||||
event_filter: Any,
|
||||
handler_metadata: StarHandlerMetadata,
|
||||
) -> tuple[str, str, CommandFilter] | None:
|
||||
) -> tuple[str, str, CommandFilter | None] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
# is_group = False
|
||||
|
||||
@@ -4,8 +4,10 @@ import binascii
|
||||
from collections.abc import AsyncGenerator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import discord
|
||||
from discord.types.interactions import ComponentInteractionData
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
channel = await self._get_channel()
|
||||
if not channel:
|
||||
return
|
||||
if not isinstance(channel, discord.abc.Messageable):
|
||||
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
|
||||
return
|
||||
await channel.send(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _get_channel(self) -> discord.abc.Messageable | None:
|
||||
async def _get_channel(
|
||||
self,
|
||||
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
channel_id = int(self.session_id)
|
||||
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
async def _parse_to_discord(
|
||||
self,
|
||||
message: MessageChain,
|
||||
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
|
||||
) -> tuple[
|
||||
str,
|
||||
list[discord.File],
|
||||
discord.ui.View | None,
|
||||
list[discord.Embed],
|
||||
str | int | None,
|
||||
]:
|
||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||
content_parts = []
|
||||
files = []
|
||||
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"add_reaction",
|
||||
):
|
||||
await self.message_obj.raw_message.add_reaction(emoji)
|
||||
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
|
||||
emoji
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||
|
||||
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
== discord.InteractionType.application_command
|
||||
)
|
||||
|
||||
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
== discord.InteractionType.component
|
||||
)
|
||||
|
||||
def get_interaction_custom_id(self) -> str:
|
||||
"""获取交互组件的custom_id"""
|
||||
if self.is_button_interaction():
|
||||
try:
|
||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
||||
return cast(
|
||||
ComponentInteractionData,
|
||||
cast(discord.Interaction, self.message_obj.raw_message).data,
|
||||
).get("custom_id", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
):
|
||||
return any(
|
||||
mention.id == int(self.message_obj.self_id)
|
||||
for mention in self.message_obj.raw_message.mentions
|
||||
for mention in cast(
|
||||
discord.Message, self.message_obj.raw_message
|
||||
).mentions
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"clean_content",
|
||||
):
|
||||
return self.message_obj.raw_message.clean_content
|
||||
return cast(discord.Message, self.message_obj.raw_message).clean_content
|
||||
return self.message_str
|
||||
|
||||
@@ -2,10 +2,17 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
GetMessageResourceRequest,
|
||||
)
|
||||
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot import logger
|
||||
@@ -18,9 +25,11 @@ from astrbot.api.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .lark_event import LarkMessageEvent
|
||||
from .server import LarkWebhookServer
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -42,9 +51,13 @@ class LarkPlatformAdapter(Platform):
|
||||
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
|
||||
self.bot_name = platform_config.get("lark_bot_name", "astrbot")
|
||||
|
||||
# socket or webhook
|
||||
self.connection_mode = platform_config.get("lark_connection_mode", "socket")
|
||||
|
||||
if not self.bot_name:
|
||||
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
|
||||
|
||||
# 初始化 WebSocket 长连接相关配置
|
||||
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
await self.convert_msg(event)
|
||||
|
||||
@@ -57,6 +70,8 @@ class LarkPlatformAdapter(Platform):
|
||||
.build()
|
||||
)
|
||||
|
||||
self.do_v2_msg_event = do_v2_msg_event
|
||||
|
||||
self.client = lark.ws.Client(
|
||||
app_id=self.appid,
|
||||
app_secret=self.appsecret,
|
||||
@@ -66,14 +81,56 @@ class LarkPlatformAdapter(Platform):
|
||||
)
|
||||
|
||||
self.lark_api = (
|
||||
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
|
||||
lark.Client.builder()
|
||||
.app_id(self.appid)
|
||||
.app_secret(self.appsecret)
|
||||
.log_level(lark.LogLevel.ERROR)
|
||||
.domain(self.domain)
|
||||
.build()
|
||||
)
|
||||
|
||||
self.webhook_server = None
|
||||
if self.connection_mode == "webhook":
|
||||
self.webhook_server = LarkWebhookServer(platform_config, event_queue)
|
||||
self.webhook_server.set_callback(self.handle_webhook_event)
|
||||
|
||||
self.event_id_timestamps: dict[str, float] = {}
|
||||
|
||||
def _clean_expired_events(self):
|
||||
"""清理超过 30 分钟的事件记录"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
event_id
|
||||
for event_id, timestamp in self.event_id_timestamps.items()
|
||||
if current_time - timestamp > 1800
|
||||
]
|
||||
for event_id in expired_keys:
|
||||
del self.event_id_timestamps[event_id]
|
||||
|
||||
def _is_duplicate_event(self, event_id: str) -> bool:
|
||||
"""检查事件是否重复
|
||||
|
||||
Args:
|
||||
event_id: 事件ID
|
||||
|
||||
Returns:
|
||||
True 表示重复事件,False 表示新事件
|
||||
"""
|
||||
self._clean_expired_events()
|
||||
if event_id in self.event_id_timestamps:
|
||||
return True
|
||||
self.event_id_timestamps[event_id] = time.time()
|
||||
return False
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||
return
|
||||
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
@@ -114,14 +171,25 @@ class LarkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
if event.event is None:
|
||||
logger.debug("[Lark] 收到空事件(event.event is None)")
|
||||
return
|
||||
message = event.event.message
|
||||
if message is None:
|
||||
logger.debug("[Lark] 事件中没有消息体(message is None)")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.timestamp = int(message.create_time) / 1000
|
||||
|
||||
if message.create_time:
|
||||
abm.timestamp = int(message.create_time) // 1000
|
||||
else:
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message = []
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
@@ -136,14 +204,28 @@ class LarkPlatformAdapter(Platform):
|
||||
at_list = {}
|
||||
if message.mentions:
|
||||
for m in message.mentions:
|
||||
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
||||
if m.name == self.bot_name:
|
||||
abm.self_id = m.id.open_id
|
||||
if m.id is None:
|
||||
continue
|
||||
# 飞书 open_id 可能是 None,这里做个防护
|
||||
open_id = m.id.open_id if m.id.open_id else ""
|
||||
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
|
||||
|
||||
content_json_b = json.loads(message.content)
|
||||
if m.name == self.bot_name:
|
||||
if m.id.open_id is not None:
|
||||
abm.self_id = m.id.open_id
|
||||
|
||||
if message.content is None:
|
||||
logger.warning("[Lark] 消息内容为空")
|
||||
return
|
||||
|
||||
try:
|
||||
content_json_b = json.loads(message.content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
|
||||
return
|
||||
|
||||
if message.message_type == "text":
|
||||
message_str_raw = content_json_b["text"] # 带有 @ 的消息
|
||||
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
|
||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||
# at_users = re.findall(at_pattern, message_str_raw)
|
||||
# 拆分文本,去掉AT符号部分
|
||||
@@ -168,27 +250,47 @@ class LarkPlatformAdapter(Platform):
|
||||
content_json_b = _ls
|
||||
elif message.message_type == "image":
|
||||
content_json_b = [
|
||||
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
|
||||
{
|
||||
"tag": "img",
|
||||
"image_key": content_json_b.get("image_key"),
|
||||
"style": [],
|
||||
},
|
||||
]
|
||||
|
||||
if message.message_type in ("post", "image"):
|
||||
for comp in content_json_b:
|
||||
if comp["tag"] == "at":
|
||||
abm.message.append(at_list[comp["user_id"]])
|
||||
elif comp["tag"] == "text" and comp["text"].strip():
|
||||
if comp.get("tag") == "at":
|
||||
user_id = comp.get("user_id")
|
||||
if user_id in at_list:
|
||||
abm.message.append(at_list[user_id])
|
||||
elif comp.get("tag") == "text" and comp.get("text", "").strip():
|
||||
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||
elif comp["tag"] == "img":
|
||||
image_key = comp["image_key"]
|
||||
elif comp.get("tag") == "img":
|
||||
image_key = comp.get("image_key")
|
||||
if not image_key:
|
||||
continue
|
||||
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message.message_id)
|
||||
.message_id(cast(str, message.message_id))
|
||||
.file_key(image_key)
|
||||
.type("image")
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
continue
|
||||
|
||||
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法下载飞书图片: {image_key}")
|
||||
continue
|
||||
|
||||
if response.file is None:
|
||||
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
|
||||
continue
|
||||
|
||||
image_bytes = response.file.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode()
|
||||
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||
@@ -196,6 +298,19 @@ class LarkPlatformAdapter(Platform):
|
||||
for comp in abm.message:
|
||||
if isinstance(comp, Comp.Plain):
|
||||
abm.message_str += comp.text
|
||||
|
||||
if message.message_id is None:
|
||||
logger.error("[Lark] 消息缺少 message_id")
|
||||
return
|
||||
|
||||
if (
|
||||
event.event.sender is None
|
||||
or event.event.sender.sender_id is None
|
||||
or event.event.sender.sender_id.open_id is None
|
||||
):
|
||||
logger.error("[Lark] 消息发送者信息不完整")
|
||||
return
|
||||
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
abm.sender = MessageMember(
|
||||
@@ -227,13 +342,61 @@ class LarkPlatformAdapter(Platform):
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
async def handle_webhook_event(self, event_data: dict):
|
||||
"""处理 Webhook 事件
|
||||
|
||||
Args:
|
||||
event_data: Webhook 事件数据
|
||||
"""
|
||||
try:
|
||||
header = event_data.get("header", {})
|
||||
event_id = header.get("event_id", "")
|
||||
if event_id and self._is_duplicate_event(event_id):
|
||||
logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
|
||||
return
|
||||
event_type = header.get("event_type", "")
|
||||
if event_type == "im.message.receive_v1":
|
||||
processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
|
||||
data = (processor.type())(event_data)
|
||||
processor.do(data)
|
||||
else:
|
||||
logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
|
||||
|
||||
async def run(self):
|
||||
# self.client.start()
|
||||
await self.client._connect()
|
||||
if self.connection_mode == "webhook":
|
||||
# Webhook 模式
|
||||
if self.webhook_server is None:
|
||||
logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
|
||||
return
|
||||
|
||||
webhook_uuid = self.config.get("webhook_uuid")
|
||||
if webhook_uuid:
|
||||
log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
|
||||
else:
|
||||
logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
|
||||
else:
|
||||
# 长连接模式
|
||||
await self.client._connect()
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
"""统一 Webhook 回调入口"""
|
||||
if not self.webhook_server:
|
||||
return {"error": "Webhook server not initialized"}, 500
|
||||
|
||||
return await self.webhook_server.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||
if self.connection_mode == "socket":
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已关闭")
|
||||
|
||||
def get_client(self) -> lark.Client:
|
||||
def get_client(self) -> lark.ws.Client:
|
||||
return self.client
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
return bool(
|
||||
self.config.get("lark_connection_mode", "") == "webhook"
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
@@ -5,7 +5,15 @@ import uuid
|
||||
from io import BytesIO
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
Emoji,
|
||||
ReplyMessageRequest,
|
||||
ReplyMessageRequestBody,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(comp.file)
|
||||
file_path = image_file_path
|
||||
file_path = image_file_path if image_file_path else ""
|
||||
elif comp.file and comp.file.startswith("base64://"):
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
file_path = comp.file
|
||||
file_path = comp.file if comp.file else ""
|
||||
|
||||
if image_file is None:
|
||||
image_file = open(file_path, "rb")
|
||||
if not file_path:
|
||||
logger.error("[Lark] 图片路径为空,无法上传")
|
||||
continue
|
||||
try:
|
||||
image_file = open(file_path, "rb")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开图片文件: {e}")
|
||||
continue
|
||||
|
||||
request = (
|
||||
CreateImageRequest.builder()
|
||||
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
|
||||
continue
|
||||
|
||||
response = await lark_client.im.v1.image.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||
continue
|
||||
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
|
||||
continue
|
||||
|
||||
image_key = response.data.image_key
|
||||
logger.debug(image_key)
|
||||
ret.append(_stage)
|
||||
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
|
||||
return
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
|
||||
if not response.success():
|
||||
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||
return
|
||||
|
||||
request = (
|
||||
CreateMessageReactionRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""飞书(Lark) Webhook 服务器实现
|
||||
|
||||
实现飞书事件订阅的 Webhook 模式,支持:
|
||||
1. 请求 URL 验证 (challenge 验证)
|
||||
2. 事件加密/解密 (AES-256-CBC)
|
||||
3. 签名校验 (SHA256)
|
||||
4. 事件接收和处理
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class AESCipher:
|
||||
"""AES 加密/解密工具类"""
|
||||
|
||||
def __init__(self, key: str):
|
||||
self.bs = AES.block_size
|
||||
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
|
||||
|
||||
@staticmethod
|
||||
def str_to_bytes(data):
|
||||
u_type = type(b"".decode("utf8"))
|
||||
if isinstance(data, u_type):
|
||||
return data.encode("utf8")
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _unpad(s):
|
||||
return s[: -ord(s[len(s) - 1 :])]
|
||||
|
||||
def decrypt(self, enc):
|
||||
iv = enc[: AES.block_size]
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
|
||||
|
||||
def decrypt_string(self, enc):
|
||||
enc = base64.b64decode(enc)
|
||||
return self.decrypt(enc).decode("utf8")
|
||||
|
||||
|
||||
class LarkWebhookServer:
|
||||
"""飞书 Webhook 服务器
|
||||
|
||||
仅支持统一 Webhook 模式
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict, event_queue: asyncio.Queue):
|
||||
"""初始化 Webhook 服务器
|
||||
|
||||
Args:
|
||||
config: 飞书配置
|
||||
event_queue: 事件队列
|
||||
"""
|
||||
self.app_id = config["app_id"]
|
||||
self.app_secret = config["app_secret"]
|
||||
self.encrypt_key = config.get("lark_encrypt_key", "")
|
||||
self.verification_token = config.get("lark_verification_token", "")
|
||||
|
||||
self.event_queue = event_queue
|
||||
self.callback: Callable[[dict], Awaitable[None]] | None = None
|
||||
|
||||
# 初始化加密工具
|
||||
self.cipher = None
|
||||
if self.encrypt_key:
|
||||
self.cipher = AESCipher(self.encrypt_key)
|
||||
|
||||
def verify_signature(
|
||||
self,
|
||||
timestamp: str,
|
||||
nonce: str,
|
||||
encrypt_key: str,
|
||||
body: bytes,
|
||||
signature: str,
|
||||
) -> bool:
|
||||
"""验证签名
|
||||
|
||||
Args:
|
||||
timestamp: 请求时间戳
|
||||
nonce: 随机数
|
||||
encrypt_key: 加密密钥
|
||||
body: 请求体
|
||||
signature: 签名
|
||||
|
||||
Returns:
|
||||
签名是否有效
|
||||
"""
|
||||
# 拼接字符串: timestamp + nonce + encrypt_key + body
|
||||
bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8")
|
||||
bytes_b = bytes_b1 + body
|
||||
h = hashlib.sha256(bytes_b)
|
||||
calculated_signature = h.hexdigest()
|
||||
return calculated_signature == signature
|
||||
|
||||
def decrypt_event(self, encrypted_data: str) -> dict:
|
||||
"""解密事件数据
|
||||
|
||||
Args:
|
||||
encrypted_data: 加密的事件数据
|
||||
|
||||
Returns:
|
||||
解密后的事件字典
|
||||
"""
|
||||
if not self.cipher:
|
||||
raise ValueError("未配置 encrypt_key,无法解密事件")
|
||||
|
||||
decrypted_str = self.cipher.decrypt_string(encrypted_data)
|
||||
return json.loads(decrypted_str)
|
||||
|
||||
async def handle_challenge(self, event_data: dict) -> dict:
|
||||
"""处理 challenge 验证请求
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
|
||||
Returns:
|
||||
包含 challenge 的响应
|
||||
"""
|
||||
challenge = event_data.get("challenge", "")
|
||||
logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}")
|
||||
|
||||
return {"challenge": challenge}
|
||||
|
||||
async def handle_callback(self, request) -> tuple[dict, int] | dict:
|
||||
"""处理 webhook 回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
# 获取原始请求体
|
||||
body = await request.get_data()
|
||||
|
||||
try:
|
||||
event_data = await request.json
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 解析请求体失败: {e}")
|
||||
return {"error": "Invalid JSON"}, 400
|
||||
|
||||
if not event_data:
|
||||
logger.error("[Lark Webhook] 请求体为空")
|
||||
return {"error": "Empty request body"}, 400
|
||||
|
||||
# 如果配置了 encrypt_key,进行签名验证
|
||||
if self.encrypt_key:
|
||||
timestamp = request.headers.get("X-Lark-Request-Timestamp", "")
|
||||
nonce = request.headers.get("X-Lark-Request-Nonce", "")
|
||||
signature = request.headers.get("X-Lark-Signature", "")
|
||||
|
||||
if timestamp and nonce and signature:
|
||||
if not self.verify_signature(
|
||||
timestamp, nonce, self.encrypt_key, body, signature
|
||||
):
|
||||
logger.error("[Lark Webhook] 签名验证失败")
|
||||
return {"error": "Invalid signature"}, 401
|
||||
|
||||
# 检查是否是加密事件
|
||||
if "encrypt" in event_data:
|
||||
try:
|
||||
event_data = self.decrypt_event(event_data["encrypt"])
|
||||
logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 解密事件失败: {e}")
|
||||
return {"error": "Decryption failed"}, 400
|
||||
|
||||
# 验证 token
|
||||
if self.verification_token:
|
||||
header = event_data.get("header", {})
|
||||
if header:
|
||||
token = header.get("token", "")
|
||||
else:
|
||||
token = event_data.get("token", "")
|
||||
if token != self.verification_token:
|
||||
logger.error("[Lark Webhook] Verification Token 不匹配。")
|
||||
return {"error": "Invalid verification token"}, 401
|
||||
|
||||
# 处理 URL 验证 (challenge)
|
||||
if event_data.get("type") == "url_verification":
|
||||
return await self.handle_challenge(event_data)
|
||||
|
||||
# 调用回调函数处理事件
|
||||
if self.callback:
|
||||
try:
|
||||
await self.callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True)
|
||||
return {"error": "Event processing failed"}, 500
|
||||
|
||||
return {}
|
||||
|
||||
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
|
||||
"""设置事件回调函数
|
||||
|
||||
Args:
|
||||
callback: 处理事件的异步函数
|
||||
"""
|
||||
self.callback = callback
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
if not isinstance(message.raw_message, dict):
|
||||
message.raw_message = {}
|
||||
message.raw_message["poll"] = poll
|
||||
message.poll = poll
|
||||
message.__setattr__("poll", poll)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSession,
|
||||
message_chain: MessageChain,
|
||||
) -> Awaitable[Any]:
|
||||
) -> None:
|
||||
if not self.api:
|
||||
logger.error("[Misskey] API 客户端未初始化")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import aiofiles
|
||||
import botpy
|
||||
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
ret = cast(
|
||||
message.Message,
|
||||
await self._post_send(stream=stream_payload),
|
||||
)
|
||||
stream_payload["index"] += 1
|
||||
stream_payload["id"] = ret["id"]
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
return None
|
||||
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(
|
||||
|
||||
if not isinstance(
|
||||
source,
|
||||
(
|
||||
botpy.message.Message,
|
||||
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
botpy.message.DirectMessage,
|
||||
botpy.message.C2CMessage,
|
||||
),
|
||||
)
|
||||
):
|
||||
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
|
||||
return None
|
||||
|
||||
(
|
||||
plain_text,
|
||||
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
):
|
||||
return None
|
||||
|
||||
payload = {
|
||||
payload: dict = {
|
||||
"content": plain_text,
|
||||
"msg_id": self.message_obj.message_id,
|
||||
}
|
||||
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
|
||||
ret = None
|
||||
|
||||
match type(source):
|
||||
case botpy.message.GroupMessage:
|
||||
match source:
|
||||
case botpy.message.GroupMessage():
|
||||
if not source.group_openid:
|
||||
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
|
||||
return None
|
||||
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(
|
||||
image_base64,
|
||||
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
group_openid=source.group_openid,
|
||||
**payload,
|
||||
)
|
||||
case botpy.message.C2CMessage:
|
||||
|
||||
case botpy.message.C2CMessage():
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(
|
||||
image_base64,
|
||||
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
**payload,
|
||||
)
|
||||
logger.debug(f"Message sent to C2C: {ret}")
|
||||
case botpy.message.Message:
|
||||
|
||||
case botpy.message.Message():
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
ret = await self.bot.api.post_message(
|
||||
channel_id=source.channel_id,
|
||||
**payload,
|
||||
)
|
||||
case botpy.message.DirectMessage:
|
||||
|
||||
case botpy.message.DirectMessage():
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
"file_type": file_type,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
result = None
|
||||
if "openid" in kwargs:
|
||||
payload["openid"] = kwargs["openid"]
|
||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
if "group_openid" in kwargs:
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
elif "group_openid" in kwargs:
|
||||
payload["group_openid"] = kwargs["group_openid"]
|
||||
route = Route(
|
||||
"POST",
|
||||
"/v2/groups/{group_openid}/files",
|
||||
group_openid=kwargs["group_openid"],
|
||||
)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
else:
|
||||
raise ValueError("Invalid upload parameters")
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise RuntimeError(
|
||||
f"Failed to upload image, response is not dict: {result}"
|
||||
)
|
||||
|
||||
return Media(
|
||||
file_uuid=result["file_uuid"],
|
||||
file_info=result["file_info"],
|
||||
ttl=result.get("ttl", 0),
|
||||
)
|
||||
|
||||
async def upload_group_and_c2c_record(
|
||||
self,
|
||||
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
if result:
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"上传文件响应格式错误: {result}")
|
||||
return None
|
||||
|
||||
return Media(
|
||||
file_uuid=result.get("file_uuid"),
|
||||
file_info=result.get("file_info"),
|
||||
file_uuid=result["file_uuid"],
|
||||
file_info=result["file_info"],
|
||||
ttl=result.get("ttl", 0),
|
||||
file_id=result.get("id", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传请求错误: {e}")
|
||||
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
message_reference: message.Reference | None = None,
|
||||
media: message.Media | None = None,
|
||||
msg_id: str | None = None,
|
||||
msg_seq: str = 1,
|
||||
msg_seq: int | None = 1,
|
||||
event_id: str | None = None,
|
||||
markdown: message.MarkdownPayload | None = None,
|
||||
keyboard: message.Keyboard | None = None,
|
||||
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
payload = locals()
|
||||
payload.pop("self", None)
|
||||
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise RuntimeError(
|
||||
f"Failed to post c2c message, response is not dict: {result}"
|
||||
)
|
||||
|
||||
return message.Message(**result)
|
||||
|
||||
@staticmethod
|
||||
async def _parse_to_qqofficial(message: MessageChain):
|
||||
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
elif i.file and i.file.startswith("base64://"):
|
||||
image_base64 = i.file
|
||||
else:
|
||||
elif i.file:
|
||||
image_base64 = file_to_base64(i.file)
|
||||
else:
|
||||
raise ValueError("Unsupported image file format")
|
||||
image_base64 = image_base64.removeprefix("base64://")
|
||||
elif isinstance(i, Record):
|
||||
if i.file:
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import botpy
|
||||
import botpy.message
|
||||
@@ -44,7 +45,9 @@ class botClient(Client):
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
self._commit(abm)
|
||||
|
||||
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.unique_session: bool = platform_settings["unique_session"]
|
||||
qq_group = platform_config["enable_group_c2c"]
|
||||
guild_dm = platform_config["enable_guild_direct_message"]
|
||||
|
||||
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="qq_official",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_from_qqofficial(
|
||||
message: botpy.message.Message | botpy.message.GroupMessage,
|
||||
message: botpy.message.Message
|
||||
| botpy.message.GroupMessage
|
||||
| botpy.message.DirectMessage
|
||||
| botpy.message.C2CMessage,
|
||||
message_type: MessageType,
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
abm.timestamp = int(time.time())
|
||||
abm.raw_message = message
|
||||
abm.message_id = message.id
|
||||
abm.tag = "qq_official"
|
||||
# abm.tag = "qq_official"
|
||||
msg: list[BaseMessageComponent] = []
|
||||
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(
|
||||
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
message,
|
||||
botpy.message.DirectMessage,
|
||||
):
|
||||
try:
|
||||
if isinstance(message, botpy.message.Message):
|
||||
abm.self_id = str(message.mentions[0].id)
|
||||
except BaseException as _:
|
||||
else:
|
||||
abm.self_id = ""
|
||||
|
||||
plain_content = message.content.replace(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import botpy
|
||||
import botpy.message
|
||||
@@ -36,7 +36,9 @@ class botClient(Client):
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
self._commit(abm)
|
||||
|
||||
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="qq_official_webhook",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import quart
|
||||
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
||||
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
|
||||
|
||||
if opcode == 13:
|
||||
# validation
|
||||
signed = await self.webhook_validation(data)
|
||||
signed = await self.webhook_validation(cast(dict, data))
|
||||
print(signed)
|
||||
return signed
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@ import hmac
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from quart import Quart, Response, request
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
@@ -66,7 +68,7 @@ class SlackWebhookClient:
|
||||
"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = await req.get_data()
|
||||
body = cast(bytes, await req.get_data())
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
@@ -139,9 +141,14 @@ class SlackSocketClient:
|
||||
self.event_handler = event_handler
|
||||
self.socket_client = None
|
||||
|
||||
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
|
||||
async def _handle_events(
|
||||
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
|
||||
):
|
||||
"""处理 Socket Mode 事件"""
|
||||
try:
|
||||
if self.socket_client is None:
|
||||
raise RuntimeError("Socket client is not initialized")
|
||||
|
||||
# 确认收到事件
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
await self.socket_client.send_socket_mode_response(response)
|
||||
|
||||
@@ -3,8 +3,7 @@ import base64
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import aiohttp
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
|
||||
self.metadata = PlatformMetadata(
|
||||
name="slack",
|
||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
|
||||
logger.debug(f"[slack] RawMessage {event}")
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self.bot_self_id
|
||||
abm.self_id = cast(str, self.bot_self_id)
|
||||
|
||||
# 获取用户信息
|
||||
user_id = event.get("user", "")
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=user_id)
|
||||
user_data = user_info["user"]
|
||||
user_data = cast(dict, user_info["user"])
|
||||
user_name = user_data.get("real_name") or user_data.get("name", user_id)
|
||||
except Exception:
|
||||
user_name = user_id
|
||||
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
|
||||
channel_id = event.get("channel", "")
|
||||
try:
|
||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||
is_im = channel_info["channel"]["is_im"]
|
||||
is_im = cast(dict, channel_info["channel"])["is_im"]
|
||||
|
||||
if is_im:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
|
||||
for mention in mentions:
|
||||
try:
|
||||
mentioned_user = await self.web_client.users_info(user=mention)
|
||||
user_data = mentioned_user["user"]
|
||||
user_data = cast(dict, mentioned_user["user"])
|
||||
user_name = user_data.get("real_name") or user_data.get(
|
||||
"name",
|
||||
mention,
|
||||
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
|
||||
)
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
async def run(self) -> None:
|
||||
self.bot_self_id = await self.get_bot_user_id()
|
||||
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
|
||||
|
||||
@@ -410,7 +409,7 @@ class SlackAdapter(Platform):
|
||||
await self.socket_client.stop()
|
||||
if self.webhook_client:
|
||||
await self.webhook_client.stop()
|
||||
logger.info("Slack 适配器已被优雅地关闭")
|
||||
logger.info("Slack 适配器已被关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
@@ -428,3 +427,10 @@ class SlackAdapter(Platform):
|
||||
|
||||
def get_client(self):
|
||||
return self.web_client
|
||||
|
||||
def unified_webhook(self) -> bool:
|
||||
return bool(
|
||||
self.config.get("unified_webhook_mode", False)
|
||||
and self.config.get("slack_connection_mode", "") == "webhook"
|
||||
and self.config.get("webhook_uuid")
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
if isinstance(segment, Image):
|
||||
# upload file
|
||||
url = segment.url or segment.file
|
||||
if url.startswith("http"):
|
||||
if url and url.startswith("http"):
|
||||
return {
|
||||
"type": "image",
|
||||
"image_url": url,
|
||||
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "图片上传失败"},
|
||||
}
|
||||
image_url = response["files"][0]["url_private"]
|
||||
image_url = cast(list, response["files"])[0]["url_private"]
|
||||
logger.debug(f"Slack file upload response: {response}")
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
file_url = cast(list, response["files"])[0]["permalink"]
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
|
||||
members = []
|
||||
for member_id in members_response["members"]:
|
||||
for member_id in cast(Iterable, members_response["members"]):
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=member_id)
|
||||
user_data = user_info["user"]
|
||||
user_data = cast(dict, user_info["user"])
|
||||
members.append(
|
||||
MessageMember(
|
||||
user_id=member_id,
|
||||
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
# 如果获取用户信息失败,使用默认信息
|
||||
members.append(MessageMember(user_id=member_id, nickname=member_id))
|
||||
|
||||
channel_data = channel_info["channel"]
|
||||
channel_data = cast(dict, channel_info["channel"])
|
||||
return Group(
|
||||
group_id=channel_id,
|
||||
group_name=channel_data.get("name", ""),
|
||||
|
||||
@@ -424,6 +424,6 @@ class TelegramPlatformAdapter(Platform):
|
||||
if self.application.updater is not None:
|
||||
await self.application.updater.stop()
|
||||
|
||||
logger.info("Telegram 适配器已被优雅地关闭")
|
||||
logger.info("Telegram 适配器已被关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"Telegram 适配器关闭时出错: {e}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
import telegramify_markdown
|
||||
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
|
||||
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
|
||||
Reply,
|
||||
)
|
||||
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
"chat_id": user_name,
|
||||
}
|
||||
if has_reply:
|
||||
payload["reply_to_message_id"] = reply_message_id
|
||||
payload["reply_to_message_id"] = str(reply_message_id)
|
||||
if message_thread_id:
|
||||
payload["message_thread_id"] = message_thread_id
|
||||
|
||||
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
chunk,
|
||||
max_line_length=None,
|
||||
normalize_whitespace=False,
|
||||
)
|
||||
await client.send_message(
|
||||
text=md_text,
|
||||
parse_mode="MarkdownV2",
|
||||
**payload,
|
||||
**cast(Any, payload),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 send failed: {e}. Using plain text instead.",
|
||||
)
|
||||
await client.send_message(text=chunk, **payload)
|
||||
await client.send_message(text=chunk, **cast(Any, payload))
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
await client.send_photo(photo=image_path, **cast(Any, payload))
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
await client.send_document(document=i.file, filename=i.name, **payload)
|
||||
path = await i.get_file()
|
||||
name = i.name or os.path.basename(path)
|
||||
await client.send_document(
|
||||
document=path, filename=name, **cast(Any, payload)
|
||||
)
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await client.send_voice(voice=path, **payload)
|
||||
await client.send_voice(voice=path, **cast(Any, payload))
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
@@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
delta += i.text
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await self.client.send_photo(photo=image_path, **payload)
|
||||
await self.client.send_photo(
|
||||
photo=image_path, **cast(Any, payload)
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
path = await i.get_file()
|
||||
name = i.name or os.path.basename(path)
|
||||
|
||||
await self.client.send_document(
|
||||
document=i.file,
|
||||
filename=i.name,
|
||||
**payload,
|
||||
document=path,
|
||||
filename=name,
|
||||
**cast(Any, payload),
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await self.client.send_voice(voice=path, **payload)
|
||||
await self.client.send_voice(voice=path, **cast(Any, payload))
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||
@@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
msg = await self.client.send_message(
|
||||
text=delta, **cast(Any, payload)
|
||||
)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
@@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
markdown_text = telegramify_markdown.markdownify(
|
||||
delta,
|
||||
max_line_length=None,
|
||||
normalize_whitespace=False,
|
||||
)
|
||||
await self.client.edit_message_text(
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
async def callback(data: tuple):
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@@ -101,9 +101,9 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain | None):
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
await super().send(message)
|
||||
await super().send(MessageChain([]))
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
self.base_url = f"http://{self.host}:{self.port}"
|
||||
self.auth_key = None # 用于保存生成的授权码
|
||||
self.wxid = None # 用于保存登录成功后的 wxid
|
||||
self.wxid: str | None = None # 用于保存登录成功后的 wxid
|
||||
self.credentials_file = os.path.join(
|
||||
get_astrbot_data_path(),
|
||||
"wechatpadpro_credentials.json",
|
||||
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def handle_websocket_message(self, message: str):
|
||||
async def handle_websocket_message(self, message: str | bytes):
|
||||
"""处理从 WebSocket 接收到的消息。"""
|
||||
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||
try:
|
||||
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
|
||||
|
||||
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
|
||||
if self.wxid is None:
|
||||
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
|
||||
return None
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = raw_message
|
||||
abm.message_id = str(raw_message.get("msg_id"))
|
||||
abm.timestamp = raw_message.get("create_time")
|
||||
abm.timestamp = cast(int, raw_message.get("create_time"))
|
||||
abm.self_id = self.wxid
|
||||
|
||||
if int(time.time()) - abm.timestamp > 180:
|
||||
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
content = raw_message.get("content", {}).get("str", "")
|
||||
push_content = raw_message.get("push_content", "")
|
||||
msg_type = raw_message.get("msg_type")
|
||||
msg_type = cast(int, raw_message.get("msg_type"))
|
||||
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
|
||||
from_user_name: str,
|
||||
to_user_name: str,
|
||||
msg_id: int,
|
||||
):
|
||||
) -> dict | None:
|
||||
"""下载原始图片。"""
|
||||
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||
params = {"key": self.auth_key}
|
||||
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
|
||||
# 图片消息
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
msg_id = raw_message.get("msg_id")
|
||||
msg_id = cast(int, raw_message.get("msg_id"))
|
||||
image_resp = await self._download_raw_image(
|
||||
from_user_name,
|
||||
to_user_name,
|
||||
msg_id,
|
||||
)
|
||||
if image_resp is None:
|
||||
logger.error(f"下载图片失败: msg_id={msg_id}")
|
||||
return
|
||||
image_bs64_data = (
|
||||
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||
)
|
||||
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
|
||||
bufid = 0
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id is None:
|
||||
logger.error("语音消息缺少 new_msg_id")
|
||||
return
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
|
||||
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
||||
if voicemsg is None:
|
||||
logger.error("无法从 XML 解析 voicemsg 节点")
|
||||
return
|
||||
bufid = voicemsg.get("bufid") or "0"
|
||||
length = int(voicemsg.get("length") or 0)
|
||||
voice_resp = await self.download_voice(
|
||||
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
|
||||
bufid=bufid,
|
||||
length=length,
|
||||
)
|
||||
if voice_resp is None:
|
||||
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
|
||||
return
|
||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||
if voice_bs64_data:
|
||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||
@@ -827,7 +843,8 @@ class WeChatPadProAdapter(Platform):
|
||||
try:
|
||||
if self.ws_handle_task:
|
||||
self.ws_handle_task.cancel()
|
||||
self._shutdown_event.set()
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
|
||||
|
||||
async def get_contact_details_list(
|
||||
self,
|
||||
room_wx_id_list: list[str] = None,
|
||||
user_names: list[str] = None,
|
||||
room_wx_id_list: list[str] | None = None,
|
||||
user_names: list[str] | None = None,
|
||||
) -> dict | None:
|
||||
"""获取联系人详情列表。"""
|
||||
if room_wx_id_list is None:
|
||||
|
||||
@@ -2,7 +2,8 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -40,7 +41,7 @@ else:
|
||||
class WecomServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.port = int(cast(str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.server.add_url_rule(
|
||||
"/callback/command",
|
||||
@@ -60,7 +61,7 @@ class WecomServer:
|
||||
config["corpid"].strip(),
|
||||
)
|
||||
|
||||
self.callback = None
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
@@ -114,7 +115,7 @@ class WecomServer:
|
||||
logger.error("解密失败,签名异常,请检查配置。")
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
msg = cast(BaseMessage, parse_message(xml))
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
|
||||
# inject
|
||||
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||
self.client.kf = self.wechat_kf_api
|
||||
self.client.kf_message = self.wechat_kf_message_api
|
||||
self.client.__setattr__("kf", self.wechat_kf_api)
|
||||
self.client.__setattr__("kf_message", self.wechat_kf_message_api)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if msg.type == "text":
|
||||
assert isinstance(msg, TextMessage)
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.self_id = str(msg.agent)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
elif isinstance(msg, ImageMessage):
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.agent)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
elif isinstance(msg, VoiceMessage):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self.client.media.download,
|
||||
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||
msgtype = msg.get("msgtype")
|
||||
external_userid = msg.get("external_userid")
|
||||
external_userid = cast(str, msg.get("external_userid"))
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = msg
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
@@ -425,4 +422,4 @@ class WecomPlatformAdapter(Platform):
|
||||
await self.server.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("企业微信 适配器已被优雅地关闭")
|
||||
logger.info("企业微信 适配器已被关闭")
|
||||
|
||||
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
if is_wechat_kf:
|
||||
# 微信客服
|
||||
kf_message_api = getattr(self.client, "kf_message", None)
|
||||
if not kf_message_api:
|
||||
if not isinstance(kf_message_api, WeChatKFMessage):
|
||||
logger.warning("未找到微信客服发送消息方法。")
|
||||
return
|
||||
assert isinstance(kf_message_api, WeChatKFMessage)
|
||||
|
||||
user_id = self.get_sender_id()
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
|
||||
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
|
||||
@staticmethod
|
||||
async def _send(
|
||||
message_chain: MessageChain,
|
||||
message_chain: MessageChain | None,
|
||||
stream_id: str,
|
||||
queue_mgr: WecomAIQueueMgr,
|
||||
streaming: bool = False,
|
||||
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain | None):
|
||||
"""发送消息"""
|
||||
raw = self.message_obj.raw_message
|
||||
assert isinstance(raw, dict), (
|
||||
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
|
||||
await super().send(message)
|
||||
await super().send(MessageChain([]))
|
||||
|
||||
async def send_streaming(self, generator, use_fallback=False):
|
||||
"""流式发送消息,参考webchat的send_streaming设计"""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -36,7 +37,7 @@ else:
|
||||
class WeixinOfficialAccountServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.port = int(cast(int | str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.token = config.get("token")
|
||||
self.encoding_aes_key = config.get("encoding_aes_key")
|
||||
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.callback = None
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
if not msg:
|
||||
logger.error("解析失败。msg为None。")
|
||||
raise
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||
|
||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||
# msgid -> Future
|
||||
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[msg.id]
|
||||
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(
|
||||
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
60,
|
||||
) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
async def convert_message(
|
||||
self,
|
||||
msg,
|
||||
future: asyncio.Future = None,
|
||||
future: asyncio.Future | None = None,
|
||||
) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.message_str = cast(str, msg.content)
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.message = [Plain(cast(str, msg.content))]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
@@ -306,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
future.set_result(None)
|
||||
if future:
|
||||
future.set_result(None)
|
||||
return
|
||||
# 很不优雅 :(
|
||||
abm.raw_message = {
|
||||
@@ -344,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
await self.server.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("微信公众平台 适配器已被优雅地关闭")
|
||||
logger.info("微信公众平台 适配器已被关闭")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.replies import ImageReply, TextReply, VoiceReply
|
||||
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||
active_send_mode = cast(dict, message_obj.raw_message).get(
|
||||
"active_send_mode", False
|
||||
)
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = TextReply(
|
||||
content=chunk,
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -118,7 +118,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list[dict],
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> FuncTool:
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -140,7 +140,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> None:
|
||||
"""添加函数调用工具
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from astrbot.core import astrbot_config, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -10,6 +11,7 @@ from .entities import ProviderType
|
||||
from .provider import (
|
||||
EmbeddingProvider,
|
||||
Provider,
|
||||
Providers,
|
||||
RerankProvider,
|
||||
STTProvider,
|
||||
TTSProvider,
|
||||
@@ -17,6 +19,11 @@ from .provider import (
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasInitialize(Protocol):
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -48,7 +55,7 @@ class ProviderManager:
|
||||
"""加载的 Rerank Provider 的实例"""
|
||||
self.inst_map: dict[
|
||||
str,
|
||||
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
||||
Providers,
|
||||
] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
@@ -123,15 +130,13 @@ class ProviderManager:
|
||||
self.curr_provider_inst = prov
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(
|
||||
self,
|
||||
provider_type: ProviderType,
|
||||
umo=None,
|
||||
) -> Provider | STTProvider | TTSProvider | None:
|
||||
self, provider_type: ProviderType, umo=None
|
||||
) -> Providers | None:
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
@@ -191,7 +196,6 @@ class ProviderManager:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
@@ -210,15 +214,37 @@ class ProviderManager:
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
||||
|
||||
temp_provider = (
|
||||
self.inst_map.get(selected_provider_id)
|
||||
if isinstance(selected_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_provider_inst = (
|
||||
temp_provider if isinstance(temp_provider, Provider) else None
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
|
||||
temp_stt = (
|
||||
self.inst_map.get(selected_stt_provider_id)
|
||||
if isinstance(selected_stt_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_stt_provider_inst = (
|
||||
temp_stt if isinstance(temp_stt, STTProvider) else None
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
|
||||
temp_tts = (
|
||||
self.inst_map.get(selected_tts_provider_id)
|
||||
if isinstance(selected_tts_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_tts_provider_inst = (
|
||||
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
@@ -358,73 +384,103 @@ class ProviderManager:
|
||||
|
||||
provider_metadata.id = provider_config["id"]
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
match provider_metadata.provider_type:
|
||||
case ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
if not issubclass(cls_type, STTProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of STTProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_stt_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
||||
)
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
case ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
if not issubclass(cls_type, TTSProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||||
)
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
case ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
if not issubclass(cls_type, Provider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of Provider"
|
||||
)
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
)
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
||||
)
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||||
case ProviderType.EMBEDDING:
|
||||
if not issubclass(cls_type, EmbeddingProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
case ProviderType.RERANK:
|
||||
if not issubclass(cls_type, RerankProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
self.rerank_provider_insts.append(inst)
|
||||
case _:
|
||||
# 未知供应商抛出异常,确保inst初始化
|
||||
# Should be unreachable
|
||||
raise Exception(
|
||||
f"未知的提供商类型:{provider_metadata.provider_type}"
|
||||
)
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.provider_settings.get("default_provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
||||
)
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
elif provider_metadata.provider_type == ProviderType.RERANK:
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.rerank_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,6 +2,7 @@ import abc
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
Providers: TypeAlias = Union[
|
||||
"Provider",
|
||||
"STTProvider",
|
||||
"TTSProvider",
|
||||
"EmbeddingProvider",
|
||||
"RerankProvider",
|
||||
]
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
"""Provider Abstract Class"""
|
||||
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
|
||||
"""
|
||||
...
|
||||
if False: # pragma: no cover - make this an async generator for typing
|
||||
yield None # type: ignore
|
||||
raise NotImplementedError()
|
||||
|
||||
async def pop_record(self, context: list):
|
||||
"""弹出 context 第一条非系统提示词对话记录"""
|
||||
|
||||
@@ -29,15 +29,24 @@ class OTTSProvider:
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.client = None
|
||||
self._client: AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError(
|
||||
"Client not initialized. Please use 'async with' context."
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(timeout=self.timeout)
|
||||
self._client = AsyncClient(timeout=self.timeout)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _sync_time(self):
|
||||
try:
|
||||
@@ -90,6 +99,7 @@ class OTTSProvider:
|
||||
if attempt == self.retry_count - 1:
|
||||
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
raise RuntimeError("OTTS未返回音频文件")
|
||||
|
||||
|
||||
class AzureNativeProvider(TTSProvider):
|
||||
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
|
||||
self.endpoint = (
|
||||
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
self.client = None
|
||||
self._client: AsyncClient | None = None
|
||||
self.token = None
|
||||
self.token_expire = 0
|
||||
self.voice_params = {
|
||||
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
|
||||
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||
}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError(
|
||||
"Client not initialized. Please use 'async with' context."
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(
|
||||
self._client = AsyncClient(
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _refresh_token(self):
|
||||
token_url = (
|
||||
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
|
||||
key_value = provider_config.get("azure_tts_subscription_key", "")
|
||||
self.provider = self._parse_provider(key_value, provider_config)
|
||||
|
||||
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
||||
def _parse_provider(
|
||||
self, key_value: str, config: dict
|
||||
) -> OTTSProvider | AzureNativeProvider:
|
||||
if key_value.lower().startswith("other["):
|
||||
json_str = ""
|
||||
try:
|
||||
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
||||
if not match:
|
||||
|
||||
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
|
||||
Returns:
|
||||
重排序结果列表
|
||||
"""
|
||||
if not self.client:
|
||||
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
logger.warning("文档列表为空,返回空结果")
|
||||
return []
|
||||
|
||||
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"messages": None,
|
||||
"api_key": self.chosen_api_key,
|
||||
"voice": self.voice or "Cherry",
|
||||
"text": text,
|
||||
}
|
||||
if not self.voice:
|
||||
logging.warning(
|
||||
|
||||
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=mp3_path, output=wav_path)
|
||||
ff.convert(input_file=mp3_path, output_file=wav_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
# use ffmpeg command line
|
||||
|
||||
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str:
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
|
||||
Args:
|
||||
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
pattern = r"^[a-fA-F0-9]{32}$"
|
||||
return bool(re.match(pattern, reference_id.strip()))
|
||||
|
||||
async def _generate_request(self, text: str) -> dict:
|
||||
async def _generate_request(self, text: str) -> ServeTTSRequest:
|
||||
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
||||
if self.reference_id and self.reference_id.strip():
|
||||
# 验证reference_id格式
|
||||
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
text = await response.aread()
|
||||
body = await response.aread()
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
api_key: str = provider_config.get("embedding_api_key")
|
||||
api_base: str = provider_config.get("embedding_api_base")
|
||||
api_key: str = provider_config["embedding_api_key"]
|
||||
api_base: str = provider_config["embedding_api_base"]
|
||||
timeout: int = int(provider_config.get("timeout", 20))
|
||||
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
model=self.model,
|
||||
contents=text,
|
||||
)
|
||||
assert result.embeddings is not None
|
||||
assert result.embeddings[0].values is not None
|
||||
return result.embeddings[0].values
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的嵌入"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model,
|
||||
contents=texts,
|
||||
contents=cast(types.ContentListUnion, text),
|
||||
)
|
||||
return [embedding.values for embedding in result.embeddings]
|
||||
assert result.embeddings is not None
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
for embedding in result.embeddings:
|
||||
assert embedding.values is not None
|
||||
embeddings.append(embedding.values)
|
||||
return embeddings
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -126,17 +127,17 @@ class ProviderGoogleGenAI(Provider):
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
if not modalities:
|
||||
modalities = ["Text"]
|
||||
modalities = ["TEXT"]
|
||||
|
||||
# 流式输出不支持图片模态
|
||||
if (
|
||||
self.provider_settings.get("streaming_response", False)
|
||||
and "Image" in modalities
|
||||
and "IMAGE" in modalities
|
||||
):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
modalities = ["TEXT"]
|
||||
|
||||
tool_list = []
|
||||
tool_list: list[types.Tool] | None = []
|
||||
model_name = self.get_model()
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
logprobs=payloads.get("logprobs"),
|
||||
seed=payloads.get("seed"),
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
tools=cast(types.ToolListUnion | None, tool_list),
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=(
|
||||
types.ThinkingConfig(
|
||||
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
content_cls: type[types.Content],
|
||||
) -> None:
|
||||
if contents and isinstance(contents[-1], content_cls):
|
||||
assert contents[-1].parts is not None
|
||||
contents[-1].parts.extend(part)
|
||||
else:
|
||||
contents.append(content_cls(parts=part))
|
||||
@@ -429,9 +431,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
None,
|
||||
)
|
||||
|
||||
modalities = ["Text"]
|
||||
modalities = ["TEXT"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("Image")
|
||||
modalities.append("IMAGE")
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
temperature = payloads.get("temperature", 0.7)
|
||||
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
logger.debug(f"genai result: {result}")
|
||||
@@ -488,7 +490,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
||||
)
|
||||
modalities = ["Text"]
|
||||
modalities = ["TEXT"]
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
break
|
||||
|
||||
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
|
||||
return json.dumps(dict_body)
|
||||
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
|
||||
"""进行流式请求"""
|
||||
try:
|
||||
async with (
|
||||
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
data = json.loads(message[6:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
audio: str | None = data.get("data", {}).get(
|
||||
"audio"
|
||||
)
|
||||
if audio is not None:
|
||||
yield audio
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的嵌入"""
|
||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
||||
embeddings = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
|
||||
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if isinstance(tool_call, str):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
if tools is None:
|
||||
# 工具集未提供
|
||||
# Should be unreachable
|
||||
raise Exception("工具集未提供")
|
||||
for tool in tools.func_list:
|
||||
if (
|
||||
tool_call.type == "function"
|
||||
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from funasr_onnx import SenseVoiceSmall
|
||||
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
||||
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("stt_model"))
|
||||
self.set_model(provider_config["stt_model"])
|
||||
self.model = None
|
||||
self.is_emotion = provider_config.get("is_emotion", False)
|
||||
|
||||
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
loop = asyncio.get_event_loop()
|
||||
res = await loop.run_in_executor(
|
||||
None, # 使用默认的线程池
|
||||
lambda: self.model(audio_url, language="auto", use_itn=True),
|
||||
lambda: cast(SenseVoiceSmall, self.model)(
|
||||
audio_url, language="auto", use_itn=True
|
||||
),
|
||||
)
|
||||
|
||||
# res = self.model(audio_url, language="auto", use_itn=True)
|
||||
|
||||
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
assert self.client is not None
|
||||
async with self.client.post(
|
||||
f"{self.base_url}/v1/rerank",
|
||||
json=payload,
|
||||
|
||||
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
|
||||
async def _get_audio_format(self, file_path):
|
||||
# 定义要检测的头部字节
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import whisper
|
||||
|
||||
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
self.model = None
|
||||
|
||||
async def initialize(self):
|
||||
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
if not self.model:
|
||||
raise RuntimeError("Whisper 模型未初始化")
|
||||
|
||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||
return result["text"]
|
||||
return cast(str, result["text"])
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import cast
|
||||
|
||||
from xinference_client.client.restful.async_restful_client import (
|
||||
AsyncClient as Client,
|
||||
)
|
||||
from xinference_client.client.restful.async_restful_client import (
|
||||
AsyncRESTfulRerankModelHandle,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
|
||||
False,
|
||||
)
|
||||
self.client = None
|
||||
self.model = None
|
||||
self.model: AsyncRESTfulRerankModelHandle | None = None
|
||||
self.model_uid = None
|
||||
|
||||
async def initialize(self):
|
||||
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
|
||||
return
|
||||
|
||||
if self.model_uid:
|
||||
self.model = await self.client.get_model(self.model_uid)
|
||||
self.model = cast(
|
||||
AsyncRESTfulRerankModelHandle,
|
||||
await self.client.get_model(self.model_uid),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Xinference model: {e}")
|
||||
|
||||
@@ -2,15 +2,19 @@ from astrbot.core import html_renderer
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.star.star_tools import StarTools
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||
|
||||
from .context import Context
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
from .star_manager import PluginManager
|
||||
|
||||
|
||||
class Star(CommandParserMixin):
|
||||
class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
author: str
|
||||
name: str
|
||||
|
||||
def __init__(self, context: Context, config: dict | None = None):
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
@@ -285,7 +285,7 @@ class Context:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str | None = None) -> Provider | None:
|
||||
def get_using_provider(self, umo: str | None = None) -> Provider:
|
||||
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
Args:
|
||||
@@ -296,7 +296,7 @@ class Context:
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
if prov and not isinstance(prov, Provider):
|
||||
if not isinstance(prov, Provider):
|
||||
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||
return prov
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import docstring_parser
|
||||
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
|
||||
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
||||
|
||||
|
||||
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
|
||||
def get_handler_full_name(
|
||||
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> str:
|
||||
"""获取 Handler 的全名"""
|
||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||
|
||||
|
||||
def get_handler_or_create(
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[
|
||||
...,
|
||||
Awaitable[MessageEventResult | str | None]
|
||||
| AsyncGenerator[MessageEventResult | str | None],
|
||||
],
|
||||
event_type: EventType,
|
||||
dont_add=False,
|
||||
**kwargs,
|
||||
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
||||
for (
|
||||
sub_handle
|
||||
) in parent_register_commandable.parent_group.sub_command_filters:
|
||||
if isinstance(sub_handle, CommandGroupFilter):
|
||||
continue
|
||||
# 所有符合fullname一致的子指令handle添加自定义过滤器。
|
||||
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
|
||||
sub_handle_md = sub_handle.get_handler_md()
|
||||
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
||||
|
||||
else:
|
||||
# 裸指令
|
||||
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
|
||||
assert isinstance(awaitable, Callable)
|
||||
handler_md = get_handler_or_create(
|
||||
awaitable,
|
||||
EventType.AdapterMessageEvent,
|
||||
@@ -237,7 +248,7 @@ class RegisteringCommandable:
|
||||
|
||||
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
|
||||
command: Callable[..., Callable[..., None]] = register_command
|
||||
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
|
||||
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
|
||||
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
self.parent_group = parent_group
|
||||
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
||||
if kwargs.get("registering_agent"):
|
||||
registering_agent = kwargs["registering_agent"]
|
||||
|
||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||
def decorator(
|
||||
awaitable: Callable[
|
||||
...,
|
||||
AsyncGenerator[MessageEventResult | str | None]
|
||||
| Awaitable[MessageEventResult | str | None],
|
||||
],
|
||||
):
|
||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||
func_doc = awaitable.__doc__ or ""
|
||||
docstring = docstring_parser.parse(func_doc)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from .filter import HandlerFilter
|
||||
from .star import star_map
|
||||
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
|
||||
for handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnAstrBotLoadedEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnPlatformLoadedEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.AdapterMessageEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[
|
||||
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnLLMRequestEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnLLMResponseEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnDecoratingResultEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnCallingFuncToolEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[
|
||||
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnAfterMessageSentEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: EventType,
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[
|
||||
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||
]: ...
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: EventType,
|
||||
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
|
||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||
|
||||
|
||||
H = TypeVar("H", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarHandlerMetadata:
|
||||
class StarHandlerMetadata(Generic[H]):
|
||||
"""描述一个 Star 所注册的某一个 Handler。"""
|
||||
|
||||
event_type: EventType
|
||||
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
|
||||
handler_module_path: str
|
||||
"""Handler 所在的模块路径。"""
|
||||
|
||||
handler: Callable[..., Awaitable[Any]]
|
||||
handler: H
|
||||
"""Handler 的函数对象,应当是一个异步函数"""
|
||||
|
||||
event_filters: list[HandlerFilter]
|
||||
|
||||
@@ -467,6 +467,18 @@ class PluginManager:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context,
|
||||
)
|
||||
|
||||
p_name = (metadata.name or "unknown").lower().replace("/", "_")
|
||||
p_author = (
|
||||
(metadata.author or "unknown").lower().replace("/", "_")
|
||||
)
|
||||
setattr(metadata.star_cls, "name", p_name)
|
||||
setattr(metadata.star_cls, "author", p_author)
|
||||
setattr(
|
||||
metadata.star_cls,
|
||||
"plugin_id",
|
||||
f"{p_author}/{p_name}",
|
||||
)
|
||||
else:
|
||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||
|
||||
|
||||
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
|
||||
async def check_update(
|
||||
self,
|
||||
url: str,
|
||||
current_version: str,
|
||||
url: str | None,
|
||||
current_version: str | None,
|
||||
consider_prerelease: bool = True,
|
||||
) -> ReleaseInfo:
|
||||
) -> ReleaseInfo | None:
|
||||
"""检查更新"""
|
||||
return await super().check_update(
|
||||
self.ASTRBOT_RELEASE_API,
|
||||
|
||||
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
return False
|
||||
|
||||
|
||||
def save_temp_img(img: Image.Image | str) -> str:
|
||||
def save_temp_img(img: Image.Image | bytes) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import TypeVar
|
||||
|
||||
from astrbot.core import sp
|
||||
|
||||
SUPPORTED_VALUE_TYPES = int | float | str | bytes | bool | dict | list | None
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class PluginKVStoreMixin:
|
||||
"""为插件提供键值存储功能的 Mixin 类"""
|
||||
|
||||
plugin_id: str
|
||||
|
||||
async def put_kv_data(
|
||||
self,
|
||||
key: str,
|
||||
value: SUPPORTED_VALUE_TYPES,
|
||||
) -> None:
|
||||
"""为指定插件存储一个键值对"""
|
||||
await sp.put_async("plugin", self.plugin_id, key, value)
|
||||
|
||||
async def get_kv_data(self, key: str, default: _VT) -> _VT | None:
|
||||
"""获取指定插件存储的键值对"""
|
||||
return await sp.get_async("plugin", self.plugin_id, key, default)
|
||||
|
||||
async def delete_kv_data(self, key: str) -> None:
|
||||
"""删除指定插件存储的键值对"""
|
||||
await sp.remove_async("plugin", self.plugin_id, key)
|
||||
@@ -20,16 +20,16 @@ class SessionController:
|
||||
|
||||
def __init__(self):
|
||||
self.future = asyncio.Future()
|
||||
self.current_event: asyncio.Event = None
|
||||
self.current_event: asyncio.Event | None = None
|
||||
"""当前正在等待的所用的异步事件"""
|
||||
self.ts: float = None
|
||||
self.ts: float | None = None
|
||||
"""上次保持(keep)开始时的时间"""
|
||||
self.timeout: float | int = None
|
||||
self.timeout: float | int | None = None
|
||||
"""上次保持(keep)开始时的超时时间"""
|
||||
|
||||
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
||||
|
||||
def stop(self, error: Exception = None):
|
||||
def stop(self, error: Exception | None = None):
|
||||
"""立即结束这个会话"""
|
||||
if not self.future.done():
|
||||
if error:
|
||||
@@ -53,6 +53,8 @@ class SessionController:
|
||||
self.stop()
|
||||
return
|
||||
else:
|
||||
assert self.timeout is not None
|
||||
assert self.ts is not None
|
||||
left_timeout = self.timeout - (new_ts - self.ts)
|
||||
timeout = left_timeout + timeout
|
||||
if timeout <= 0:
|
||||
@@ -69,7 +71,7 @@ class SessionController:
|
||||
|
||||
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
|
||||
async def _holding(self, event: asyncio.Event, timeout: int):
|
||||
async def _holding(self, event: asyncio.Event, timeout: float):
|
||||
"""等待事件结束或超时"""
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout)
|
||||
@@ -108,7 +110,9 @@ class SessionWaiter:
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.session_filter = session_filter
|
||||
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
|
||||
self.handler: (
|
||||
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
||||
) = None # 处理函数
|
||||
|
||||
self.session_controller = SessionController()
|
||||
self.record_history_chains = record_history_chains
|
||||
@@ -119,7 +123,7 @@ class SessionWaiter:
|
||||
|
||||
async def register_wait(
|
||||
self,
|
||||
handler: Callable[[str], Awaitable[Any]],
|
||||
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
timeout: int = 30,
|
||||
) -> Any:
|
||||
"""等待外部输入并处理"""
|
||||
@@ -137,7 +141,7 @@ class SessionWaiter:
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self, error: Exception = None):
|
||||
def _cleanup(self, error: Exception | None = None):
|
||||
"""清理会话"""
|
||||
USER_SESSIONS.pop(self.session_id, None)
|
||||
try:
|
||||
@@ -161,6 +165,7 @@ class SessionWaiter:
|
||||
)
|
||||
try:
|
||||
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
||||
assert session.handler is not None
|
||||
await session.handler(session.session_controller, event)
|
||||
except Exception as e:
|
||||
session.session_controller.stop(e)
|
||||
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
||||
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[str], Awaitable[Any]]):
|
||||
def decorator(
|
||||
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(
|
||||
event: AstrMessageEvent,
|
||||
session_filter: SessionFilter = None,
|
||||
session_filter: SessionFilter | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@@ -53,6 +53,38 @@ class SharedPreferences:
|
||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||
return ret
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
) -> _VT: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str | None,
|
||||
|
||||
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class RenderStrategy(ABC):
|
||||
@abstractmethod
|
||||
def render(self, text: str, return_url: bool) -> str:
|
||||
async def render(self, text: str, return_url: bool) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_custom_template(
|
||||
async def render_custom_template(
|
||||
self,
|
||||
tmpl_str: str,
|
||||
tmpl_data: dict,
|
||||
|
||||
@@ -20,7 +20,7 @@ class FontManager:
|
||||
_font_cache = {}
|
||||
|
||||
@classmethod
|
||||
def get_font(cls, size: int) -> ImageFont.FreeTypeFont:
|
||||
def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont:
|
||||
"""获取指定大小的字体,优先从缓存获取"""
|
||||
if size in cls._font_cache:
|
||||
return cls._font_cache[size]
|
||||
@@ -66,23 +66,17 @@ class TextMeasurer:
|
||||
"""测量文本尺寸的工具类"""
|
||||
|
||||
@staticmethod
|
||||
def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]:
|
||||
def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]:
|
||||
"""获取文本的尺寸"""
|
||||
try:
|
||||
# PIL 9.0.0 以上版本
|
||||
return (
|
||||
font.getbbox(text)[2:]
|
||||
if hasattr(font, "getbbox")
|
||||
else font.getsize(text)
|
||||
)
|
||||
except Exception:
|
||||
# 兼容旧版本
|
||||
return font.getsize(text)
|
||||
|
||||
# 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0
|
||||
left, top, right, bottom = font.getbbox("Hello world")
|
||||
return int(right - left), int(bottom - top)
|
||||
|
||||
@staticmethod
|
||||
def split_text_to_fit_width(
|
||||
text: str, font: ImageFont.FreeTypeFont, max_width: int
|
||||
) -> List[str]:
|
||||
text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int
|
||||
) -> list[str]:
|
||||
"""将文本拆分为多行,确保每行不超过指定宽度"""
|
||||
lines = []
|
||||
if not text:
|
||||
@@ -126,7 +120,7 @@ class MarkdownElement(ABC):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -152,7 +146,7 @@ class TextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement):
|
||||
# 倾斜变换,使用仿射变换实现斜体效果
|
||||
# 变换矩阵: [1, 0.2, 0, 0, 1, 0]
|
||||
italic_img = text_img.transform(
|
||||
text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC
|
||||
text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC
|
||||
)
|
||||
|
||||
# 粘贴到原图像
|
||||
@@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement):
|
||||
class CodeBlockElement(MarkdownElement):
|
||||
"""代码块元素"""
|
||||
|
||||
def __init__(self, content: List[str]):
|
||||
def __init__(self, content: list[str]):
|
||||
super().__init__("\n".join(content))
|
||||
|
||||
def calculate_height(self, image_width: int, font_size: int) -> int:
|
||||
@@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -667,7 +661,7 @@ class ImageElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -686,7 +680,7 @@ class ImageElement(MarkdownElement):
|
||||
if pasted_image.width > max_width:
|
||||
ratio = max_width / pasted_image.width
|
||||
new_size = (int(max_width), int(pasted_image.height * ratio))
|
||||
pasted_image = pasted_image.resize(new_size, Image.LANCZOS)
|
||||
pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 计算居中位置
|
||||
paste_x = x + (image_width - pasted_image.width) // 2 - 10
|
||||
@@ -705,7 +699,7 @@ class MarkdownParser:
|
||||
"""Markdown解析器,将文本解析为元素"""
|
||||
|
||||
@staticmethod
|
||||
async def parse(text: str) -> List[MarkdownElement]:
|
||||
async def parse(text: str) -> list[MarkdownElement]:
|
||||
elements = []
|
||||
lines = text.split("\n")
|
||||
|
||||
@@ -847,7 +841,7 @@ class MarkdownRenderer:
|
||||
self,
|
||||
font_size: int = 26,
|
||||
width: int = 800,
|
||||
bg_color: Tuple[int, int, int] = (255, 255, 255),
|
||||
bg_color: tuple[int, int, int] = (255, 255, 255),
|
||||
):
|
||||
self.font_size = font_size
|
||||
self.width = width
|
||||
|
||||
@@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=input_path, output=output_path)
|
||||
ff.convert(input_file=input_path, output_file=output_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
|
||||
|
||||
@@ -60,9 +60,12 @@ class VersionComparator:
|
||||
return -1
|
||||
if isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
if (isinstance(p1, int) and isinstance(p2, int)) or (
|
||||
isinstance(p1, str) and isinstance(p2, str)
|
||||
):
|
||||
if isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
if p1 < p2:
|
||||
return -1
|
||||
if isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
if p1 < p2:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import uuid
|
||||
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.config.default import WEBHOOK_SUPPORTED_PLATFORMS
|
||||
|
||||
|
||||
def _get_callback_api_base() -> str:
|
||||
@@ -45,3 +48,19 @@ def log_webhook_info(platform_name: str, webhook_uuid: str):
|
||||
"====================\n"
|
||||
)
|
||||
logger.info(display_log)
|
||||
|
||||
|
||||
def ensure_platform_webhook_config(platform_cfg: dict) -> bool:
|
||||
"""为支持统一 webhook 的平台自动生成 webhook_uuid
|
||||
|
||||
Args:
|
||||
platform_cfg (dict): 平台配置字典
|
||||
|
||||
Returns:
|
||||
bool: 如果生成了 webhook_uuid 则返回 True,否则返回 False
|
||||
"""
|
||||
pt = platform_cfg.get("type", "")
|
||||
if pt in WEBHOOK_SUPPORTED_PLATFORMS and not platform_cfg.get("webhook_uuid"):
|
||||
platform_cfg["webhook_uuid"] = uuid.uuid4().hex[:16]
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -4,7 +4,9 @@ import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
from quart import Response as QuartResponse
|
||||
from quart import g, make_response, request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
@@ -424,16 +426,19 @@ class ChatRoute(Route):
|
||||
sender_name=username,
|
||||
)
|
||||
|
||||
response = await make_response(
|
||||
stream(),
|
||||
{
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
response = cast(
|
||||
QuartResponse,
|
||||
await make_response(
|
||||
stream(),
|
||||
{
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
),
|
||||
)
|
||||
response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue]
|
||||
response.timeout = None # fix SSE auto disconnect issue
|
||||
return response
|
||||
|
||||
async def delete_webchat_session(self):
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from quart import request
|
||||
|
||||
@@ -14,7 +14,6 @@ from astrbot.core.config.default import (
|
||||
CONFIG_METADATA_3_SYSTEM,
|
||||
DEFAULT_CONFIG,
|
||||
DEFAULT_VALUE_MAP,
|
||||
WEBHOOK_SUPPORTED_PLATFORMS,
|
||||
)
|
||||
from astrbot.core.config.i18n_utils import ConfigMetadataI18n
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -22,11 +21,12 @@ from astrbot.core.platform.register import platform_cls_map, platform_registry
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
def try_cast(value: Any, type_: str):
|
||||
if type_ == "int":
|
||||
try:
|
||||
return int(value)
|
||||
@@ -505,9 +505,9 @@ class ConfigRoute(Route):
|
||||
if not isinstance(inst, EmbeddingProvider):
|
||||
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
|
||||
|
||||
# 初始化
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
init_fn = getattr(inst, "initialize", None)
|
||||
if inspect.iscoroutinefunction(init_fn):
|
||||
await init_fn()
|
||||
|
||||
# 获取嵌入向量维度
|
||||
vec = await inst.get_embedding("echo")
|
||||
@@ -558,13 +558,8 @@ class ConfigRoute(Route):
|
||||
async def post_new_platform(self):
|
||||
new_platform_config = await request.json
|
||||
|
||||
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,自动生成 webhook_uuid
|
||||
platform_type = new_platform_config.get("type", "")
|
||||
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
|
||||
if new_platform_config.get("unified_webhook_mode", False):
|
||||
# 如果没有 webhook_uuid 或为空,自动生成
|
||||
if not new_platform_config.get("webhook_uuid"):
|
||||
new_platform_config["webhook_uuid"] = uuid.uuid4().hex[:16]
|
||||
# 如果是支持统一 webhook 模式的平台,生成 webhook_uuid
|
||||
ensure_platform_webhook_config(new_platform_config)
|
||||
|
||||
self.config["platform"].append(new_platform_config)
|
||||
try:
|
||||
@@ -596,12 +591,7 @@ class ConfigRoute(Route):
|
||||
return Response().error("参数错误").__dict__
|
||||
|
||||
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid
|
||||
platform_type = new_config.get("type", "")
|
||||
if platform_type in WEBHOOK_SUPPORTED_PLATFORMS:
|
||||
if new_config.get("unified_webhook_mode", False):
|
||||
# 如果没有 webhook_uuid 或为空,自动生成
|
||||
if not new_config.get("webhook_uuid"):
|
||||
new_config["webhook_uuid"] = uuid.uuid4().hex
|
||||
ensure_platform_webhook_config(new_config)
|
||||
|
||||
for i, platform in enumerate(self.config["platform"]):
|
||||
if platform["id"] == platform_id:
|
||||
@@ -777,7 +767,7 @@ class ConfigRoute(Route):
|
||||
return {"metadata": CONFIG_METADATA_2, "config": config}
|
||||
|
||||
async def _get_plugin_config(self, plugin_name: str):
|
||||
ret = {"metadata": None, "config": None}
|
||||
ret: dict = {"metadata": None, "config": None}
|
||||
|
||||
for plugin_md in star_registry:
|
||||
if plugin_md.name == plugin_name:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
from quart import request
|
||||
from quart import request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -30,6 +32,7 @@ class ConversationRoute(Route):
|
||||
"POST",
|
||||
self.update_history,
|
||||
),
|
||||
"/conversation/export": ("POST", self.export_conversations),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
@@ -283,3 +286,90 @@ class ConversationRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
||||
|
||||
async def export_conversations(self):
|
||||
"""批量导出对话为 JSONL 格式"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
conversations_to_export = data.get("conversations", [])
|
||||
|
||||
if not conversations_to_export:
|
||||
return Response().error("导出列表不能为空").__dict__
|
||||
|
||||
# 收集所有对话的内容
|
||||
jsonl_lines = []
|
||||
exported_count = 0
|
||||
failed_items = []
|
||||
|
||||
for conv_info in conversations_to_export:
|
||||
user_id = conv_info.get("user_id")
|
||||
cid = conv_info.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 对话不存在"
|
||||
)
|
||||
continue
|
||||
|
||||
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
|
||||
content = json.loads(conversation.history)
|
||||
|
||||
# 创建导出记录
|
||||
export_record = {
|
||||
"cid": cid,
|
||||
"user_id": user_id,
|
||||
"platform_id": conversation.platform_id,
|
||||
"title": conversation.title,
|
||||
"persona_id": conversation.persona_id,
|
||||
"created_at": conversation.created_at,
|
||||
"updated_at": conversation.updated_at,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# 将记录转换为 JSON 字符串并添加到 JSONL
|
||||
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
|
||||
exported_count += 1
|
||||
|
||||
except Exception as e:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
||||
logger.error(
|
||||
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
|
||||
)
|
||||
|
||||
if exported_count == 0:
|
||||
return Response().error("没有成功导出任何对话").__dict__
|
||||
|
||||
# 创建 JSONL 内容
|
||||
jsonl_content = "\n".join(jsonl_lines)
|
||||
|
||||
# 创建一个内存文件对象
|
||||
file_obj = BytesIO(jsonl_content.encode("utf-8"))
|
||||
file_obj.seek(0)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
|
||||
|
||||
# 返回文件流
|
||||
return await send_file(
|
||||
file_obj,
|
||||
mimetype="application/jsonl",
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"批量导出对话失败: {e!s}").__dict__
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user