Merge remote-tracking branch 'origin/master' into anka-dev

This commit is contained in:
anka
2025-03-26 13:53:21 +08:00
30 changed files with 1882 additions and 133 deletions
+2
View File
@@ -1,6 +1,8 @@
__pycache__
botpy.log
.vscode
.venv*
.idea
data_v2.db
data_v3.db
configs/session
+1 -1
View File
@@ -7,7 +7,7 @@ ci:
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.0
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format
+6 -24
View File
@@ -16,7 +16,6 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=60&style=for-the-badge&color=3b618e)
[![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
@@ -28,6 +27,9 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg?style=for-the-badge)](https://gitcode.com/Soulter/AstrBot)
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
## ✨ 主要功能
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
@@ -50,7 +52,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
#### Windows 一键安装器部署
需要电脑上安装有 Python>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### 宝塔面板部署
@@ -62,22 +64,13 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### 手动部署
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## 🚀 路线图
#### Replit 部署
### 垂类功能
1. 更好的上下文管理:限制 token 总数、对话上下文总结
3. AstrBot in Minecraft
### 横功能
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ⚡ 消息平台支持情况
@@ -190,16 +183,5 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
_私は、高性能ですから!_
+9 -3
View File
@@ -49,6 +49,7 @@ DEFAULT_CONFIG = {
"datetime_system_prompt": True,
"default_personality": "default",
"prompt_prefix": "",
"max_context_length": -1,
},
"provider_stt_settings": {
"enable": False,
@@ -346,7 +347,7 @@ CONFIG_METADATA_2 = {
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "只处理填写的 ID 发来的消息事件为空时不启用白名单过滤。可使用 /sid 指令获取在某个平台上的会话 ID。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单",
"hint": "只处理填写的 ID 发来的消息事件为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -909,6 +910,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "添加之后,会在每次对话的 Prompt 前加上此文本。",
},
"max_context_length": {
"description": "最多携带对话数量(条)",
"type": "int",
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
},
},
},
"persona": {
@@ -1002,10 +1008,10 @@ CONFIG_METADATA_2 = {
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需模型支持)",
"description": "群聊图像转述(需模型支持)",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型",
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入",
},
"image_caption_provider_id": {
"description": "图像转述提供商 ID",
+43 -1
View File
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass
from typing import List
from typing import List, Dict, Any, Tuple
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@@ -117,3 +117,45 @@ class BaseDatabase(abc.ABC):
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID"""
raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码,从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
+192 -18
View File
@@ -3,7 +3,7 @@ import os
import time
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
from . import BaseDatabase
from typing import Tuple
from typing import Tuple, List, Dict, Any
class SQLiteDatabase(BaseDatabase):
@@ -128,24 +128,23 @@ class SQLiteDatabase(BaseDatabase):
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
where_clause = ""
if session_id or provider_type:
where_clause += " WHERE "
has = False
if session_id:
where_clause += f"session_id = '{session_id}'"
has = True
if provider_type:
if has:
where_clause += " AND "
where_clause += f"provider_type = '{provider_type}'"
conditions = []
params = []
if session_id:
conditions.append("session_id = ?")
params.append(session_id)
if provider_type:
conditions.append("provider_type = ?")
params.append(provider_type)
sql = "SELECT * FROM llm_history"
if conditions:
sql += " WHERE " + " AND ".join(conditions)
c.execute(sql, params)
c.execute(
"""
SELECT * FROM llm_history
"""
+ where_clause
)
res = c.fetchall()
histories = []
for row in res:
@@ -389,3 +388,178 @@ class SQLiteDatabase(BaseDatabase):
if res:
return ATRIVision(*res)
return None
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
+5 -3
View File
@@ -38,11 +38,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
);
PRAGMA encoding = 'UTF-8';
+15
View File
@@ -61,6 +61,8 @@ class ComponentType(Enum):
TTS = "TTS"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
class BaseMessageComponent(BaseModel):
type: ComponentType
@@ -412,6 +414,8 @@ class Reply(BaseMessageComponent):
"""引用的消息发送时间"""
message_str: T.Optional[str] = ""
"""解析后的纯文本消息字符串"""
sender_str: T.Optional[str] = ""
"""被引用的消息纯文本"""
text: T.Optional[str] = ""
"""deprecated"""
@@ -559,6 +563,16 @@ class File(BaseMessageComponent):
super().__init__(name=name, file=file)
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
ComponentTypes = {
"plain": Plain,
"text": Plain,
@@ -587,4 +601,5 @@ ComponentTypes = {
"tts": TTS,
"unknown": Unknown,
"file": File,
"WechatEmoji": WechatEmoji,
}
@@ -34,6 +34,9 @@ class LLMRequestSubStage(Stage):
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
"wake_prefix"
] # str
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -123,6 +126,14 @@ class LLMRequestSubStage(Stage):
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[-self.max_context_length * 2 :]
try:
need_loop = True
while need_loop:
@@ -10,11 +10,18 @@ import anyio
import quart
from astrbot.api import logger, sp
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api.message_components import Plain, Image, At, Record, Video
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.core.utils.io import download_image_by_url
from .downloader import GeweDownloader
try:
from .xml_data_parser import GeweDataParser
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
class SimpleGewechatClient:
"""针对 Gewechat 的简单实现。
@@ -217,15 +224,10 @@ class SimpleGewechatClient:
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
@@ -236,15 +238,19 @@ class SimpleGewechatClient:
case 42: # 名片
logger.info("消息类型(42):名片")
case 43: # 视频
logger.info("消息类型(43):视频")
video = Video(file="", cover=content)
abm.message.append(video)
case 47: # emoji
logger.info("消息类型(47)emoji")
data_parser = GeweDataParser(content, abm.group_id == "")
emoji = data_parser.parse_emoji()
abm.message.append(emoji)
case 48: # 地理位置
logger.info("消息类型(48):地理位置")
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
logger.info(
"消息类型(49):公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请"
)
data_parser = GeweDataParser(content, abm.group_id == "")
abm_data = data_parser.parse_mutil_49()
if abm_data:
abm.message.append(abm_data)
case 51: # 帐号消息同步?
logger.info("消息类型(51):帐号消息同步?")
case 10000: # 被踢出群聊/更换群主/修改群名称
@@ -508,6 +514,34 @@ class SimpleGewechatClient:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
"""发送emoji消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"emojiMd5": emoji_md5,
"emojiSize": emoji_size,
}
# 优先表情包,若拿不到表情包的md5,就用当作图片发
try:
if emoji_md5 != "" and emoji_size != "":
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postEmoji",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
)
else:
await self.post_image(to_wxid, cdnurl)
except Exception as e:
logger.error(e)
async def post_video(
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
):
@@ -525,6 +559,27 @@ class SimpleGewechatClient:
json_blob = await resp.json()
logger.debug(f"发送视频结果: {json_blob}")
async def forward_video(self, to_wxid, cnd_xml: str):
"""转发视频
Args:
to_wxid (str): 发送给谁
cnd_xml (str): 视频消息的cdn信息
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"xml": cnd_xml,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/forwardVideo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"转发视频结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
"""发送语音信息
@@ -546,7 +601,7 @@ class SimpleGewechatClient:
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
async def post_file(self, to_wxid, file_url: str, file_name: str):
"""发送文件
@@ -39,3 +39,17 @@ class GeweDownloader:
continue
raise Exception("无法下载图片")
async def download_emoji_md5(self, app_id, emoji_md5):
"""下载emoji"""
try:
payload = {"appId": app_id, "emojiMd5": emoji_md5}
# gewe 计划中的接口,暂时没有实现。返回代码404
data = await self._post_json(
self.base_url, "/message/downloadEmojiMd5", payload
)
json_blob = json.loads(data)
return json_blob
except BaseException as e:
logger.error(f"gewe download emoji: {e}")
@@ -8,7 +8,15 @@ from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
from astrbot.api.message_components import Plain, Image, Record, At, File, Video
from astrbot.api.message_components import (
Plain,
Image,
Record,
At,
File,
Video,
WechatEmoji as Emoji,
)
from .client import SimpleGewechatClient
@@ -84,55 +92,60 @@ class GewechatPlatformEvent(AstrMessageEvent):
logger.debug(f"gewe callback img url: {img_url}")
await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
if comp.cover != "":
await client.forward_video(to_wxid, comp.cover)
else:
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
video_url = comp.file
# 根据 url 下载视频
video_filename = f"{uuid.uuid4()}.mp4"
video_path = f"data/temp/{video_filename}"
await download_file(video_url, video_path)
# 获取视频第一帧
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
try:
ff = FFmpeg()
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
ff.options(command)
thumb_file_id = os.path.basename(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
file_id = os.path.basename(video_path)
video_url = f"{client.file_server_url}/{file_id}"
await client.post_video(
to_wxid, video_url, thumb_url, video_duration
)
video_url = comp.file
# 根据 url 下载视频
video_filename = f"{uuid.uuid4()}.mp4"
video_path = f"data/temp/{video_filename}"
await download_file(video_url, video_path)
# 获取视频第一帧
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
try:
ff = FFmpeg()
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
ff.options(command)
thumb_file_id = os.path.basename(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
file_id = os.path.basename(video_path)
video_url = f"{client.file_server_url}/{file_id}"
await client.post_video(to_wxid, video_url, thumb_url, video_duration)
# 删除临时视频和缩略图文件
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(thumb_path):
os.remove(thumb_path)
# 删除临时视频和缩略图文件
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
@@ -165,6 +178,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
file_url = f"{client.file_server_url}/{file_id}"
logger.debug(f"gewe callback file url: {file_url}")
await client.post_file(to_wxid, file_url, file_id)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At):
pass
else:
@@ -0,0 +1,78 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
class GeweDataParser:
def __init__(self, data, is_private_chat):
self.data = data
self.is_private_chat = is_private_chat
def _format_to_xml(self):
return eT.fromstring(self.data)
def parse_mutil_49(self):
appmsg_type = self._format_to_xml().find(".//appmsg/type")
if appmsg_type is None:
return
match appmsg_type.text:
case "57":
return self.parse_reply()
def parse_emoji(self) -> Emoji | None:
try:
emoji_element = self._format_to_xml().find(".//emoji")
# 提取 md5 和 len 属性
if emoji_element is not None:
md5_value = emoji_element.get("md5")
emoji_size = emoji_element.get("len")
cdnurl = emoji_element.get("cdnurl")
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
except Exception as e:
logger.error(f"gewechat: parse_emoji failed, {e}")
def parse_reply(self) -> Reply | None:
try:
replied_id = -1
replied_uid = 0
replied_nickname = ""
replied_content = ""
content = ""
root = self._format_to_xml()
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 被引用的信息
svrid = refermsg.find("svrid")
fromusr = refermsg.find("fromusr")
displayname = refermsg.find("displayname")
refermsg_content = refermsg.find("content")
if svrid is not None:
replied_id = svrid.text
if fromusr is not None:
replied_uid = fromusr.text
if displayname is not None:
replied_nickname = displayname.text
if refermsg_content is not None:
replied_content = refermsg_content.text
# 提取引用者说的内容
title = root.find(".//appmsg/title")
if title is not None:
content = title.text
r = Reply(
id=replied_id,
chain=[Plain(content)],
sender_id=replied_uid,
sender_nickname=replied_nickname,
sender_str=replied_content,
message_str=content,
)
return r
except Exception as e:
logger.error(f"gewechat: parse_reply failed, {e}")
@@ -100,7 +100,8 @@ class TelegramPlatformAdapter(Platform):
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
logger.debug(f"Telegram message: {update.message}")
abm = await self.convert_message(update, context)
await self.handle_msg(abm)
if abm:
await self.handle_msg(abm)
async def convert_message(
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
@@ -178,7 +179,7 @@ class TelegramPlatformAdapter(Platform):
message.message.append(Comp.Plain(plain_text))
message.message_str = plain_text
if message.message_str == "/start":
if message.message_str.strip() == "/start":
await self.start(update, context)
return
@@ -240,5 +241,13 @@ class TelegramPlatformAdapter(Platform):
return self.client
async def terminate(self):
await self.application.stop()
logger.info("Telegram 适配器已被优雅地关闭")
try:
await self.application.stop()
# 保险起见先判断是否存在updater对象
if self.application.updater is not None:
await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭")
except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}")
+2
View File
@@ -7,6 +7,7 @@ from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
from .tools import ToolsRoute # 导入新的ToolsRoute
from .conversation import ConversationRoute
__all__ = [
@@ -19,4 +20,5 @@ __all__ = [
"StaticFileRoute",
"ChatRoute",
"ToolsRoute", # 添加新的ToolsRoute
"ConversationRoute",
]
+227
View File
@@ -0,0 +1,227 @@
import traceback
import json
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.db import BaseDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class ConversationRoute(Route):
def __init__(
self,
context: RouteContext,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.routes = {
"/conversation/list": ("GET", self.list_conversations),
"/conversation/detail": (
"POST",
self.get_conv_detail,
),
"/conversation/update": ("POST", self.upd_conv),
"/conversation/delete": ("POST", self.del_conv),
"/conversation/update_history": (
"POST",
self.update_history,
),
}
self.db_helper = db_helper
self.register_routes()
async def list_conversations(self):
"""获取对话列表,支持分页、排序和筛选"""
try:
# 获取分页参数
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
# 获取筛选参数
platforms = request.args.get("platforms", "")
message_types = request.args.get("message_types", "")
search_query = request.args.get("search", "")
exclude_ids = request.args.get("exclude_ids", "")
exclude_platforms = request.args.get("exclude_platforms", "")
# 转换为列表
platform_list = platforms.split(",") if platforms else []
message_type_list = message_types.split(",") if message_types else []
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
exclude_platform_list = (
exclude_platforms.split(",") if exclude_platforms else []
)
logger.info(
f"获取对话列表: page={page}, page_size={page_size}, "
f"platforms={platform_list}, message_types={message_type_list}, "
f"search={search_query}, exclude_ids={exclude_id_list}, "
f"exclude_platforms={exclude_platform_list}"
)
# 限制页面大小,防止请求过大数据
if page < 1:
page = 1
if page_size < 1:
page_size = 20
if page_size > 100:
page_size = 100
# 使用数据库的分页方法获取会话列表和总数,传入筛选条件
try:
conversations, total_count = self.db_helper.get_filtered_conversations(
page=page,
page_size=page_size,
platforms=platform_list,
message_types=message_type_list,
search_query=search_query,
exclude_ids=exclude_id_list,
exclude_platforms=exclude_platform_list,
)
logger.info(f"获取到 {len(conversations)} 条对话,总数: {total_count}")
except Exception as e:
logger.error(f"数据库查询出错: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"数据库查询出错: {str(e)}").__dict__
# 计算总页数
total_pages = (
(total_count + page_size - 1) // page_size if total_count > 0 else 1
)
result = {
"conversations": conversations,
"pagination": {
"page": page,
"page_size": page_size,
"total": total_count,
"total_pages": total_pages,
},
}
logger.info(
f"返回对话列表成功: {json.dumps(result, ensure_ascii=False)[:200]}..."
)
return Response().ok(result).__dict__
except Exception as e:
error_msg = f"获取对话列表失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"获取对话列表失败: {str(e)}").__dict__
async def get_conv_detail(self):
"""获取指定对话详情(通过POST请求)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
return (
Response()
.ok(
{
"user_id": user_id,
"cid": cid,
"title": conversation.title,
"persona_id": conversation.persona_id,
"history": conversation.history,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
}
)
.__dict__
)
except Exception as e:
logger.error(f"获取对话详情失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"获取对话详情失败: {str(e)}").__dict__
async def upd_conv(self):
"""更新对话信息(标题和角色ID)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
title = data.get("title")
persona_id = data.get("persona_id", "")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
if title is not None:
self.db_helper.update_conversation_title(user_id, cid, title)
if persona_id is not None:
self.db_helper.update_conversation_persona_id(user_id, cid, persona_id)
return Response().ok({"message": "对话信息更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话信息失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新对话信息失败: {str(e)}").__dict__
async def del_conv(self):
"""删除对话"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
self.db_helper.delete_conversation(user_id, cid)
return Response().ok({"message": "对话删除成功"}).__dict__
except Exception as e:
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"删除对话失败: {str(e)}").__dict__
async def update_history(self):
"""更新对话历史内容"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
history = data.get("history")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
if history is None:
return Response().error("缺少必要参数: history").__dict__
# 历史记录必须是合法的 JSON 字符串
try:
if isinstance(history, list):
history = json.dumps(history)
else:
# 验证是否为有效的 JSON 字符串
json.loads(history)
except json.JSONDecodeError:
return (
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
)
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
self.db_helper.update_conversation(user_id, cid, history)
return Response().ok({"message": "对话历史更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话历史失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {str(e)}").__dict__
+1
View File
@@ -51,6 +51,7 @@ class AstrBotDashboard:
self.ar = AuthRoute(self.context)
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
self.tools_root = ToolsRoute(self.context, core_lifecycle)
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
self.shutdown_event = shutdown_event
+2 -2
View File
@@ -1,6 +1,6 @@
version: '3.8'
# 当接入 QQ NapCat 时,请使用这个 compose 文件一部署: https://github.com/NapNeko/NapCat-Docker/blob/main/compose/astrbot.yml
# 当接入 QQ NapCat 时,请使用这个 compose 文件一部署: https://github.com/NapNeko/NapCat-Docker/blob/main/compose/astrbot.yml
services:
astrbot:
@@ -18,4 +18,4 @@ services:
volumes:
- ./data:/AstrBot/data
# - /etc/timezone:/etc/timezone:ro
# - /etc/localtime:/etc/localtime:ro
- /etc/localtime:/etc/localtime:ro
@@ -12,7 +12,7 @@
<v-card class="item-card hover-elevation" :color="getItemEnabled(item) ? '' : 'grey-lighten-4'">
<div class="item-status-indicator" :class="{'active': getItemEnabled(item)}"></div>
<v-card-title class="d-flex justify-space-between align-center pb-1 pt-3">
<span class="text-h6 text-truncate" :title="getItemTitle(item)">{{ getItemTitle(item) }}</span>
<span class="text-h4 text-truncate" :title="getItemTitle(item)">{{ getItemTitle(item) }}</span>
<v-tooltip location="top">
<template v-slot:activator="{ props }">
<v-switch
@@ -16,7 +16,7 @@ const props = defineProps({ item: Object, level: Number });
<template v-slot:prepend>
<v-icon v-if="item.icon" :size="item.iconSize" class="hide-menu" :icon="item.icon"></v-icon>
</template>
<v-list-item-title style="font-size: 15px;">{{ item.title }}</v-list-item-title>
<v-list-item-title style="font-size: 14px;">{{ item.title }}</v-list-item-title>
<v-list-item-subtitle v-if="item.subCaption" class="text-caption mt-n1 hide-menu">
{{ item.subCaption }}
</v-list-item-subtitle>
@@ -31,7 +31,12 @@ const sidebarItem: menu[] = [
to: '/providers',
},
{
title: '配置',
title: 'MCP',
icon: 'mdi-function-variant',
to: '/tool-use'
},
{
title: '配置文件',
icon: 'mdi-cog',
to: '/config',
},
@@ -45,16 +50,16 @@ const sidebarItem: menu[] = [
icon: 'mdi-storefront',
to: '/extension-marketplace'
},
{
title: '函数调用',
icon: 'mdi-function-variant',
to: '/tool-use'
},
{
title: '聊天',
icon: 'mdi-chat',
to: '/chat'
},
{
title: '对话数据库',
icon: 'mdi-database',
to: '/conversation'
},
{
title: '控制台',
icon: 'mdi-console',
+5
View File
@@ -46,6 +46,11 @@ const MainRoutes = {
path: '/dashboard/default',
component: () => import('@/views/dashboards/default/DefaultDashboard.vue')
},
{
name: 'Conversation',
path: '/conversation',
component: () => import('@/views/ConversationPage.vue')
},
{
name: 'Console',
path: '/console',
+3
View File
@@ -16,6 +16,9 @@
color: rgb(var(--v-theme-secondary));
}
}
.v-list-item--density-default.v-list-item--one-line {
min-height: 42px;
}
.leftPadding {
margin-left: 4px;
}
File diff suppressed because it is too large Load Diff
+7 -4
View File
@@ -67,7 +67,7 @@ import { useCommonStore } from '@/stores/common';
<v-col cols="12" md="12" style="padding: 0px;">
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name"
:loading="loading_" v-model:search="marketSearch"
:filter-keys="['name', 'desc', 'author']">
:filter-keys="filterKeys">
<template v-slot:item.name="{ item }">
<div class="d-flex align-center">
<img v-if="item.logo" :src="item.logo"
@@ -221,7 +221,9 @@ export default {
],
marketSearch: "",
commonStore: useCommonStore()
commonStore: useCommonStore(),
filterKeys: ['name', 'desc', 'author']
}
},
computed: {
@@ -231,8 +233,9 @@ export default {
}
const search = this.marketSearch.toLowerCase();
return this.pluginMarketData.filter(plugin =>
plugin.name.toLowerCase().includes(search)
);
this.filterKeys.some(key =>
plugin[key]?.toLowerCase().includes(search)
));
},
pinnedPlugins() {
return this.pluginMarketData.filter(plugin => plugin?.pinned);
+1
View File
@@ -300,5 +300,6 @@ export default {
<style scoped>
.platform-page {
padding: 20px;
padding-top: 8px;
}
</style>
+1
View File
@@ -324,5 +324,6 @@ export default {
<style scoped>
.provider-page {
padding: 20px;
padding-top: 8px;
}
</style>
+1 -4
View File
@@ -5,7 +5,7 @@
<v-list lines="two">
<v-list-subheader>网络</v-list-subheader>
<v-list-item subtitle="设置下载插件或者更新 AstrBot 时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效" title="GitHub 加速地址">
<v-list-item subtitle="设置下载插件或者更新 AstrBot 时所用的 GitHub 加速地址。这在中国大陆的网络环境有效。可以自定义,输入结果实时生效。所有地址均不保证稳定性,如果在更新插件/项目时出现报错,请首先检查加速地址是否能正常使用。" title="GitHub 加速地址">
<v-combobox variant="outlined" style="width: 100%; margin-top: 16px;" v-model="selectedGitHubProxy" :items="githubProxies"
label="选择 GitHub 加速地址">
@@ -41,11 +41,8 @@ export default {
data() {
return {
githubProxies: [
"https://ghproxy.cn",
"https://gh.llkk.cc",
"https://ghproxy.net",
"https://gitproxy.click",
"https://github.tbedu.top"
],
selectedGitHubProxy: "",
}
+1
View File
@@ -589,6 +589,7 @@ export default {
<style scoped>
.tools-page {
padding: 20px;
padding-top: 8px;
}
.server-card {
+1
View File
@@ -25,6 +25,7 @@ dashscope
python-telegram-bot
wechatpy
dingtalk-stream
defusedxml
mcp
certifi
pip