Merge pull request #1990 from AstrBotDevs/fix-stream-multi-tool-use-err
fix: Multi-turn tools use error when using streaming output
This commit is contained in:
@@ -300,7 +300,11 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
if not isinstance(tool_calls_result, list):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from ..register import register_provider_adapter
|
||||
TEMP_DIR = Path("data/temp/azure_tts")
|
||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class OTTSProvider:
|
||||
def __init__(self, config: Dict):
|
||||
self.skey = config["OTTS_SKEY"]
|
||||
@@ -70,12 +71,12 @@ class OTTSProvider:
|
||||
"style": voice_params["style"],
|
||||
"role": voice_params["role"],
|
||||
"rate": voice_params["rate"],
|
||||
"volume": voice_params["volume"]
|
||||
"volume": voice_params["volume"],
|
||||
},
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"UAK": "AstrBot/AzureTTS"
|
||||
}
|
||||
"UAK": "AstrBot/AzureTTS",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -88,14 +89,19 @@ class OTTSProvider:
|
||||
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
|
||||
|
||||
class AzureNativeProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
|
||||
self.subscription_key = provider_config.get(
|
||||
"azure_tts_subscription_key", ""
|
||||
).strip()
|
||||
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
||||
raise ValueError("无效的Azure订阅密钥")
|
||||
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
||||
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
self.endpoint = (
|
||||
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
self.client = None
|
||||
self.token = None
|
||||
self.token_expire = 0
|
||||
@@ -104,15 +110,17 @@ class AzureNativeProvider(TTSProvider):
|
||||
"style": provider_config.get("azure_tts_style", "cheerful"),
|
||||
"role": provider_config.get("azure_tts_role", "Boy"),
|
||||
"rate": provider_config.get("azure_tts_rate", "1"),
|
||||
"volume": provider_config.get("azure_tts_volume", "100")
|
||||
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
|
||||
})
|
||||
self.client = AsyncClient(
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
|
||||
}
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
@@ -120,10 +128,11 @@ class AzureNativeProvider(TTSProvider):
|
||||
await self.client.aclose()
|
||||
|
||||
async def _refresh_token(self):
|
||||
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
||||
token_url = (
|
||||
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
||||
)
|
||||
response = await self.client.post(
|
||||
token_url,
|
||||
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
|
||||
token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.token = response.text
|
||||
@@ -150,8 +159,8 @@ class AzureNativeProvider(TTSProvider):
|
||||
content=ssml,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"User-Agent": f"AstrBot/{VERSION}"
|
||||
}
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -160,6 +169,7 @@ class AzureNativeProvider(TTSProvider):
|
||||
f.write(chunk)
|
||||
return str(file_path.resolve())
|
||||
|
||||
|
||||
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
|
||||
class AzureTTSProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
@@ -183,7 +193,7 @@ class AzureTTSProvider(TTSProvider):
|
||||
error_msg = (
|
||||
f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n"
|
||||
f"错误详情: {e.msg}\n"
|
||||
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}"
|
||||
f"错误上下文: {json_str[max(0, e.pos - 30) : e.pos + 30]}"
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except KeyError as e:
|
||||
@@ -202,8 +212,8 @@ class AzureTTSProvider(TTSProvider):
|
||||
"style": self.provider_config.get("azure_tts_style"),
|
||||
"role": self.provider_config.get("azure_tts_role"),
|
||||
"rate": self.provider_config.get("azure_tts_rate"),
|
||||
"volume": self.provider_config.get("azure_tts_volume")
|
||||
}
|
||||
"volume": self.provider_config.get("azure_tts_volume"),
|
||||
},
|
||||
)
|
||||
else:
|
||||
async with self.provider as provider:
|
||||
|
||||
@@ -18,7 +18,7 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona = None,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
@@ -65,7 +65,7 @@ class ProviderDify(Provider):
|
||||
if image_urls is None:
|
||||
image_urls = []
|
||||
result = ""
|
||||
session_id = session_id or kwargs.get("user") # 1734
|
||||
session_id = session_id or kwargs.get("user") # 1734
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
files_payload = []
|
||||
@@ -84,13 +84,11 @@ class ProviderDify(Provider):
|
||||
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
|
||||
)
|
||||
continue
|
||||
files_payload.append(
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response["id"],
|
||||
}
|
||||
)
|
||||
files_payload.append({
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response["id"],
|
||||
})
|
||||
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
|
||||
@@ -14,7 +14,7 @@ import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -544,13 +544,13 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: str = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
@@ -566,7 +566,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
if not isinstance(tool_calls_result, list):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona = None,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
@@ -525,9 +525,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}}
|
||||
)
|
||||
user_content["content"].append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
})
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
@@ -5,12 +5,12 @@ import os
|
||||
import traceback
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import requests
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
@@ -22,7 +22,9 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
self.cluster = provider_config.get("volcengine_cluster", "")
|
||||
self.voice_type = provider_config.get("volcengine_voice_type", "")
|
||||
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
|
||||
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
|
||||
self.api_base = provider_config.get(
|
||||
"api_base", "https://openspeech.bytedance.com/api/v1/tts"
|
||||
)
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
|
||||
def _build_request_payload(self, text: str) -> dict:
|
||||
@@ -30,11 +32,9 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
"app": {
|
||||
"appid": self.appid,
|
||||
"token": self.api_key,
|
||||
"cluster": self.cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4())
|
||||
"cluster": self.cluster,
|
||||
},
|
||||
"user": {"uid": str(uuid.uuid4())},
|
||||
"audio": {
|
||||
"voice_type": self.voice_type,
|
||||
"encoding": "mp3",
|
||||
@@ -48,60 +48,61 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
"text_type": "plain",
|
||||
"operation": "query",
|
||||
"with_frontend": 1,
|
||||
"frontend_type": "unitTson"
|
||||
}
|
||||
"frontend_type": "unitTson",
|
||||
},
|
||||
}
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""异步方法获取语音文件路径"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer; {self.api_key}"
|
||||
"Authorization": f"Bearer; {self.api_key}",
|
||||
}
|
||||
|
||||
|
||||
payload = self._build_request_payload(text)
|
||||
|
||||
|
||||
logger.debug(f"请求头: {headers}")
|
||||
logger.debug(f"请求 URL: {self.api_base}")
|
||||
logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.api_base,
|
||||
data=json.dumps(payload),
|
||||
data=json.dumps(payload),
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
logger.debug(f"响应状态码: {response.status}")
|
||||
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(f"响应内容: {response_text[:200]}...")
|
||||
|
||||
|
||||
if response.status == 200:
|
||||
resp_data = json.loads(response_text)
|
||||
|
||||
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
|
||||
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: open(file_path, "wb").write(audio_data)
|
||||
None, lambda: open(file_path, "wb").write(audio_data)
|
||||
)
|
||||
|
||||
|
||||
return file_path
|
||||
else:
|
||||
error_msg = resp_data.get("message", "未知错误")
|
||||
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
|
||||
|
||||
raise Exception(
|
||||
f"火山引擎 TTS API 请求失败: {response.status}, {response_text}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
|
||||
raise Exception(f"火山引擎 TTS 异常: {str(e)}")
|
||||
raise Exception(f"火山引擎 TTS 异常: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user