Compare commits

...

34 Commits

Author SHA1 Message Date
Soulter 3c28001a74 v3.4.16 2025-02-01 19:31:59 +08:00
Soulter 76a6218be6 fix: 修复webui无法从本地上传插件的问题 2025-02-01 19:31:29 +08:00
Soulter 6c1de1bbd6 Update README.md 2025-02-01 16:19:01 +08:00
Soulter d7678081da perf: Provider 重复时不直接报错闪退 #265 2025-02-01 14:36:41 +08:00
Soulter 5e4ba563cb perf: 弱化更新报错 #267 2025-02-01 14:29:39 +08:00
Soulter 8afbe77b0a Update README.md 2025-02-01 12:11:58 +08:00
Soulter 2ef139b59a fix: 修复每次启动astrbot都需要微信扫码的问题 2025-01-31 01:28:49 +08:00
Soulter 1f0d2d9b89 fix: QQ官方机器人开启 reply with metion 和 reply with quote 后,无法正常回复消息 #244 2025-01-30 01:36:25 +08:00
Soulter 37a1f144ab chore: update changelog of 3.4.15 2025-01-30 00:32:50 +08:00
Soulter 9a7a654596 perf: 插件处于禁用状态时其所属的函数调用工具不可被启用 #254 2025-01-30 00:27:10 +08:00
Soulter 9abccd63cf chore: remove stt.py 2025-01-29 23:47:50 +08:00
Soulter 93fea77182 chore: bump to v3.4.15 2025-01-29 23:43:09 +08:00
Soulter 19797243f6 perf: 增加插件链接 2025-01-29 19:56:09 +08:00
Soulter c9c733d925 Merge branch 'dev' 2025-01-29 19:43:52 +08:00
Soulter a7d7678c78 fix: 修复白名单为空时依然终止事件 #259 2025-01-29 17:17:27 +08:00
Soulter c0911921c7 feat: 配置Schema以及插件支持配置 2025-01-29 16:54:57 +08:00
Soulter 4a4241d57a Update README.md 2025-01-29 13:26:51 +08:00
Soulter c9426bb6eb config 2025-01-29 12:25:54 +08:00
Soulter db4abd169a fix: 优化分段回复 2025-01-28 14:42:15 +08:00
Soulter 80b6958599 fix: 修复 config validator 不起效的问题 2025-01-28 14:18:21 +08:00
Soulter 80058c781a fix: 修复r1思考标签问题和分段回复间隔时间问题 2025-01-28 14:03:10 +08:00
Soulter 44bd2e36f3 Update README.md 2025-01-28 02:15:11 +08:00
Soulter 3589a5e5be perf: 强化ltm异常处理 2025-01-27 21:47:35 +08:00
Soulter 13ef033f0e fix: 群聊增强的参数类型转换 2025-01-27 21:40:20 +08:00
Soulter 3f8c68bbca fix: f-string expression part cannot include a backslash
long_term_memory.py, line 69
2025-01-27 21:01:50 +08:00
Soulter 4275cea82b chore: v3.4.14 2025-01-27 20:09:03 +08:00
Soulter a0bcb5339a perf: 自动删除 deepseek-r1 模型自带的 think 标签 2025-01-27 20:04:39 +08:00
Soulter 43deec4a4b Merge pull request #255 from Soulter/feat-ltm
支持记录非唤醒状态下群聊历史记录
2025-01-27 20:02:43 +08:00
Soulter 2bc433a30b feat: 支持记录非唤醒状态下群聊历史记录 2025-01-27 20:00:32 +08:00
Soulter eb2b395932 perf: /t2i 即时生效 2025-01-27 19:33:38 +08:00
Soulter 2bfd1c0bf2 perf: 自动移除 ollama 不支持 tool 的模型的 tool 请求 2025-01-27 19:25:28 +08:00
Soulter 7228c4b13f fix: 修复 TTS 部分变量名错误导致请求失败 2025-01-27 18:45:34 +08:00
Soulter 9351d7471f perf: 优化 gewechat 消息下发异常处理 2025-01-27 18:11:31 +08:00
Soulter 1cf49998bc Update README.md 2025-01-27 11:34:27 +08:00
33 changed files with 707 additions and 383 deletions
+34 -9
View File
@@ -1,6 +1,6 @@
<p align="center">
![](https://github.com/user-attachments/assets/04f22f63-4dcf-4a6c-981e-b33f45c23f3e)
![logo](https://github.com/user-attachments/assets/07649e07-3b8e-4feb-9aa9-bf13af4f3476)
</p>
@@ -14,9 +14,8 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
@@ -36,7 +35,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM因此无法在聊天页使用大模型。
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
## ✨ 使用方式
@@ -65,19 +64,31 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ⚡ 消息平台支持情况
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
| QQ 官方API | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字、图片、语音 |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| 微信(企业微信) | 🚧 | 计划内 | - |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| 飞书 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
# 🦌 接下来的路线图
> [!TIP]
> 欢迎在 Issue 提出更多建议 <3
- [ ] 完善并保证目前所有平台适配器的功能一致性
- [ ] 优化插件接口
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
- [ ] 完善“聊天增强”部分,支持持久化记忆
- [ ] 规划 i18n
## ❤️ 贡献
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
@@ -129,8 +140,21 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
## Sponsors
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
## Disclaimer
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
@@ -142,5 +166,6 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
4. TTS
-->
_アトリは、高性能ですから!_
_私は、高性能ですから!_
+2 -1
View File
@@ -2,4 +2,5 @@ from astrbot.core.platform import (
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
+1 -1
View File
@@ -1,2 +1,2 @@
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse
+53 -10
View File
@@ -2,7 +2,7 @@ import os
import json
import logging
import enum
from .default import DEFAULT_CONFIG
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
@@ -13,29 +13,72 @@ class RateLimitStrategy(enum.Enum):
DISCARD = "discard"
class AstrBotConfig(dict):
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项
def __init__(self):
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
'''
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None
):
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
object.__setattr__(self, 'config_path', config_path)
object.__setattr__(self, 'default_config', default_config)
object.__setattr__(self, 'schema', schema)
if schema:
default_config = self._config_schema_to_default_config(schema)
if not self.check_exist():
'''不存在时载入默认配置'''
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
has_new = self.check_config_integrity(default_config, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def _config_schema_to_default_config(self, schema: dict) -> dict:
'''将 Schema 转换成 Config'''
conf = {}
def _parse_schema(schema: dict, conf: dict):
for k, v in schema.items():
if v['type'] not in DEFAULT_VALUE_MAP:
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
if 'default' in v:
default = v['default']
else:
default = DEFAULT_VALUE_MAP[v['type']]
if v['type'] == 'object':
conf[k] = {}
_parse_schema(v['items'], conf[k])
else:
conf[k] = default
_parse_schema(schema, conf)
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
@@ -61,7 +104,7 @@ class AstrBotConfig(dict):
'''
if replace_config:
self.update(replace_config)
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
with open(self.config_path, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
@@ -81,4 +124,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(ASTRBOT_CONFIG_PATH)
return os.path.exists(self.config_path)
+43 -7
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.13"
VERSION = "3.4.16"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -50,6 +50,12 @@ DEFAULT_CONFIG = {
"enable": False,
"provider_id": "",
},
"provider_ltm_settings": {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_prompt": "Please describe the image using Chinese.",
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -230,8 +236,8 @@ CONFIG_METADATA_2 = {
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
"items": {"type": "int"},
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
"items": {"type": "string"},
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -259,6 +265,7 @@ CONFIG_METADATA_2 = {
"path_mapping": {
"description": "路径映射",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
}
@@ -583,14 +590,14 @@ CONFIG_METADATA_2 = {
"begin_dialogs": {
"description": "预设对话",
"type": "list",
"items": {},
"items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
"obvious_hint": True,
},
"mood_imitation_dialogs": {
"description": "对话风格模仿",
"type": "list",
"items": {},
"items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
"obvious_hint": True,
},
@@ -630,6 +637,34 @@ CONFIG_METADATA_2 = {
},
},
},
"provider_ltm_settings": {
"description": "聊天记忆增强(Beta)",
"type": "object",
"items": {
"group_icl_enable": {
"description": "群聊内记录各群员对话",
"type": "bool",
"obvious-hint": True,
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
},
"group_message_max_cnt": {
"description": "群聊消息最大数量",
"type": "int",
"obvious-hint": True,
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需要模型支持)",
"type": "bool",
"obvious-hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
},
"image_caption_prompt": {
"description": "图像转述提示词",
"type": "string"
},
},
},
},
},
"misc_config_group": {
@@ -639,7 +674,8 @@ CONFIG_METADATA_2 = {
"description": "机器人唤醒前缀",
"type": "list",
"items": {"type": "string"},
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
"obvious_hint": True,
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
},
"t2i": {
"description": "文本转图像",
@@ -649,7 +685,7 @@ CONFIG_METADATA_2 = {
"admins_id": {
"description": "管理员 ID",
"type": "list",
"items": {"type": "int"},
"items": {"type": "string"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
},
"http_proxy": {
+1 -1
View File
@@ -306,7 +306,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
id: int
id: T.Union[str, int]
text: T.Optional[str] = ""
qq: T.Optional[int] = 0
time: T.Optional[int] = 0
@@ -140,5 +140,10 @@ class MessageEventResult(MessageChain):
'''
return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
'''
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
CommandResult = MessageEventResult
@@ -39,8 +39,11 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
event.stop_event()
+5 -2
View File
@@ -15,6 +15,7 @@ class RespondStage(Stage):
# 分段回复
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
@@ -22,6 +23,7 @@ class RespondStage(Stage):
except BaseException as e:
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
logger.info(f"分段回复间隔时间:{self.interval}")
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
@@ -31,11 +33,12 @@ class RespondStage(Stage):
if len(result.chain) > 0:
await event._pre_send()
if self.enable_seg:
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
# 分段回复
for comp in result.chain:
await event.send(MessageChain([comp]))
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
else:
await event.send(result)
await event._post_send()
@@ -19,7 +19,6 @@ class ResultDecorateStage:
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
self.t2i = ctx.astrbot_config['t2i']
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
@@ -68,8 +67,8 @@ class ResultDecorateStage:
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + plain_str)
audio_path = await tts_provider.get_audio(plain_str)
logger.info("TTS 请求: " + comp.text)
audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
@@ -85,7 +84,7 @@ class ResultDecorateStage:
result.chain = new_chain
# 文本转图片
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
@@ -105,7 +104,7 @@ class ResultDecorateStage:
result.chain = [Image.fromURL(url)]
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
result.chain.insert(0, At(qq=event.get_sender_id()))
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
if self.reply_with_quote:
result.chain.insert(0, Reply(id=event.message_obj.message_id))
@@ -18,6 +18,11 @@ class WhitelistCheckStage(Stage):
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
# 白名单检查未启用
return
if len(self.whitelist) == 0:
# 白名单为空,不检查
return
if event.get_platform_name() == 'webchat':
@@ -66,6 +66,7 @@ class SimpleGewechatClient():
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
abm = AstrBotMessage()
d = data['Data']
@@ -102,7 +103,7 @@ class SimpleGewechatClient():
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
user_real_name = d['PushContent'].split(' : ')[0] \
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] \
.replace('在群聊中@了你', '') \
.replace('在群聊中发了一段语音', '') # 真实昵称
abm.sender = MessageMember(user_id, user_real_name)
@@ -153,13 +154,17 @@ class SimpleGewechatClient():
if data.get('testMsg', None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = await self._convert(data)
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}")
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
@@ -289,7 +294,7 @@ class SimpleGewechatClient():
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{nickname}", appid)
sp.put(f"gewechat-appid-{self.nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
@@ -5,7 +5,7 @@ import botpy.types.message
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Reply
from botpy import Client
from botpy.http import Route
@@ -29,6 +29,19 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
ref = None
for i in self.send_buffer.chain:
if isinstance(i, Reply):
try:
ref = self.message_obj.raw_message.message_reference
ref = botpy.types.message.Reference(
message_id=ref.message_id,
ignore_get_message_error=False
)
except BaseException as _:
pass
break
payload = {
'content': plain_text,
'msg_id': self.message_obj.message_id,
@@ -36,22 +49,30 @@ class QQOfficialMessageEvent(AstrMessageEvent):
match type(source):
case botpy.message.GroupMessage:
if ref:
payload['message_reference'] = ref
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if ref:
payload['message_reference'] = ref
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if ref:
payload['message_reference'] = ref
if image_path:
payload['file_image'] = image_path
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
case botpy.message.DirectMessage:
if ref:
payload['message_reference'] = ref
if image_path:
payload['file_image'] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
+15 -3
View File
@@ -1,4 +1,5 @@
import traceback
import uuid
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
@@ -13,6 +14,7 @@ class ProviderManager():
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', [])
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
@@ -82,12 +84,16 @@ class ProviderManager():
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
changed = False
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
if provider_cfg['id'] in self.loaded_ids:
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}")
provider_cfg['id'] = new_id
changed = True
self.loaded_ids[provider_cfg['id']] = True
try:
@@ -115,7 +121,13 @@ class ProviderManager():
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
if changed:
try:
config.save_config()
except Exception as e:
logger.warning(f"保存配置文件失败:{e}")
async def initialize(self):
selected_provider_id = sp.get("curr_provider")
@@ -123,7 +135,7 @@ class ProviderManager():
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_settings.get("enable", False)
tts_enabled = self.provider_tts_settings.get("enable", False)
for provider_config in self.providers_config:
if not provider_config['enable']:
+24 -5
View File
@@ -1,5 +1,6 @@
import base64
import json
import re
from openai import AsyncOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
@@ -103,10 +104,20 @@ class ProviderOpenAIOfficial(Provider):
if tool_list:
payloads['tools'] = tool_list
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
try:
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
except BaseException as e:
if 'does not support Function Calling' \
or 'does not support tools' in e: # ollama
del payloads['tools']
logger.debug(f"模型 {self.model_name} 不支持 tools,已自动移除")
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion}")
@@ -118,6 +129,13 @@ class ProviderOpenAIOfficial(Provider):
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
# 适配 deepseek-r1 模型
if r'<think>' in completion_text:
completion_text = re.sub(r'<think>.*?</think>', '', completion_text, flags=re.DOTALL).strip()
# 可能有单标签情况
completion_text = completion_text.replace(r'<think>', '').replace(r'</think>', '').strip()
return LLMResponse("assistant", completion_text)
elif choice.message.tool_calls:
# tools call (function calling)
@@ -163,7 +181,8 @@ class ProviderOpenAIOfficial(Provider):
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
if kwargs.get("persist", True):
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
+4
View File
@@ -1,3 +1,7 @@
'''
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
'''
from typing import Union
import os
import json
+104 -102
View File
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
@@ -54,46 +54,19 @@ class Context:
self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
'''根据插件名获取插件的 Metadata'''
for star in star_registry:
if star.name == star_name:
return star
def get_all_stars(self) -> List[StarMetadata]:
'''获取当前载入的所有插件 Metadata 的列表'''
return star_registry
def get_llm_tool_manager(self) -> FuncCall:
'''
获取 LLM Tool Manager
'''
'''获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools'''
return self.provider_manager.llm_tools
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
@@ -102,6 +75,11 @@ class Context:
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
@@ -129,6 +107,101 @@ class Context:
return True
return False
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''获取 AstrBot 的配置。'''
return self._config
def get_db(self) -> BaseDatabase:
'''获取 AstrBot 数据库。'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
'''
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
'''
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
@@ -162,77 +235,6 @@ class Context:
))
star_handlers_registry.append(md)
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider(Chat_Completion 类型)。
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider(Chat_Completion 类型)。
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''
获取 AstrBot 配置信息。
'''
return self._config
def get_db(self) -> BaseDatabase:
'''
获取 AstrBot 数据库。
'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
def register_task(self, task: Awaitable, desc: str):
'''
注册一个异步任务。
+11 -7
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
star_registry: List[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {}
@@ -11,7 +12,7 @@ star_map: Dict[str, StarMetadata] = {}
@dataclass
class StarMetadata:
'''
Star 的元数据。
插件的元数据。
'''
name: str
author: str # 插件作者
@@ -20,21 +21,24 @@ class StarMetadata:
repo: str = None # 插件仓库地址
star_cls_type: type = None
'''Star 的类对象的类型'''
'''插件的类对象的类型'''
module_path: str = None
'''Star 的模块路径'''
'''插件的模块路径'''
star_cls: object = None
'''Star 的类对象'''
'''插件的类对象'''
module: ModuleType = None
'''Star 的模块对象'''
'''插件的模块对象'''
root_dir_name: str = None
'''Star 的根目录名'''
'''插件的目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留 Star'''
'''是否是 AstrBot 的保留插件'''
activated: bool = True
'''是否被激活'''
config: AstrBotConfig = None
'''插件配置'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
+54 -19
View File
@@ -2,12 +2,14 @@ import inspect
import functools
import os
import sys
import json
import traceback
import yaml
import logging
from types import ModuleType
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import DEFAULT_VALUE_MAP
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
@@ -26,13 +28,20 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self # 就这样吧,不想改了
self.context._star_manager = self
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
'''存储插件的路径。即 data/plugins'''
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
'''存储插件配置的路径。data/config'''
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
'''保留插件的路径。在 packages 目录下'''
self.conf_schema_fname = "_conf_schema.json"
'''插件配置 Schema 文件名'''
def _get_classes(self, arg: ModuleType):
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
for (name, _) in clsmembers:
@@ -128,7 +137,7 @@ class PluginManager:
return metadata
async def reload(self):
'''扫描并加载所有的 Star'''
'''扫描并加载所有的插件'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
@@ -150,13 +159,13 @@ class PluginManager:
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 导入 Star 模块,并尝试实例化 Star
# 导入插件模块,并尝试实例化插件
for plugin_module in plugin_modules:
try:
module_str = plugin_module['module']
# module_path = plugin_module['module_path']
root_dir_name = plugin_module['pname']
reserved = plugin_module.get('reserved', False)
root_dir_name = plugin_module['pname'] # 插件的目录名
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
logger.info(f"正在载入插件 {root_dir_name} ...")
@@ -173,11 +182,33 @@ class PluginManager:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
continue
# 检查 _conf_schema.json
plugin_config = None
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
if os.path.exists(plugin_schema_path):
# 加载插件配置
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
plugin_config = AstrBotConfig(
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
schema=json.loads(f.read())
)
if path in star_map:
# 通过装饰器的方式注册插件
metadata = star_map[path]
metadata.star_cls = metadata.star_cls_type(context=self.context)
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
except TypeError as _:
metadata.star_cls = metadata.star_cls_type(context=self.context)
else:
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -199,16 +230,20 @@ class PluginManager:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
classes = self._get_classes(module)
try:
obj = getattr(module, classes[0])(context=self.context)
except BaseException as e:
logger.error(f"插件 {root_dir_name} 实例化失败。")
raise e
if plugin_config:
try:
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
except TypeError as _:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
else:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -221,7 +256,7 @@ class PluginManager:
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 执行 initialize 函数
# 执行 initialize() 方法
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
@@ -292,13 +327,14 @@ class PluginManager:
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
inactivated_llm_tools.append(func_tool.name)
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
@@ -323,8 +359,9 @@ class PluginManager:
plugin.activated = True
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
desti_dir = os.path.join(self.plugin_store_path, dir_name)
self.updator.unzip_file(zip_file_path, desti_dir)
# remove the zip
@@ -332,6 +369,4 @@ class PluginManager:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
await self.reload()
-1
View File
@@ -39,7 +39,6 @@ class RepoZipUpdator():
else:
ret = self.github_api_release_parser(result)
except BaseException:
logger.error("解析版本信息失败")
raise Exception("解析版本信息失败")
return ret
+70 -53
View File
@@ -1,14 +1,12 @@
import os
import json
import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -19,9 +17,9 @@ def try_cast(value: str, type_: str):
elif type_ == "float" and isinstance(value, int):
return float(value)
def validate_config(data, config: AstrBotConfig):
def validate_config(data, schema: dict, is_core: bool):
errors = []
def validate(data, metadata=CONFIG_METADATA_2, path=""):
def validate(data, metadata=schema, path=""):
for key, meta in metadata.items():
if key not in data:
continue
@@ -56,35 +54,33 @@ def validate_config(data, config: AstrBotConfig):
elif meta["type"] == "object" and not isinstance(value, dict):
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
validate(value, meta["items"], path=f"{path}{key}.")
validate(data)
if is_core:
for key, group in schema.items():
group_meta = group.get("metadata")
if not group_meta:
continue
logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
else:
validate(data, schema)
return errors
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
'''验证并保存配置'''
errors = validate_config(post_config, config)
errors = None
try:
if is_core:
errors = validate_config(post_config, CONFIG_METADATA_2, is_core)
else:
errors = validate_config(post_config, config.schema, is_core)
except BaseException as e:
logger.warning(f"验证配置时出现异常: {e}")
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.save_config(post_config)
def save_extension_config(post_config: dict):
if 'namespace' not in post_config:
raise ValueError("Missing key: namespace")
if 'config' not in post_config:
raise ValueError("Missing key: config")
namespace = post_config['namespace']
config: list = post_config['config'][0]['body']
for item in config:
key = item['path']
value = item['value']
typ = item['val_type']
if typ == 'int':
if not value.isdigit():
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
value = int(value)
update_config(namespace, key, value)
class ConfigRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
@@ -92,17 +88,17 @@ class ConfigRoute(Route):
self.routes = {
'/config/get': ('GET', self.get_configs),
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
'/config/plugin/update': ('POST', self.post_extension_configs),
'/config/plugin/update': ('POST', self.post_plugin_configs),
}
self.register_routes()
async def get_configs(self):
# namespace 为空时返回 AstrBot 配置
# 否则返回指定 namespace 的插件配置
namespace = "" if "namespace" not in request.args else request.args["namespace"]
if not namespace:
# plugin_name 为空时返回 AstrBot 配置
# 否则返回指定 plugin_name 的插件配置
plugin_name = request.args.get("plugin_name", None)
if not plugin_name:
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_extension_config(namespace)).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
async def post_astrbot_configs(self):
post_configs = await request.json
@@ -110,14 +106,15 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
traceback.print_exc()
logger.error(e)
return Response().error(str(e)).__dict__
async def post_extension_configs(self):
async def post_plugin_configs(self):
post_configs = await request.json
plugin_name = request.args.get("plugin_name", "unknown")
try:
await self._save_extension_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
await self._save_plugin_configs(post_configs, plugin_name)
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
except Exception as e:
return Response().error(str(e)).__dict__
@@ -141,28 +138,48 @@ class ConfigRoute(Route):
"config": config
}
async def _get_extension_config(self, namespace: str):
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
return [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
async def _get_plugin_config(self, plugin_name: str):
ret = {
"metadata": None,
"config": None
}
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
if not plugin_md.config:
break
ret['config'] = plugin_md.config # 这是自定义的 Dict 类(AstrBotConfig
ret['metadata'] = {
plugin_name: {
"description": f"{plugin_name} 配置",
"type": "object",
"items": plugin_md.config.schema # 初始化时通过 __setattr__ 存入了 schema
}
}
break
return ret
async def _save_astrbot_configs(self, post_configs: dict):
try:
save_astrbot_config(post_configs, self.config)
save_config(post_configs, self.config, is_core=True)
self.core_lifecycle.restart()
except Exception as e:
raise e
async def _save_extension_configs(self, post_configs: dict):
async def _save_plugin_configs(self, post_configs: dict, plugin_name: str):
md = None
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
md = plugin_md
if not md:
raise ValueError(f"插件 {plugin_name} 不存在")
if not md.config:
raise ValueError(f"插件 {plugin_name} 没有注册配置")
try:
save_extension_config(post_configs)
save_config(post_configs, md.config)
self.core_lifecycle.restart()
except Exception as e:
raise e
+2 -2
View File
@@ -67,9 +67,9 @@ class PluginRoute(Route):
file = await request.files
file = file['file']
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{uuid.uuid4()}.zip"
file_path = f"data/temp/{file.filename}"
await file.save(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
+2 -1
View File
@@ -1,4 +1,5 @@
import traceback
import aiohttp
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -43,7 +44,7 @@ class UpdateRoute(Route):
}
).__dict__
except Exception as e:
logger.error(traceback.format_exc())
logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)")
return Response().error(e.__str__()).__dict__
async def update_project(self):
+8
View File
@@ -0,0 +1,8 @@
# What's Changed
- 修复: TTS 问题
- 新增: **支持记录非唤醒状态下群聊历史记录(beta)**
- 优化: 自动删除 deepseek-r1 模型自带的 think 标签
- 优化: 自动移除 ollama 不支持 tool 的模型的 tool 请求
- 优化: /t2i 即时生效
- 优化: gewechat 消息下发异常处理
+9
View File
@@ -0,0 +1,9 @@
# What's Changed
- 修复: 配置 Validator 不起效的问题
- 修复: DeepSeek-R1 思考标签问题
- 修复: 分段回复间隔时间不生效
- 修复: 修复白名单为空时依然终止事件 #259
- 修复: 群聊增强某些参数的类型转换问题
- 新增: 插件支持注册配置,详见 [注册插件配置](https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta)
- 优化: 插件的禁用/启用逻辑以及函数工具的禁用/启用逻辑
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
- [gewechat] [修复每次启动astrbot都需要扫码的问题](https://github.com/Soulter/AstrBot/commit/fd5d7dd37a6d74f81a148bbebef8516aa0cb5540)
- [core] [Provider 重复时不直接报错闪退](https://github.com/Soulter/AstrBot/commit/b61f9be18db9a6b8b3c5b6b36553f66dd2b79375) https://github.com/Soulter/AstrBot/issues/265
- [core] [弱化更新报错](https://github.com/Soulter/AstrBot/commit/0ba0150fd8ff2062dbe83889163888ba3e33bd49) https://github.com/Soulter/AstrBot/issues/267
- 修复 webui 无法从本地上传插件的问题
@@ -1,41 +0,0 @@
<script setup>
import UiParentCard from '@/components/shared/UiParentCard.vue';
const props = defineProps({
config: Array
});
</script>
<template>
<a v-show="config.length === 0">该插件没有配置</a>
<UiParentCard v-for="group in config" :key="group.name" :title="group.name" style="margin-bottom: 16px;">
<template v-for="item in group.body">
<template v-if="item.config_type === 'item'">
<template v-if="item.val_type === 'bool'">
<v-switch v-model="item.value" :label="item.name" :hint="item.description" color="primary" inset></v-switch>
</template>
<template v-else-if="item.val_type === 'str'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'int'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'list'">
<span>{{ item.name }}</span>
<v-combobox v-model="item.value" chips clearable label="请添加" multiple prepend-icon="mdi-tag-multiple-outline">
<template v-slot:selection="{ attrs, item, select, selected }">
<v-chip v-bind="attrs" :model-value="selected" closable @click="select" @click:close="remove(item)">
<strong>{{ item }}</strong>
</v-chip>
</template>
</v-combobox>
</template>
</template>
<template v-else-if="item.config_type === 'divider'">
<v-divider style="margin-top: 8px; margin-bottom: 8px;"></v-divider>
</template>
</template>
</UiParentCard>
</template>
+18 -11
View File
@@ -1,7 +1,7 @@
<script setup>
import ExtensionCard from '@/components/shared/ExtensionCard.vue';
import ConfigDetailCard from '@/components/shared/ConfigDetailCard.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
@@ -52,11 +52,17 @@ import axios from 'axios';
<v-btn v-else variant="plain" disabled>已安装</v-btn>
</div>
</ExtensionCard>
</v-col>
<v-col style="margin-bottom: 16px;" cols="12" md="12">
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
</v-col>
</v-row>
<v-dialog v-model="configDialog" width="750">
<v-dialog v-model="configDialog" width="1000">
<template v-slot:activator="{ props }">
</template>
<v-card>
@@ -65,7 +71,8 @@ import axios from 'axios';
</v-card-title>
<v-card-text>
<v-container>
<ConfigDetailCard :config="extension_config"></ConfigDetailCard>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<p v-else>这个插件没有配置</p>
</v-container>
</v-card-text>
<v-card-actions>
@@ -172,9 +179,9 @@ export default {
name: 'ExtensionPage',
components: {
ExtensionCard,
ConfigDetailCard,
WaitingForRestart,
ConsoleDisplayer
ConsoleDisplayer,
AstrBotConfig
},
data() {
return {
@@ -189,7 +196,10 @@ export default {
snack_success: "success",
loading_: false,
configDialog: false,
extension_config: {},
extension_config: {
"metadata": {},
"config": {}
},
upload_file: null,
pluginMarketData: {},
loadingDialog: {
@@ -364,7 +374,7 @@ export default {
openExtensionConfig(extension_name) {
this.curr_namespace = extension_name;
this.configDialog = true;
axios.get('/api/config/get?namespace=' + extension_name).then((res) => {
axios.get('/api/config/get?plugin_name=' + extension_name).then((res) => {
this.extension_config = res.data.data;
console.log(this.extension_config);
}).catch((err) => {
@@ -372,10 +382,7 @@ export default {
});
},
updateConfig() {
axios.post('/api/config/plugin/update', {
"config": this.extension_config,
"namespace": this.curr_namespace
}).then((res) => {
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
if (res.data.status === "ok") {
this.toast(res.data.message, "success");
this.$refs.wfr.check();
+88
View File
@@ -0,0 +1,88 @@
import datetime
import uuid
import astrbot.api.star as star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.platform import MessageType
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Plain, Image
from astrbot import logger
from collections import defaultdict
class LongTermMemory:
def __init__(self, config: dict, context: star.Context):
self.config = config
self.context = context
self.session_chats = defaultdict(list)
"""记录群成员的群聊记录"""
try:
self.max_cnt = int(self.config["group_message_max_cnt"])
except BaseException as e:
logger.error(e)
self.max_cnt = 300
self.image_caption = self.config["image_caption"]
self.image_caption_prompt = self.config["image_caption_prompt"]
async def remove_session(self, event: AstrMessageEvent) -> int:
cnt = 0
if event.unified_msg_origin in self.session_chats:
cnt = len(self.session_chats[event.unified_msg_origin])
del self.session_chats[event.unified_msg_origin]
return cnt
async def get_image_caption(self, image_url: str) -> str:
provider = self.context.get_using_provider()
response = await provider.text_chat(
prompt=self.image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
)
return response.completion_text
async def handle_message(self, event: AstrMessageEvent):
if event.get_message_type() == MessageType.GROUP_MESSAGE:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
for comp in event.get_messages():
if isinstance(comp, Plain):
final_message += f" {comp.text}"
elif isinstance(comp, Image):
# image_urls.append(comp.url if comp.url else comp.file)
if self.image_caption:
try:
caption = await self.get_image_caption(
comp.url if comp.url else comp.file
)
final_message += f" [Image: {caption}]"
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
self.session_chats[event.unified_msg_origin].pop(0)
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
if event.unified_msg_origin not in self.session_chats:
return
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
req.system_prompt += chats_str
if self.image_caption:
req.system_prompt += (
"The images sent by the members are displayed in text form above."
)
async def after_req_llm(self, event: AstrMessageEvent):
if event.unified_msg_origin not in self.session_chats:
return
if event.get_result() and event.get_result().is_llm_result():
final_message = f"[AstrBot/{datetime.datetime.now().strftime('%H:%M:%S')}]: {event.get_result().get_plain_text()}"
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
self.session_chats[event.unified_msg_origin].pop(0)
+99 -58
View File
@@ -6,12 +6,16 @@ import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import sp
from astrbot.api.provider import Personality, ProviderRequest
from astrbot.api.platform import MessageType
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.config.default import VERSION
from collections import defaultdict
from .long_term_memory import LongTermMemory
from astrbot.core import logger
from typing import Union
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
@star.register(name="astrbot", desc="AstrBot 基础指令结合 + 拓展功能", author="Soulter", version="4.0.0")
class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
@@ -20,7 +24,12 @@ class Main(star.Star):
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
self.kdb_enabled = False
self.ltm = None
if self.context.get_config()['provider_ltm_settings']['group_icl_enable']:
try:
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
except BaseException as e:
logger.error(f"聊天增强 err: {e}")
async def _query_astrbot_notice(self):
try:
@@ -219,7 +228,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
await self.context.get_using_provider().forget(message.session_id)
message.set_result(MessageEventResult().message("重置成功"))
ret = "清除会话 LLM 聊天历史成功"
if self.ltm:
cnt = await self.ltm.remove_session(event=message)
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
message.set_result(MessageEventResult().message(ret))
@filter.command("model")
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
@@ -355,9 +369,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
self.context.provider_manager.personas
), None):
self.context.get_using_provider().curr_personality = persona
message.set_result(MessageEventResult().message(f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
else:
message.set_result(MessageEventResult().message(f"不存在该人格情景。使用 /persona list 查看所有。"))
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
@@ -366,31 +380,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await download_dashboard()
yield event.plain_result("管理面板更新完成。")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if persona := provider.curr_personality:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
# if provider.curr_personality['prompt']:
# req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
session_id = event.get_session_id()
@@ -428,32 +417,84 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await platform.logout()
yield event.plain_result("已登出 gewechat")
return
@filter.command_group("kdb")
def kdb(self):
pass
@kdb.command("on")
async def on_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = True
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if not curr_kdb_name:
yield event.plain_result("未载入任何知识库")
else:
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
@kdb.command("off")
async def off_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = False
yield event.plain_result("知识库已关闭")
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''长期记忆'''
if self.ltm:
try:
await self.ltm.handle_message(event)
except BaseException as e:
logger.error(e)
@filter.on_llm_request()
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if self.kdb_enabled and curr_kdb_name:
mgr = self.context.knowledge_db_manager
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
if results:
req.system_prompt += "\nHere are documents that related to user's query: \n"
for result in results:
req.system_prompt += f"- {result}\n"
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
'''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt'''
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if persona := provider.curr_personality:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
if self.ltm:
try:
await self.ltm.on_req_llm(event, req)
except BaseException as e:
logger.error(f"ltm: {e}")
@filter.after_message_sent()
async def after_llm_req(self, event: AstrMessageEvent):
'''在 LLM 请求后记录对话'''
if self.ltm:
try:
await self.ltm.after_req_llm(event)
except BaseException as e:
logger.error(f"ltm: {e}")
# @filter.command_group("kdb")
# def kdb(self):
# pass
# @kdb.command("on")
# async def on_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = True
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
# if not curr_kdb_name:
# yield event.plain_result("未载入任何知识库")
# else:
# yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
# @kdb.command("off")
# async def off_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = False
# yield event.plain_result("知识库已关闭")
# @filter.on_llm_request()
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
# curr_kdb_name = self.context.provider_manager.curr_kdb_name
# if self.kdb_enabled and curr_kdb_name:
# mgr = self.context.knowledge_db_manager
# results = await mgr.retrive_records(curr_kdb_name, req.prompt)
# if results:
# req.system_prompt += "\nHere are documents that related to user's query: \n"
# for result in results:
# req.system_prompt += f"- {result}\n"
-1
View File
@@ -1,5 +1,4 @@
pydantic~=2.10.3
vchat
aiohttp
openai
qq-botpy
-33
View File
@@ -1,33 +0,0 @@
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
+1
View File
@@ -5,6 +5,7 @@ import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import CONFIG_METADATA_2
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
from astrbot.core.message.message_event_result import MessageChain, ResultContentType