Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6837d4d692 | |||
| 8aba83735b | |||
| aa51187747 | |||
| 5f07a9ae95 | |||
| a2ca767bf4 | |||
| 5806c74e7c | |||
| 0481e1d45e | |||
| 3177b61421 | |||
| 6009cf5dfa | |||
| 0a970e8c31 | |||
| aa276ca6af | |||
| 9f02dd13ff | |||
| 609e723322 | |||
| c564a1d53e | |||
| a7fe31f28b | |||
| a84dc599d6 | |||
| 8da029add9 | |||
| ba45a2d270 | |||
| cb56b22aea | |||
| 23cc5b31ba | |||
| e8d99f0460 | |||
| 6bcd10cd5c | |||
| 619fb20c5f | |||
| 386a312e96 | |||
| 2759d347e6 | |||
| b6ec327b49 | |||
| ee02d622ba | |||
| 5c4a6083f5 | |||
| 49e63a3d3d | |||
| 6bae9dc9ed | |||
| 5fa1979a46 | |||
| b40d4fa315 | |||
| 4d2ff7cd5b | |||
| d8ec0e64d0 | |||
| 82e979cc07 | |||
| 8c132a51f5 | |||
| 40bd372cc1 | |||
| 212e114270 | |||
| b0e9de6951 | |||
| 3489522bbb | |||
| 96237abc03 |
+2
-1
@@ -21,4 +21,5 @@ node_modules/
|
||||
.DS_Store
|
||||
package-lock.json
|
||||
package.json
|
||||
venv/*
|
||||
venv/*
|
||||
packages/python_interpreter/workplace
|
||||
|
||||
+3
-1
@@ -12,7 +12,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install -r requirements.txt
|
||||
RUN python -m pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
@@ -28,7 +27,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
## ✨ 主要功能
|
||||
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat、VChat)、Telegram。后续将支持钉钉、飞书、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
@@ -73,8 +72,8 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 飞书 | ✔ | 群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| 飞书 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.22"
|
||||
VERSION = "3.4.25"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -69,7 +69,9 @@ DEFAULT_CONFIG = {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
||||
},
|
||||
"admins_id": [],
|
||||
"admins_id": [
|
||||
"astrbot"
|
||||
],
|
||||
"t2i": False,
|
||||
"t2i_word_threshold": 150,
|
||||
"http_proxy": "",
|
||||
@@ -123,12 +125,21 @@ CONFIG_METADATA_2 = {
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 11451,
|
||||
},
|
||||
"lark(飞书)": {
|
||||
"id": "lark",
|
||||
"type": "lark",
|
||||
"enable": False,
|
||||
"lark_bot_name": "",
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"domain": "https://open.feishu.cn"
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。",
|
||||
"hint": "用于在多实例下方便管理和识别。自定义,ID 不能重复。",
|
||||
},
|
||||
"type": {
|
||||
"description": "适配器类型",
|
||||
@@ -170,6 +181,12 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||
},
|
||||
"lark_bot_name": {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"obvious_hint": True
|
||||
}
|
||||
},
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -276,7 +293,7 @@ CONFIG_METADATA_2 = {
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"content_safety": {
|
||||
@@ -404,7 +421,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
},
|
||||
"硅基流动": {
|
||||
"siliconflow": {
|
||||
"id": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
@@ -415,6 +432,17 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
},
|
||||
},
|
||||
"moonshot(kimi)": {
|
||||
"id": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {
|
||||
"model": "moonshot-v1-8k",
|
||||
},
|
||||
},
|
||||
"llmtuner": {
|
||||
"id": "llmtuner_default",
|
||||
"type": "llm_tuner",
|
||||
@@ -433,6 +461,7 @@ CONFIG_METADATA_2 = {
|
||||
"dify_api_key": "",
|
||||
"dify_api_base": "https://api.dify.ai/v1",
|
||||
"dify_workflow_output_key": "",
|
||||
"timeout": 60,
|
||||
},
|
||||
"whisper(API)": {
|
||||
"id": "whisper",
|
||||
@@ -459,6 +488,15 @@ CONFIG_METADATA_2 = {
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
},
|
||||
"fishaudio_tts(API)": {
|
||||
"id": "fishaudio_tts",
|
||||
"type": "fishaudio_tts_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "https://api.fish-audio.cn/v1",
|
||||
"fishaudio-tts-character": "可莉",
|
||||
"timeout": "20",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"timeout": {
|
||||
@@ -472,6 +510,12 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||
},
|
||||
"fishaudio-tts-character": {
|
||||
"description": "character",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问:https://fish.audio/zh-CN/discovery",
|
||||
},
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
@@ -725,7 +769,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"description": "图像转述提示词",
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
},
|
||||
"active_reply": {
|
||||
"description": "主动回复",
|
||||
@@ -756,7 +800,7 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -811,7 +855,8 @@ CONFIG_METADATA_2 = {
|
||||
"plugin_repo_mirror": {
|
||||
"description": "插件仓库镜像",
|
||||
"type": "string",
|
||||
"hint": "插件仓库的镜像地址,用于加速插件的下载。",
|
||||
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
|
||||
"obvious_hint": True,
|
||||
"options": [
|
||||
"default",
|
||||
"https://ghp.ci/",
|
||||
|
||||
@@ -325,11 +325,13 @@ class RedBag(BaseMessageComponent):
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type: ComponentType = "Poke"
|
||||
qq: int
|
||||
type: str = ""
|
||||
id: T.Optional[int] = 0
|
||||
qq: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
def __init__(self, type: str, **_):
|
||||
type = f"Poke:{type}"
|
||||
super().__init__(type=type, **_)
|
||||
|
||||
|
||||
class Forward(BaseMessageComponent):
|
||||
@@ -339,14 +341,14 @@ class Forward(BaseMessageComponent):
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForwardMessage 发送
|
||||
class Node(BaseMessageComponent):
|
||||
'''群合并转发消息'''
|
||||
type: ComponentType = "Node"
|
||||
id: T.Optional[int] = 0
|
||||
name: T.Optional[str] = ""
|
||||
uin: T.Optional[int] = 0
|
||||
content: T.Optional[T.Union[str, list]] = ""
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 不清楚是什么
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[int] = 0 # qq号
|
||||
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, content: T.Union[str, list], **_):
|
||||
|
||||
@@ -31,7 +31,7 @@ class StarRequestSubStage(Stage):
|
||||
# 孤立无援的 star handler
|
||||
continue
|
||||
|
||||
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
|
||||
logger.debug(f"执行插件 handler {handler.handler_full_name}")
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
|
||||
@@ -18,7 +18,6 @@ class ResultDecorateStage:
|
||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
|
||||
self.t2i_word_threshold = ctx.astrbot_config['t2i_word_threshold']
|
||||
try:
|
||||
self.t2i_word_threshold = int(self.t2i_word_threshold)
|
||||
@@ -39,9 +38,8 @@ class ResultDecorateStage:
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
|
||||
for handler in handlers:
|
||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
||||
await handler.handler(event)
|
||||
|
||||
|
||||
if len(result.chain) > 0:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
@@ -69,7 +67,7 @@ class ResultDecorateStage:
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
if self.use_tts and result.is_llm_result():
|
||||
if self.ctx.astrbot_config['provider_tts_settings']['enable'] and result.is_llm_result():
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
@@ -84,7 +82,7 @@ class ResultDecorateStage:
|
||||
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,7 @@ from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.message.components import At, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
@@ -86,6 +86,10 @@ class WakingCheckStage(Stage):
|
||||
if len(handler.event_filters) == 0:
|
||||
# 不可能有这种情况, 也不允许有这种情况
|
||||
continue
|
||||
|
||||
if 'sub_command' in handler.extras_configs:
|
||||
# 如果是子指令
|
||||
continue
|
||||
|
||||
for filter in handler.event_filters:
|
||||
try:
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import List, Union
|
||||
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
@dataclass
|
||||
class MessageSesion:
|
||||
@@ -305,9 +305,10 @@ class AstrMessageEvent(abc.ABC):
|
||||
prompt: str,
|
||||
func_tool_manager = None,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = ""
|
||||
image_urls: List[str] = [],
|
||||
contexts: List = [],
|
||||
system_prompt: str = "",
|
||||
conversation: Conversation = None
|
||||
) -> ProviderRequest:
|
||||
'''
|
||||
创建一个 LLM 请求。
|
||||
@@ -316,10 +317,12 @@ class AstrMessageEvent(abc.ABC):
|
||||
```py
|
||||
yield event.request_llm(prompt="hi")
|
||||
```
|
||||
|
||||
prompt: 提示词
|
||||
session_id: 已经过时,留空即可
|
||||
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
||||
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
|
||||
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。
|
||||
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
|
||||
conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。
|
||||
'''
|
||||
return ProviderRequest(
|
||||
prompt = prompt,
|
||||
@@ -327,5 +330,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
image_urls = image_urls,
|
||||
func_tool = func_tool_manager,
|
||||
contexts = contexts,
|
||||
system_prompt = system_prompt
|
||||
system_prompt = system_prompt,
|
||||
conversation=conversation
|
||||
)
|
||||
@@ -15,22 +15,23 @@ class PlatformManager():
|
||||
self.settings = config['platform_settings']
|
||||
self.event_queue = event_queue
|
||||
|
||||
for platform in self.platforms_config:
|
||||
if not platform['enable']:
|
||||
continue
|
||||
match platform['type']:
|
||||
case "aiocqhttp":
|
||||
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
|
||||
case "qq_official":
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
try:
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
except BaseException:
|
||||
logger.warning("当前 astrbot 已不维护 vchat 的接入,如有需要请 pip 安装 vchat 然后重启")
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
||||
|
||||
try:
|
||||
for platform in self.platforms_config:
|
||||
if not platform['enable']:
|
||||
continue
|
||||
match platform['type']:
|
||||
case "aiocqhttp":
|
||||
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
|
||||
case "qq_official":
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
|
||||
except Exception as e:
|
||||
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。")
|
||||
|
||||
async def initialize(self):
|
||||
for platform in self.platforms_config:
|
||||
@@ -40,7 +41,7 @@ class PlatformManager():
|
||||
logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
cls_type = platform_cls_map[platform['type']]
|
||||
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
|
||||
logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
|
||||
inst = cls_type(platform, self.settings, self.event_queue)
|
||||
self.platform_insts.append(inst)
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import os
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Music, Video
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
|
||||
@@ -20,7 +18,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d['type'] = 'text'
|
||||
if isinstance(segment, (Image, Record)):
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
if segment.file and segment.file.startswith("file:///"):
|
||||
bs64_data = file_to_base64(segment.file[8:])
|
||||
@@ -28,17 +26,35 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(segment.file)
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
elif segment.file and segment.file.startswith("base64://"):
|
||||
bs64_data = segment.file
|
||||
else:
|
||||
bs64_data = file_to_base64(segment.file)
|
||||
d['data'] = {
|
||||
'file': bs64_data,
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d['data'] = {
|
||||
'qq': str(segment.qq) # 转换为字符串
|
||||
}
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
send_one_by_one = False
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Music)):
|
||||
# 转发消息不能和普通消息混在一起发送
|
||||
send_one_by_one = True
|
||||
break
|
||||
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
await super().send(message)
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
import aiohttp
|
||||
import quart
|
||||
import base64
|
||||
import datetime
|
||||
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
@@ -67,6 +68,17 @@ class SimpleGewechatClient():
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
|
||||
if 'Data' in data and 'CreateTime' in data['Data']:
|
||||
# 得到系统 UTF+8 的 ts
|
||||
tz_offset = datetime.timedelta(hours=8)
|
||||
tz = datetime.timezone(tz_offset)
|
||||
ts = datetime.datetime.now(tz).timestamp()
|
||||
create_time = data['Data']['CreateTime']
|
||||
if create_time < ts - 30:
|
||||
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
||||
return
|
||||
|
||||
|
||||
abm = AstrBotMessage()
|
||||
d = data['Data']
|
||||
|
||||
@@ -141,12 +153,11 @@ class SimpleGewechatClient():
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
|
||||
case _:
|
||||
logger.error(f"未实现的消息类型: {d['MsgType']}")
|
||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||
return
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
logger.debug(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
async def callback(self):
|
||||
@@ -300,12 +311,14 @@ class SimpleGewechatClient():
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
async def post_text(self, to_wxid, content: str):
|
||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
if ats:
|
||||
payload['ats'] = ats
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
@@ -349,4 +362,21 @@ class SimpleGewechatClient():
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送语音结果: {json_blob}")
|
||||
logger.debug(f"发送语音结果: {json_blob}")
|
||||
|
||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"fileUrl": file_url,
|
||||
"fileName": file_name
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postFile",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送文件结果: {json_blob}")
|
||||
@@ -6,7 +6,7 @@ from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, File
|
||||
from .client import SimpleGewechatClient
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
@@ -15,6 +15,8 @@ def get_wav_duration(file_path):
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
elif n_frames == 0:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
@@ -43,9 +45,31 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
# 检查@
|
||||
ats = []
|
||||
ats_names = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, At):
|
||||
ats.append(comp.qq)
|
||||
ats_names.append(comp.name)
|
||||
has_at = False
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.post_text(to_wxid, comp.text)
|
||||
text = comp.text
|
||||
payload = {
|
||||
"to_wxid": to_wxid,
|
||||
"content": text,
|
||||
}
|
||||
if not has_at and ats:
|
||||
ats = f"{','.join(ats)}"
|
||||
ats_names = f"@{' @'.join(ats_names)}"
|
||||
text = f"{ats_names} {text}"
|
||||
payload["content"] = text
|
||||
payload["ats"] = ats
|
||||
has_at = True
|
||||
await self.client.post_text(**payload)
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_url = comp.file
|
||||
img_path = ""
|
||||
@@ -81,22 +105,28 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
|
||||
print(f"duration: {duration}, {silk_path}")
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
# temp_directory = os.path.abspath('data/temp')
|
||||
# record_path = os.path.abspath(record_path)
|
||||
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
|
||||
# with open(record_path, "rb") as f:
|
||||
# record_path = f"data/temp/{uuid.uuid4()}.wav"
|
||||
# with open(record_path, "wb") as f2:
|
||||
# f2.write(f.read())
|
||||
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await self.client.post_voice(to_wxid, record_url, duration*1000)
|
||||
elif isinstance(comp, File):
|
||||
file_path = comp.file
|
||||
file_name = comp.name
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
await download_file(file_path, f"data/temp/{file_name}")
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
file_id = os.path.basename(file_path)
|
||||
file_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await self.client.post_file(to_wxid, file_url, file_id)
|
||||
else:
|
||||
logger.error(f"gewechat 暂不支持发送消息类型: {comp.type}")
|
||||
|
||||
await super().send(message)
|
||||
@@ -0,0 +1,175 @@
|
||||
import base64
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List
|
||||
from astrbot.api.message_components import Image, Plain, At
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .lark_event import LarkMessageEvent
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot import logger
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
|
||||
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
|
||||
class LarkPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
|
||||
self.appid = platform_config['app_id']
|
||||
self.appsecret = platform_config['app_secret']
|
||||
self.domain = platform_config.get('domain', lark.FEISHU_DOMAIN)
|
||||
self.bot_name = platform_config.get('lark_bot_name', "astrbot")
|
||||
|
||||
if not self.bot_name:
|
||||
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
|
||||
|
||||
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
await self.convert_msg(event)
|
||||
|
||||
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
asyncio.create_task(on_msg_event_recv(event))
|
||||
|
||||
self.event_handler = lark.EventDispatcherHandler.builder("", "") \
|
||||
.register_p2_im_message_receive_v1(do_v2_msg_event) \
|
||||
.build()
|
||||
|
||||
self.client = lark.ws.Client(
|
||||
app_id=self.appid,
|
||||
app_secret=self.appsecret,
|
||||
log_level=lark.LogLevel.ERROR,
|
||||
domain=self.domain,
|
||||
event_handler=self.event_handler
|
||||
)
|
||||
|
||||
self.lark_api = (
|
||||
lark.Client.builder()
|
||||
.app_id(self.appid)
|
||||
.app_secret(self.appsecret)
|
||||
.build()
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"lark",
|
||||
"飞书机器人官方 API 适配器",
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
message = event.event.message
|
||||
abm = AstrBotMessage()
|
||||
abm.timestamp = int(message.create_time) / 1000
|
||||
abm.message = []
|
||||
abm.type = MessageType.GROUP_MESSAGE if message.chat_type == 'group' else MessageType.FRIEND_MESSAGE
|
||||
if message.chat_type == 'group':
|
||||
abm.group_id = message.chat_id
|
||||
abm.self_id = self.bot_name
|
||||
abm.message_str = ""
|
||||
|
||||
at_list = {}
|
||||
if message.mentions:
|
||||
for m in message.mentions:
|
||||
at_list[m.key] = At(qq=m.id.open_id, name=m.name)
|
||||
if m.name == self.bot_name:
|
||||
abm.self_id = m.id.open_id
|
||||
|
||||
content_json_b = json.loads(message.content)
|
||||
|
||||
if message.message_type == 'text':
|
||||
message_str_raw = content_json_b['text'] # 带有 @ 的消息
|
||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||
at_users = re.findall(at_pattern, message_str_raw)
|
||||
# 拆分文本,去掉AT符号部分
|
||||
parts = re.split(at_pattern, message_str_raw)
|
||||
for i in range(len(parts)):
|
||||
s = parts[i].strip()
|
||||
if not s:
|
||||
continue
|
||||
if s in at_list:
|
||||
abm.message.append(at_list[s])
|
||||
else:
|
||||
abm.message.append(Plain(parts[i].strip()))
|
||||
elif message.message_type == 'post':
|
||||
_ls = []
|
||||
|
||||
content_ls = content_json_b.get('content', [])
|
||||
for comp in content_ls:
|
||||
if isinstance(comp, list):
|
||||
_ls.extend(comp)
|
||||
elif isinstance(comp, dict):
|
||||
_ls.append(comp)
|
||||
content_json_b = _ls
|
||||
elif message.message_type == 'image':
|
||||
content_json_b = [
|
||||
{"tag": "img", "image_key": content_json_b["image_key"], "style": []}
|
||||
]
|
||||
|
||||
if message.message_type in ('post', 'image'):
|
||||
for comp in content_json_b:
|
||||
if comp['tag'] == 'at':
|
||||
abm.message.append(at_list[comp['user_id']])
|
||||
elif comp['tag'] == 'text' and comp['text'].strip():
|
||||
abm.message.append(Plain(comp['text'].strip()))
|
||||
elif comp['tag'] == 'img':
|
||||
image_key = comp['image_key']
|
||||
request = GetMessageResourceRequest.builder() \
|
||||
.message_id(message.message_id) \
|
||||
.file_key(image_key) \
|
||||
.type("image") \
|
||||
.build()
|
||||
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法下载飞书图片: {image_key}")
|
||||
image_bytes = response.file.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode()
|
||||
abm.message.append(Image.fromBase64(image_base64))
|
||||
|
||||
for comp in abm.message:
|
||||
if isinstance(comp, Plain):
|
||||
abm.message_str += comp.text
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
abm.sender = MessageMember(
|
||||
user_id=event.event.sender.sender_id.open_id,
|
||||
nickname=event.event.sender.sender_id.open_id[:8]
|
||||
)
|
||||
# 独立会话
|
||||
if not self.unique_session:
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
logger.debug(abm)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
event = LarkMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=self.meta(),
|
||||
session_id=abm.session_id,
|
||||
bot=self.lark_api
|
||||
)
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
async def run(self):
|
||||
# self.client.start()
|
||||
await self.client._connect()
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
import uuid
|
||||
import lark_oapi as lark
|
||||
from typing import List
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image as AstrBotImage, Record, At, Node, Music, Video
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from astrbot import logger
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id, bot: lark.Client):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List:
|
||||
ret = []
|
||||
_stage = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
_stage.append({
|
||||
"tag": "md",
|
||||
"text": comp.text
|
||||
})
|
||||
elif isinstance(comp, At):
|
||||
_stage.append({
|
||||
"tag": "at",
|
||||
"user_id": comp.qq,
|
||||
"style": []
|
||||
})
|
||||
elif isinstance(comp, AstrBotImage):
|
||||
file_path = ""
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
file_path = comp.file.replace('file:///', '')
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(comp.file)
|
||||
file_path = image_file_path
|
||||
elif comp.file and comp.file.startswith("base64://"):
|
||||
pass
|
||||
else:
|
||||
file_path = comp.file
|
||||
|
||||
request = CreateImageRequest.builder() \
|
||||
.request_body( \
|
||||
CreateImageRequestBody.builder() \
|
||||
.image_type("message") \
|
||||
.image(open(file_path, 'rb')) \
|
||||
.build() \
|
||||
) \
|
||||
.build()
|
||||
response = await lark_client.im.v1.image.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||
image_key = response.data.image_key
|
||||
print(image_key)
|
||||
ret.append(_stage)
|
||||
ret.append([{
|
||||
"tag": "img",
|
||||
"image_key": image_key
|
||||
}])
|
||||
_stage.clear()
|
||||
else:
|
||||
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
|
||||
|
||||
if _stage:
|
||||
ret.append(_stage)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
}
|
||||
}
|
||||
|
||||
request = ReplyMessageRequest.builder() \
|
||||
.message_id(self.message_obj.message_id) \
|
||||
.request_body( \
|
||||
ReplyMessageRequestBody.builder() \
|
||||
.content(json.dumps(wrapped)) \
|
||||
.msg_type("post") \
|
||||
.uuid(str(uuid.uuid4())) \
|
||||
.reply_in_thread(False) \
|
||||
.build() \
|
||||
) \
|
||||
.build()
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send(message)
|
||||
@@ -8,6 +8,7 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Reply
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
@@ -114,4 +115,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
else:
|
||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
||||
image_file_path = i.file
|
||||
else:
|
||||
logger.error(f"qq_official 暂不支持发送消息类型 {i.type}")
|
||||
return plain_text, image_base64, image_file_path
|
||||
@@ -1,44 +0,0 @@
|
||||
import random
|
||||
import asyncio
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from vchat import Core
|
||||
|
||||
class VChatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(client: Core, message: MessageChain, user_name: str):
|
||||
plain = ""
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
if message.is_split_:
|
||||
await client.send_msg(comp.text, user_name)
|
||||
else:
|
||||
plain += comp.text
|
||||
elif isinstance(comp, Image):
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
with open(file_path, "rb") as f:
|
||||
await client.send_image(user_name, fd=f)
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_path = await download_image_by_url(comp.file)
|
||||
with open(image_path, "rb") as f:
|
||||
await client.send_image(user_name, fd=f)
|
||||
else:
|
||||
logger.error(f"不支持的 vchat(微信适配器) 消息类型: {comp}")
|
||||
await asyncio.sleep(random.uniform(0.5, 1.5)) # 🤓
|
||||
|
||||
if plain:
|
||||
await client.send_msg(plain, user_name)
|
||||
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await VChatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
|
||||
await super().send(message)
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .vchat_message_event import VChatPlatformEvent
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
from vchat import Core
|
||||
from vchat import model
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
@register_platform_adapter("vchat", "基于 VChat 的 Wechat 适配器")
|
||||
class VChatPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
from_username = session.session_id.split('$$')[0]
|
||||
await VChatPlatformEvent.send_with_client(self.client, message_chain, from_username)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"vchat",
|
||||
"基于 VChat 的 Wechat 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.client = Core()
|
||||
@self.client.msg_register(msg_types=model.ContentTypes.TEXT,
|
||||
contact_type=model.ContactTypes.CHATROOM | model.ContactTypes.USER)
|
||||
async def _(msg: model.Message):
|
||||
if isinstance(msg.content, model.UselessContent):
|
||||
return
|
||||
if msg.create_time < self.start_time:
|
||||
logger.debug(f"忽略旧消息: {msg}")
|
||||
return
|
||||
logger.debug(f"收到消息: {msg.todict()}")
|
||||
abmsg = self.convert_message(msg)
|
||||
# await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞
|
||||
asyncio.create_task(self.handle_msg(abmsg))
|
||||
|
||||
# TODO: 对齐微信服务器时间
|
||||
self.start_time = int(time.time())
|
||||
return self._run()
|
||||
|
||||
|
||||
async def _run(self):
|
||||
await self.client.init()
|
||||
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
|
||||
await self.client.run()
|
||||
|
||||
def convert_message(self, msg: model.Message) -> AstrBotMessage:
|
||||
# credits: https://github.com/z2z63/astrbot_plugin_vchat/blob/master/main.py#L49
|
||||
assert isinstance(msg.content, model.TextContent)
|
||||
amsg = AstrBotMessage()
|
||||
amsg.message = [Plain(msg.content.content)]
|
||||
amsg.self_id = self.client_self_id
|
||||
if msg.content.is_at_me:
|
||||
amsg.message.insert(0, At(qq=amsg.self_id))
|
||||
|
||||
sender = msg.chatroom_sender or msg.from_
|
||||
amsg.sender = MessageMember(sender.username, sender.nickname)
|
||||
|
||||
if msg.content.is_at_me:
|
||||
amsg.message_str = msg.content.content.split("\u2005")[1].strip()
|
||||
else:
|
||||
amsg.message_str = msg.content.content
|
||||
amsg.message_id = msg.message_id
|
||||
if isinstance(msg.from_, model.User):
|
||||
amsg.type = MessageType.FRIEND_MESSAGE
|
||||
elif isinstance(msg.from_, model.Chatroom):
|
||||
amsg.type = MessageType.GROUP_MESSAGE
|
||||
amsg.group_id = msg.from_.username
|
||||
else:
|
||||
logger.error(f"不支持的 Wechat 消息类型: {msg.from_}")
|
||||
|
||||
amsg.raw_message = msg
|
||||
|
||||
if self.settingss['unique_session']:
|
||||
session_id = msg.from_.username + "$$" + msg.to.username
|
||||
if msg.chatroom_sender is not None:
|
||||
session_id += '$$' + msg.chatroom_sender.username
|
||||
else:
|
||||
session_id = msg.from_.username
|
||||
|
||||
amsg.session_id = session_id
|
||||
return amsg
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = VChatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
logger.info(f"处理消息: {message_event}")
|
||||
|
||||
self.commit_event(message_event)
|
||||
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
import uuid
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
@@ -37,5 +38,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
else:
|
||||
logger.error(f"webchat 暂不支持发送消息类型: {comp.type}")
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -52,7 +52,7 @@ class ProviderRequest():
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
'''角色'''
|
||||
'''角色, assistant, tool, err'''
|
||||
completion_text: str = ""
|
||||
'''LLM 返回的文本'''
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
|
||||
@@ -114,22 +114,24 @@ class ProviderManager():
|
||||
try:
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify # noqa: F401
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
|
||||
from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
|
||||
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
||||
continue
|
||||
@@ -160,7 +162,7 @@ class ProviderManager():
|
||||
continue
|
||||
|
||||
provider_metadata = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
try:
|
||||
# 按任务实例化提供商
|
||||
|
||||
|
||||
@@ -31,7 +31,9 @@ class ProviderDify(Provider):
|
||||
raise Exception("Dify API 类型不能为空。")
|
||||
self.model_name = "dify"
|
||||
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.conversation_ids = {}
|
||||
|
||||
|
||||
@@ -78,7 +80,8 @@ class ProviderDify(Provider):
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk['event'] == "message" or \
|
||||
@@ -96,7 +99,8 @@ class ProviderDify(Provider):
|
||||
**session_var
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload
|
||||
files=files_payload,
|
||||
timeout=self.timeout
|
||||
):
|
||||
match chunk['event']:
|
||||
case "workflow_started":
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
import uuid
|
||||
import ormsgpack
|
||||
from pydantic import BaseModel, conint
|
||||
from httpx import AsyncClient
|
||||
from typing import Annotated, Literal
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
audio: bytes
|
||||
text: str
|
||||
|
||||
|
||||
class ServeTTSRequest(BaseModel):
|
||||
text: str
|
||||
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
||||
# 音频格式
|
||||
format: Literal["wav", "pcm", "mp3"] = "mp3"
|
||||
mp3_bitrate: Literal[64, 128, 192] = 128
|
||||
# 参考音频
|
||||
references: list[ServeReferenceAudio] = []
|
||||
# 参考模型 ID
|
||||
# 例如 https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
||||
# 其中reference_id为 7f92f8afb8ec43bf81429cc1c9199cb1
|
||||
reference_id: str | None = None
|
||||
# 对中英文文本进行标准化,这可以提高数字的稳定性
|
||||
normalize: bool = True
|
||||
# 平衡模式将延迟减少到300毫秒,但可能会降低稳定性
|
||||
latency: Literal["normal", "balanced"] = "normal"
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"fishaudio_tts_api", "FishAudio TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.character: str = provider_config.get("fishaudio-tts-character", "可莉")
|
||||
self.api_base: str = provider_config.get(
|
||||
"api_base", "https://api.fish-audio.cn/v1"
|
||||
)
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str:
|
||||
"""
|
||||
获取角色的reference_id
|
||||
|
||||
Args:
|
||||
character: 角色名称
|
||||
|
||||
Returns:
|
||||
reference_id: 角色的reference_id
|
||||
|
||||
exception:
|
||||
APIException: 获取语音角色列表为空
|
||||
"""
|
||||
sort_options = ["score", "task_count", "created_at"]
|
||||
async with AsyncClient(base_url=self.api_base.replace("/v1", "")) as client:
|
||||
for sort_by in sort_options:
|
||||
params = {"title": character, "sort_by": sort_by}
|
||||
response = await client.get(
|
||||
"/model", params=params, headers=self.headers
|
||||
)
|
||||
resp_data = response.json()
|
||||
if resp_data["total"] == 0:
|
||||
continue
|
||||
for item in resp_data["items"]:
|
||||
if character in item["title"]:
|
||||
return item["_id"]
|
||||
return None
|
||||
|
||||
async def _generate_request(self, text: str) -> dict:
|
||||
return ServeTTSRequest(
|
||||
text=text,
|
||||
format="wav",
|
||||
reference_id=await self._get_reference_id_by_character(self.character),
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav"
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
"POST",
|
||||
"/tts",
|
||||
headers=self.headers,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
) as response:
|
||||
if response.headers["content-type"] == "audio/wav":
|
||||
with open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
text = await response.aread()
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
@@ -48,8 +48,18 @@ class SimpleGoogleGenAIClient():
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
async with self.client.post(request_url, json=payload, timeout=self.timeout) as resp:
|
||||
response = await resp.json()
|
||||
return response
|
||||
if "application/json" in resp.headers.get("Content-Type"):
|
||||
try:
|
||||
response = await resp.json()
|
||||
except Exception as e:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise e
|
||||
return response
|
||||
else:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise Exception("Gemini 返回了非 json 数据: ")
|
||||
|
||||
|
||||
@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器")
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI, NOT_GIVEN
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
@@ -100,18 +100,23 @@ class ProviderOpenAIOfficial(Provider):
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
|
||||
if choice.finish_reason == 'content_filter':
|
||||
raise Exception("API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。")
|
||||
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
llm_response.raw_completion = completion
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str=None,
|
||||
image_urls: List[str]=None,
|
||||
image_urls: List[str]=[],
|
||||
func_tool: FuncCall=None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
@@ -173,7 +178,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
or 'Function call is not supported' in str(e) \
|
||||
or 'Function calling is not enabled' in str(e) \
|
||||
or 'Tool calling is not supported' in str(e) \
|
||||
or 'No endpoints found that support tool use' in str(e): # siliconcloud
|
||||
or 'No endpoints found that support tool use' in str(e) \
|
||||
or 'model does not support function calling' in str(e) \
|
||||
or ('tool' in str(e) and 'support' in str(e).lower()) \
|
||||
or ('function' in str(e) and 'support' in str(e).lower()):
|
||||
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
|
||||
if 'tools' in payloads:
|
||||
del payloads['tools']
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, TypedDict, Union
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
@@ -127,6 +127,14 @@ class Context:
|
||||
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_all_tts_providers(self) -> List[TTSProvider]:
|
||||
'''获取所有用于 TTS 任务的 Provider。'''
|
||||
return self.provider_manager.tts_provider_insts
|
||||
|
||||
def get_all_stt_providers(self) -> List[STTProvider]:
|
||||
'''获取所有用于 STT 任务的 Provider。'''
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
@@ -135,6 +143,18 @@ class Context:
|
||||
'''
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_using_tts_provider(self) -> TTSProvider:
|
||||
'''
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
'''
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
|
||||
def get_using_stt_provider(self) -> STTProvider:
|
||||
'''
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
'''
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''获取 AstrBot 的配置。'''
|
||||
return self._config
|
||||
|
||||
@@ -43,7 +43,7 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
return self.handler_md
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_wake_up():
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
|
||||
if event.get_extra("parsing_command"):
|
||||
|
||||
@@ -37,7 +37,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
return result
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]:
|
||||
if not event.is_wake_up():
|
||||
if not event.is_at_or_wake_command:
|
||||
return False, None
|
||||
|
||||
if event.get_extra("parsing_command"):
|
||||
|
||||
@@ -68,12 +68,14 @@ def register_command(command_name: str = None, *args, **kwargs):
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(awaitable):
|
||||
if not add_to_event_filters:
|
||||
kwargs['sub_command'] = True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
new_command.init_handler_md(handler_md)
|
||||
if add_to_event_filters:
|
||||
# 裸指令
|
||||
handler_md.event_filters.append(new_command)
|
||||
|
||||
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -116,7 +118,7 @@ class RegisteringCommandable():
|
||||
def register_event_message_type(event_message_type: EventMessageType, **kwargs):
|
||||
'''注册一个 EventMessageType'''
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, kwargs)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
|
||||
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
|
||||
return awaitable
|
||||
|
||||
|
||||
@@ -340,8 +340,8 @@ class PluginManager:
|
||||
self.failed_plugin_info = fail_rec
|
||||
return False, fail_rec
|
||||
|
||||
async def install_plugin(self, repo_url: str):
|
||||
plugin_path = await self.updator.install(repo_url)
|
||||
async def install_plugin(self, repo_url: str, proxy=""):
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
await self.reload()
|
||||
return plugin_path
|
||||
@@ -376,14 +376,14 @@ class PluginManager:
|
||||
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
async def update_plugin(self, plugin_name: str):
|
||||
async def update_plugin(self, plugin_name: str, proxy = ""):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||
|
||||
await self.updator.update(plugin)
|
||||
await self.updator.update(plugin, proxy=proxy)
|
||||
await self.reload()
|
||||
|
||||
async def turn_off_plugin(self, plugin_name: str):
|
||||
@@ -428,6 +428,7 @@ class PluginManager:
|
||||
|
||||
async def install_plugin_from_file(self, zip_file_path: str):
|
||||
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
|
||||
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
|
||||
desti_dir = os.path.join(self.plugin_store_path, dir_name)
|
||||
self.updator.unzip_file(zip_file_path, desti_dir)
|
||||
|
||||
|
||||
@@ -15,20 +15,24 @@ class PluginUpdator(RepoZipUpdator):
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
async def install(self, repo_url: str) -> str:
|
||||
async def install(self, repo_url: str, proxy="") -> str:
|
||||
repo_name = self.format_repo_name(repo_url)
|
||||
plugin_path = os.path.join(self.plugin_store_path, repo_name)
|
||||
await self.download_from_repo_url(plugin_path, repo_url)
|
||||
await self.download_from_repo_url(plugin_path, repo_url, proxy)
|
||||
self.unzip_file(plugin_path + ".zip", plugin_path)
|
||||
|
||||
return plugin_path
|
||||
|
||||
async def update(self, plugin: StarMetadata) -> str:
|
||||
async def update(self, plugin: StarMetadata, proxy="") -> str:
|
||||
repo_url = plugin.repo
|
||||
|
||||
if not repo_url:
|
||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
repo_url = f"{proxy}/{repo_url}"
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
|
||||
@@ -22,21 +22,27 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
|
||||
|
||||
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
'''返回 duration'''
|
||||
import pysilk
|
||||
|
||||
with wave.open(wav_path, 'rb') as wav:
|
||||
wav_data = wav.readframes(wav.getnframes())
|
||||
wav_data = BytesIO(wav_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.encode(wav_data, output_io, 24000, 24000)
|
||||
output_io.seek(0)
|
||||
try:
|
||||
import pilk
|
||||
except (ImportError, ModuleNotFoundError) as _:
|
||||
raise Exception("pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库")
|
||||
# with wave.open(wav_path, 'rb') as wav:
|
||||
# wav_data = wav.readframes(wav.getnframes())
|
||||
# wav_data = BytesIO(wav_data)
|
||||
# output_io = BytesIO()
|
||||
# pysilk.encode(wav_data, output_io, 24000, 24000)
|
||||
# output_io.seek(0)
|
||||
|
||||
# 在首字节添加 \x02,去除结尾的\xff\xff
|
||||
silk_data = output_io.read()
|
||||
silk_data_with_prefix = b'\x02' + silk_data[:-2]
|
||||
# # 在首字节添加 \x02,去除结尾的\xff\xff
|
||||
# silk_data = output_io.read()
|
||||
# silk_data_with_prefix = b'\x02' + silk_data[:-2]
|
||||
|
||||
# return BytesIO(silk_data_with_prefix)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(silk_data_with_prefix)
|
||||
# # return BytesIO(silk_data_with_prefix)
|
||||
# with open(output_path, "wb") as f:
|
||||
# f.write(silk_data_with_prefix)
|
||||
|
||||
return 0
|
||||
# return 0
|
||||
with wave.open(wav_path, 'rb') as wav:
|
||||
rate = wav.getframerate()
|
||||
duration = pilk.encode(wav_path, output_path, pcm_rate=rate, tencent=True)
|
||||
return duration
|
||||
@@ -100,7 +100,7 @@ class RepoZipUpdator():
|
||||
body=update_data[0]['body']
|
||||
)
|
||||
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str):
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""):
|
||||
repo_namespace = repo_url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
@@ -110,19 +110,23 @@ class RepoZipUpdator():
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.info(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
logger.info(f"正在从默认分支下载 {author}/{repo} ")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
# 镜像站点
|
||||
match self.repo_mirror:
|
||||
case 'https://github-mirror.us.kg/':
|
||||
release_url = self.repo_mirror + release_url
|
||||
case "https://ghp.ci/":
|
||||
release_url = self.repo_mirror + release_url
|
||||
case _:
|
||||
pass
|
||||
# match self.repo_mirror:
|
||||
# case 'https://github-mirror.us.kg/':
|
||||
# release_url = self.repo_mirror + release_url
|
||||
# case "https://ghp.ci/":
|
||||
# release_url = self.repo_mirror + release_url
|
||||
# case _:
|
||||
# pass
|
||||
|
||||
if proxy:
|
||||
release_url = f"{proxy}/{release_url}"
|
||||
logger.info(f"使用代理下载: {release_url}")
|
||||
|
||||
await download_file(release_url, target_path + ".zip")
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import traceback
|
||||
import aiohttp
|
||||
import uuid
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import request
|
||||
@@ -49,7 +48,7 @@ class PluginRoute(Route):
|
||||
return Response().error(message).__dict__
|
||||
return Response().ok(None, "重载成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/reload: {traceback.format_exc()}")
|
||||
logger.error(f"/api/plugin/reload: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_online_plugins(self):
|
||||
@@ -59,7 +58,6 @@ class PluginRoute(Route):
|
||||
urls = [custom]
|
||||
else:
|
||||
urls = [
|
||||
"https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json",
|
||||
"https://api.soulter.top/astrbot/plugins"
|
||||
]
|
||||
|
||||
@@ -88,6 +86,7 @@ class PluginRoute(Route):
|
||||
"version": plugin.version,
|
||||
"reserved": plugin.reserved,
|
||||
"activated": plugin.activated,
|
||||
"online_vesion": "",
|
||||
"handlers": await self.get_plugin_handlers_info(plugin.star_handler_full_names),
|
||||
}
|
||||
_plugin_resp.append(_t)
|
||||
@@ -143,9 +142,14 @@ class PluginRoute(Route):
|
||||
async def install_plugin(self):
|
||||
post_data = await request.json
|
||||
repo_url = post_data["url"]
|
||||
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
await self.plugin_manager.install_plugin(repo_url)
|
||||
await self.plugin_manager.install_plugin(repo_url, proxy)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {repo_url} 成功。")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
@@ -183,14 +187,15 @@ class PluginRoute(Route):
|
||||
async def update_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
await self.plugin_manager.update_plugin(plugin_name)
|
||||
await self.plugin_manager.update_plugin(plugin_name, proxy)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"更新插件 {plugin_name} 成功。")
|
||||
return Response().ok(None, "更新成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
|
||||
logger.error(f"/api/plugin/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def off_plugin(self):
|
||||
@@ -201,7 +206,7 @@ class PluginRoute(Route):
|
||||
logger.info(f"停用插件 {plugin_name} 。")
|
||||
return Response().ok(None, "停用成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/off: {traceback.format_exc()}")
|
||||
logger.error(f"/api/plugin/off: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def on_plugin(self):
|
||||
@@ -212,5 +217,5 @@ class PluginRoute(Route):
|
||||
logger.info(f"启用插件 {plugin_name} 。")
|
||||
return Response().ok(None, "启用成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/on: {traceback.format_exc()}")
|
||||
logger.error(f"/api/plugin/on: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -0,0 +1,11 @@
|
||||
# What's Changed
|
||||
|
||||
0. ✨ 新增: 支持 海豚 AI(FishAudio) TTS API #433 by @Cvandia
|
||||
1. 🐛 修复: 当群聊主动回复时,不会带上人格的Prompt #419
|
||||
2. ✨ 新增: 支持展示插件是否有更新
|
||||
3. 👌 优化: 增加DIFY超时时间 #422
|
||||
4. 🐛 修复: 自部署文转图不生效 #352
|
||||
5. 🐛 修复: 修复 qq 回复别人的时候也会触发机器人, Onebot at 使用 string #330
|
||||
6. 👌 优化: 增加DIFY超时时间 #422
|
||||
7. 🐛 修复: 重启gewe的时候机器人会疯狂发消息 #421
|
||||
8. 🐛 修复: 修复子指令设置permission之后会导致其一定会被执行 #427
|
||||
@@ -0,0 +1,11 @@
|
||||
# What's Changed
|
||||
|
||||
0. ✨ 新增: 支持正则表达式匹配触发机器人,机器人在某一段时间内持续唤醒(不用输唤醒词)。(安装 astrbot_plugin_wake_enhance 插件)
|
||||
2. ✨ 新增: 可以通过 /tts 开关TTS,通过 /provider 更换 TTS #436
|
||||
3. ✨ 新增: 管理面板支持设置 GitHub 反向代理地址以优化中国大陆地区下载 AstrBot 插件的速度。(在管理面板-设置页)
|
||||
4. 🐛 修复: 修复指令不经过唤醒前缀也能生效的问题。在引用消息的时候无法使用前缀唤醒机器人 #444
|
||||
5. 🐛 修复: 修复 Napcat 下戳一戳消息报错
|
||||
6. 👌 优化: 从压缩包上传插件时,去除仓库 -branch 尾缀
|
||||
7. 🐛 修复: gemini 报错时显示 apikey
|
||||
8. 🐛 修复: drun 不支持函数调用的报错
|
||||
9. 🐛 修复: raw_completion 没有正确传递导致部分插件无法正常运作 #439
|
||||
@@ -0,0 +1,9 @@
|
||||
# What's Changed
|
||||
|
||||
1. ✨ 新增: 支持接入飞书(Lark)。支持飞书文字、图片。
|
||||
2. ✨ 新增: 添加月之暗面配置模板 #446
|
||||
3. ✨ 新增: Gewechat 支持文件输出
|
||||
4. 🐛 修复: 修复gewechat无法at人和发语音失败的问题 #447 #438
|
||||
5. 🐛 修复: 修复qq在@和回复开启的情况下转发消息异常的问题
|
||||
6. 🐛 修复: GitHub 加速镜像没有正确被应用
|
||||
7. 🐛 优化: 平台将显示不受支持的消息段
|
||||
@@ -2,7 +2,8 @@
|
||||
const props = defineProps({
|
||||
title: String,
|
||||
link: String,
|
||||
logo: String
|
||||
logo: String,
|
||||
has_update: Boolean,
|
||||
});
|
||||
|
||||
const open = (link: string | undefined) => {
|
||||
@@ -17,6 +18,7 @@ const open = (link: string | undefined) => {
|
||||
<img v-if="logo" :src="logo" alt="logo" style="width: 40px; height: 40px; margin-right: 8px;">
|
||||
<v-card-title style="font-size: 16px;">{{ props.title }}</v-card-title>
|
||||
<v-spacer></v-spacer>
|
||||
<v-icon color="success" v-if="has_update">mdi-arrow-up-bold</v-icon>
|
||||
<v-btn size="small" text="Read" variant="flat" border @click="open(props.link)">帮助</v-btn>
|
||||
</div>
|
||||
</v-card-item>
|
||||
|
||||
@@ -16,12 +16,12 @@ export interface menu {
|
||||
|
||||
const sidebarItem: menu[] = [
|
||||
{
|
||||
title: '面板',
|
||||
title: '统计',
|
||||
icon: 'mdi-view-dashboard',
|
||||
to: '/dashboard/default'
|
||||
},
|
||||
{
|
||||
title: '配置',
|
||||
title: '配置文件',
|
||||
icon: 'mdi-cog',
|
||||
to: '/config',
|
||||
},
|
||||
@@ -40,6 +40,11 @@ const sidebarItem: menu[] = [
|
||||
icon: 'mdi-console',
|
||||
to: '/console'
|
||||
},
|
||||
{
|
||||
title: '设置',
|
||||
icon: 'mdi-wrench',
|
||||
to: '/settings'
|
||||
},
|
||||
{
|
||||
title: '关于',
|
||||
icon: 'mdi-information',
|
||||
|
||||
@@ -42,6 +42,11 @@ const MainRoutes = {
|
||||
path: '/chat',
|
||||
component: () => import('@/views/ChatPage.vue')
|
||||
},
|
||||
{
|
||||
name: 'Settings',
|
||||
path: '/settings',
|
||||
component: () => import('@/views/Settings.vue')
|
||||
},
|
||||
{
|
||||
name: 'About',
|
||||
path: '/about',
|
||||
|
||||
@@ -10,8 +10,8 @@ import { max } from 'date-fns';
|
||||
|
||||
<template>
|
||||
<v-row>
|
||||
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README" title="💡提示"
|
||||
type="info" variant="tonal">
|
||||
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,点击设置页选择 GitHub 加速地址。或前往仓库下载压缩包然后本地上传。" title="💡提示"
|
||||
type="info" color="primary" variant="tonal">
|
||||
</v-alert>
|
||||
<v-col cols="12" md="12">
|
||||
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
|
||||
@@ -44,13 +44,22 @@ import { max } from 'date-fns';
|
||||
</v-dialog>
|
||||
</div>
|
||||
</div>
|
||||
</v-col>
|
||||
</v-col>
|
||||
<v-col cols="12" md="6" lg="3" v-for="extension in extension_data.data">
|
||||
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" :logo="extension?.logo"
|
||||
style="margin-bottom: 4px;">
|
||||
<div style="min-height: 135px; max-height: 135px; overflow: none;">
|
||||
<span style="font-weight: bold;">By @{{ extension.author }}</span>
|
||||
<span> | 插件有 {{ extension.handlers.length }} 个行为</span>
|
||||
:has_update="extension.has_update" style="margin-bottom: 4px;">
|
||||
<div style="min-height: 140px; max-height: 140px; overflow: auto;">
|
||||
<div>
|
||||
<span style="font-weight: bold ;">By @{{ extension.author }}</span>
|
||||
<span> | 插件有 {{ extension.handlers.length }} 个行为</span>
|
||||
</div>
|
||||
<span> 当前: <v-chip size="small" color="primary">{{ extension.version }}</v-chip>
|
||||
<span v-if="extension.online_version">
|
||||
| 最新: <v-chip size="small" color="primary">{{ extension.online_version }}</v-chip>
|
||||
</span>
|
||||
<span v-if="extension.has_update" style="font-weight: bold;">有更新
|
||||
</span>
|
||||
</span>
|
||||
<p style="margin-top: 8px;">{{ extension.desc }}</p>
|
||||
<a style="font-size: 12px; cursor: pointer; text-decoration: underline; color: #555;"
|
||||
@click="reloadPlugin(extension.name)">重载插件</a>
|
||||
@@ -329,6 +338,7 @@ export default {
|
||||
{ title: '作者', value: 'author' },
|
||||
{ title: '操作', value: 'actions', sortable: false }
|
||||
],
|
||||
alreadyCheckUpdate: false
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
@@ -367,10 +377,29 @@ export default {
|
||||
getExtensions() {
|
||||
axios.get('/api/plugin/get').then((res) => {
|
||||
this.extension_data = res.data;
|
||||
|
||||
this.checkAlreadyInstalled();
|
||||
this.checkUpdate()
|
||||
});
|
||||
},
|
||||
|
||||
checkUpdate() {
|
||||
// 遍历 extension_data 和 pluginMarketData,检查是否有更新\
|
||||
for (let i = 0; i < this.extension_data.data.length; i++) {
|
||||
for (let j = 0; j < this.pluginMarketData.length; j++) {
|
||||
console.log(this.extension_data.data[i].repo, this.pluginMarketData[j].repo);
|
||||
if (this.extension_data.data[i].repo === this.pluginMarketData[j].repo ||
|
||||
this.extension_data.data[i].name === this.pluginMarketData[j].name) {
|
||||
this.extension_data.data[i].online_version = this.pluginMarketData[j].version;
|
||||
if (this.extension_data.data[i].version !== this.pluginMarketData[j].version && this.pluginMarketData[j].version !== "未知") {
|
||||
this.extension_data.data[i].has_update = true;
|
||||
} else {
|
||||
this.extension_data.data[i].has_update = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
newExtension() {
|
||||
if (this.extension_url === "" && this.upload_file === null) {
|
||||
this.toast("请填写插件链接或上传插件文件", "error");
|
||||
@@ -411,7 +440,8 @@ export default {
|
||||
this.toast("正在从链接 " + this.extension_url + " 安装插件...", "primary");
|
||||
axios.post('/api/plugin/install',
|
||||
{
|
||||
url: this.extension_url
|
||||
url: this.extension_url,
|
||||
proxy: localStorage.getItem('selectedGitHubProxy') || ""
|
||||
}).then((res) => {
|
||||
this.loading_ = false;
|
||||
if (res.data.status === "error") {
|
||||
@@ -452,7 +482,8 @@ export default {
|
||||
this.loadingDialog.show = true;
|
||||
axios.post('/api/plugin/update',
|
||||
{
|
||||
name: extension_name
|
||||
name: extension_name,
|
||||
proxy: localStorage.getItem('selectedGitHubProxy') || ""
|
||||
}).then((res) => {
|
||||
if (res.data.status === "error") {
|
||||
this.onLoadingDialogResult(2, res.data.message, -1);
|
||||
@@ -529,11 +560,13 @@ export default {
|
||||
"desc": res.data.data[key].desc,
|
||||
"author": res.data.data[key].author,
|
||||
"repo": res.data.data[key].repo,
|
||||
"installed": false
|
||||
"installed": false,
|
||||
"version": res.data.data[key]?.version ? res.data.data[key].version : "未知",
|
||||
})
|
||||
}
|
||||
this.pluginMarketData = data;
|
||||
this.checkAlreadyInstalled();
|
||||
this.checkUpdate();
|
||||
}).catch((err) => {
|
||||
this.toast("获取插件市场数据失败: " + err, "error");
|
||||
});
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
<template>
|
||||
|
||||
<div style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px;">
|
||||
|
||||
<v-list lines="two">
|
||||
<v-list-subheader>网络</v-list-subheader>
|
||||
|
||||
<v-list-item subtitle="设置下载插件时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效" title="GitHub 加速地址">
|
||||
|
||||
<v-combobox variant="outlined" style="width: 100%; margin-top: 16px;" v-model="selectedGitHubProxy" :items="githubProxies"
|
||||
label="选择 GitHub 加速地址">
|
||||
</v-combobox>
|
||||
</v-list-item>
|
||||
|
||||
|
||||
</v-list>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
</template>
|
||||
|
||||
<script>
|
||||
export default {
|
||||
data() {
|
||||
return {
|
||||
githubProxies: [
|
||||
"https://ghproxy.cn",
|
||||
"https://gh.llkk.cc",
|
||||
"https://ghproxy.net",
|
||||
"https://gitproxy.click",
|
||||
"https://github.tbedu.top"
|
||||
],
|
||||
selectedGitHubProxy: "",
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
|
||||
},
|
||||
mounted() {
|
||||
this.selectedGitHubProxy = localStorage.getItem('selectedGitHubProxy') || "";
|
||||
},
|
||||
watch: {
|
||||
selectedGitHubProxy: function (newVal, oldVal) {
|
||||
if (!newVal) {
|
||||
newVal = ""
|
||||
}
|
||||
localStorage.setItem('selectedGitHubProxy', newVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
+94
-24
@@ -6,6 +6,7 @@ import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
@@ -59,6 +60,7 @@ AstrBot 指令:
|
||||
[System]
|
||||
/plugin: 查看插件、插件帮助
|
||||
/t2i: 开关文本转图片
|
||||
/tts: 开关文本转语音
|
||||
/sid: 获取会话 ID
|
||||
/op <admin_id>: 授权管理员(op)
|
||||
/deop <admin_id>: 取消管理员(op)
|
||||
@@ -72,8 +74,8 @@ AstrBot 指令:
|
||||
/model: 模型列表
|
||||
/ls: 对话列表
|
||||
/new: 创建新对话
|
||||
/switch: 切换对话
|
||||
/rename: 重命名对话
|
||||
/switch 序号: 切换对话
|
||||
/rename 新名字: 重命名当前对话
|
||||
/del: 删除当前会话对话(op)
|
||||
/reset: 重置 LLM 会话(op)
|
||||
/history: 当前对话的对话记录
|
||||
@@ -83,10 +85,7 @@ AstrBot 指令:
|
||||
/websearch: 网页搜索
|
||||
|
||||
[其他]
|
||||
/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。
|
||||
/unset <变量名>: 删除会话的变量。
|
||||
|
||||
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
/set 变量名 值: 为会话定义变量(Dify 工作流输入)
|
||||
{notice}"""
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
@@ -125,7 +124,7 @@ AstrBot 指令:
|
||||
tm = self.context.get_llm_tool_manager()
|
||||
for tool in tm.func_list:
|
||||
self.context.deactivate_llm_tool(tool.name)
|
||||
event.set_result(MessageEventResult().message(f"停用所有工具成功。"))
|
||||
event.set_result(MessageEventResult().message("停用所有工具成功。"))
|
||||
|
||||
@filter.command("plugin")
|
||||
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
|
||||
@@ -200,6 +199,18 @@ AstrBot 指令:
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
|
||||
|
||||
@filter.command("tts")
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
config = self.context.get_config()
|
||||
if config['provider_tts_settings']['enable']:
|
||||
config['provider_tts_settings']['enable'] = False
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已关闭文本转语音。"))
|
||||
return
|
||||
config['provider_tts_settings']['enable'] = True
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已开启文本转语音。"))
|
||||
|
||||
@filter.command("sid")
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
sid = event.unified_msg_origin
|
||||
@@ -245,34 +256,89 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
@filter.command("provider")
|
||||
async def provider(self, event: AstrMessageEvent, idx: int = None):
|
||||
async def provider(self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None):
|
||||
'''查看或者切换 LLM Provider'''
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
if idx is None:
|
||||
ret = "## 当前载入的 LLM 提供商\n"
|
||||
if idx is None:
|
||||
ret = "## 载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
id_ = llm.meta().id
|
||||
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
|
||||
if self.context.get_using_provider().meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
tts_providers = self.context.get_all_tts_providers()
|
||||
if tts_providers:
|
||||
ret += "\n## 载入的 TTS 提供商\n"
|
||||
for idx, tts in enumerate(tts_providers):
|
||||
id_ = tts.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
tts_using = self.context.get_using_tts_provider()
|
||||
if tts_using and tts_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
stt_providers = self.context.get_all_stt_providers()
|
||||
if stt_providers:
|
||||
ret += "\n## 载入的 STT 提供商\n"
|
||||
for idx, stt in enumerate(stt_providers):
|
||||
id_ = stt.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
stt_using = self.context.get_using_stt_provider()
|
||||
if stt_using and stt_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
ret += "\n使用 /provider <序号> 切换提供商。"
|
||||
ret += "\n使用 /provider <序号> 切换 LLM 提供商。"
|
||||
|
||||
if tts_providers:
|
||||
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
||||
if stt_providers:
|
||||
ret += "\n使用 /provider stt <切换> STT 提供商。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
else:
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
if idx == "tts":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_tts_provider_inst = provider
|
||||
sp.put("curr_provider_tts", id_)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif idx == "stt":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_stt_provider_inst = provider
|
||||
sp.put("curr_provider_stt", id_)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_provider_inst = provider
|
||||
sp.put("curr_provider", id_)
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_provider_inst = provider
|
||||
sp.put("curr_provider", id_)
|
||||
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("reset")
|
||||
@@ -429,8 +495,11 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
message.set_result(MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"))
|
||||
|
||||
@filter.command("switch")
|
||||
async def switch_conv(self, message: AstrMessageEvent, index: int):
|
||||
async def switch_conv(self, message: AstrMessageEvent, index: int = None):
|
||||
'''通过 /ls 前面的序号切换对话'''
|
||||
if index is None:
|
||||
message.set_result(MessageEventResult().message("请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话"))
|
||||
return
|
||||
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
|
||||
if index > len(conversations) or index < 1:
|
||||
message.set_result(MessageEventResult().message("对话序号错误,请使用 /ls 查看"))
|
||||
@@ -581,7 +650,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
sp.put("session_variables", session_vars)
|
||||
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。使用 /unset 移除。")
|
||||
|
||||
@filter.command("unset")
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
@@ -591,7 +660,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
session_var = session_vars.get(session_id, {})
|
||||
|
||||
if key not in session_var:
|
||||
yield event.plain_result("没有那个变量名。")
|
||||
yield event.plain_result("没有那个变量名。格式 /unset 变量名。")
|
||||
else:
|
||||
del session_var[key]
|
||||
sp.put("session_variables", session_vars)
|
||||
@@ -632,7 +701,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin)
|
||||
|
||||
if not session_curr_cid:
|
||||
logger.error("当前未处于对话状态,无法主动回复,请使用 /switch 切换或者 /new 创建。")
|
||||
logger.error("当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 切换或者 /new 创建一个会话。")
|
||||
return
|
||||
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
@@ -649,7 +718,8 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
contexts=history if history else []
|
||||
contexts=history if history else [],
|
||||
conversation=conv,
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
@@ -793,4 +863,4 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
# if results:
|
||||
# req.system_prompt += "\nHere are documents that related to user's query: \n"
|
||||
# for result in results:
|
||||
# req.system_prompt += f"- {result}\n"
|
||||
# req.system_prompt += f"- {result}\n"7
|
||||
+4
-1
@@ -16,4 +16,7 @@ pyjwt
|
||||
apscheduler
|
||||
docstring_parser
|
||||
aiodocker
|
||||
silk-python
|
||||
silk-python
|
||||
|
||||
lark-oapi
|
||||
ormsgpack
|
||||
Reference in New Issue
Block a user