fix: 修复gewechat无法at人和发语音失败的问题 #447 #438

This commit is contained in:
Soulter
2025-02-10 18:11:22 +08:00
parent 9f02dd13ff
commit aa276ca6af
5 changed files with 54 additions and 20 deletions
+2 -1
View File
@@ -21,4 +21,5 @@ node_modules/
.DS_Store
package-lock.json
package.json
venv/*
venv/*
packages/python_interpreter/workplace
@@ -312,12 +312,14 @@ class SimpleGewechatClient():
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
async def post_text(self, to_wxid, content: str):
async def post_text(self, to_wxid, content: str, ats: str = ""):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"content": content,
}
if ats:
payload['ats'] = ats
async with aiohttp.ClientSession() as session:
async with session.post(
@@ -6,7 +6,7 @@ from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image, Record, At, File, Reply
from .client import SimpleGewechatClient
def get_wav_duration(file_path):
@@ -15,6 +15,8 @@ def get_wav_duration(file_path):
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
elif n_frames == 0:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
@@ -43,9 +45,31 @@ class GewechatPlatformEvent(AstrMessageEvent):
logger.error("无法获取到 to_wxid。")
return
# 检查@
ats = []
ats_names = []
for comp in message.chain:
if isinstance(comp, At):
ats.append(comp.qq)
ats_names.append(comp.name)
has_at = False
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
text = comp.text
payload = {
"to_wxid": to_wxid,
"content": text,
}
if not has_at and ats:
ats = f"{','.join(ats)}"
ats_names = f"@{' @'.join(ats_names)}"
text = f"{ats_names} {text}"
payload["content"] = text
payload["ats"] = ats
has_at = True
await self.client.post_text(**payload)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
@@ -80,6 +104,7 @@ class GewechatPlatformEvent(AstrMessageEvent):
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
logger.info("开始转换语音文件: " + record_path)
duration = await wav_to_tencent_silk(record_path, silk_path)
print(f"duration: {duration}, {silk_path}")
+1 -1
View File
@@ -52,7 +52,7 @@ class ProviderRequest():
@dataclass
class LLMResponse:
role: str
'''角色'''
'''角色, assistant, tool, err'''
completion_text: str = ""
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
+21 -15
View File
@@ -22,21 +22,27 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
'''返回 duration'''
import pysilk
with wave.open(wav_path, 'rb') as wav:
wav_data = wav.readframes(wav.getnframes())
wav_data = BytesIO(wav_data)
output_io = BytesIO()
pysilk.encode(wav_data, output_io, 24000, 24000)
output_io.seek(0)
try:
import pilk
except (ImportError, ModuleNotFoundError) as _:
raise Exception("pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库")
# with wave.open(wav_path, 'rb') as wav:
# wav_data = wav.readframes(wav.getnframes())
# wav_data = BytesIO(wav_data)
# output_io = BytesIO()
# pysilk.encode(wav_data, output_io, 24000, 24000)
# output_io.seek(0)
# 在首字节添加 \x02,去除结尾的\xff\xff
silk_data = output_io.read()
silk_data_with_prefix = b'\x02' + silk_data[:-2]
# # 在首字节添加 \x02,去除结尾的\xff\xff
# silk_data = output_io.read()
# silk_data_with_prefix = b'\x02' + silk_data[:-2]
# return BytesIO(silk_data_with_prefix)
with open(output_path, "wb") as f:
f.write(silk_data_with_prefix)
# # return BytesIO(silk_data_with_prefix)
# with open(output_path, "wb") as f:
# f.write(silk_data_with_prefix)
return 0
# return 0
with wave.open(wav_path, 'rb') as wav:
rate = wav.getframerate()
duration = pilk.encode(wav_path, output_path, pcm_rate=rate, tencent=True)
return duration