feat: 初步完成原生流式请求逻辑

This commit is contained in:
Raven95676
2025-04-11 23:45:30 +08:00
parent b493a808fe
commit bd24cf3ea4
+165 -56
View File
@@ -2,10 +2,11 @@ import asyncio
import base64
import json
import random
from typing import Dict, List, Optional
from typing import Dict, List, Optional, AsyncGenerator
from google import genai
from google.genai import types
from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot import logger
@@ -87,8 +88,8 @@ class ProviderGoogleGenAI(Provider):
for m in models
if "generateContent" in m.supported_actions
]
except Exception as e:
raise Exception(f"获取模型列表失败: {e}")
except APIError as e:
raise Exception(f"获取模型列表失败: {e.message}")
@staticmethod
def _prepare_conversation(payloads: Dict) -> List[types.Content]:
@@ -168,17 +169,18 @@ class ProviderGoogleGenAI(Provider):
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> LLMResponse:
"""非流式请求 Gemini API"""
tool_list = []
if func_desc := tools.get_func_desc_google_genai_style() if tools else None:
tool_list.append(
types.Tool(function_declarations=func_desc["function_declarations"])
)
tool_list = None
if tools:
func_desc = tools.get_func_desc_google_genai_style()
if func_desc:
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
system_instruction = ""
for message in payloads["messages"]:
if message["role"] == "system":
system_instruction = message["content"]
break
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
conversation = self._prepare_conversation(payloads)
@@ -217,20 +219,19 @@ class ProviderGoogleGenAI(Provider):
break
except Exception as e:
error_msg = str(e)
if "Developer instruction is not enabled" in error_msg:
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in error_msg:
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tool_list = None
elif (
"Multi-modal output is not supported"
or "Model does not support the requested response modalities"
in error_msg
in e.message
):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
@@ -241,8 +242,95 @@ class ProviderGoogleGenAI(Provider):
continue
llm_response = LLMResponse("assistant")
result_parts: Optional[types.Part] = result.candidates[0].content.parts
llm_response.result_chain = self._process_content_parts(result, llm_response)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
tool_list = None
if tools:
func_desc = tools.get_func_desc_google_genai_style()
if func_desc:
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
conversation = self._prepare_conversation(payloads)
result = None
while True:
try:
result = await self.client.models.generate_content_stream(
model=self.get_model(),
contents=conversation,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
tools=tool_list,
safety_settings=self.safety_settings
if self.safety_settings
else None,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
),
)
break
except APIError as e:
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tool_list = None
else:
raise
continue
if not result:
raise Exception("API 返回异常")
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
if chunk.candidates[0].content.parts and any(
part.function_call for part in chunk.candidates[0].content.parts
):
response = LLMResponse("assistant", is_chunk=False)
response.result_chain = self._process_content_parts(chunk, response)
yield response
break
if chunk.text:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response
if chunk.candidates[0].finish_reason:
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
)
yield llm_response
break
@staticmethod
def _process_content_parts(
result: types.GenerateContentResponse, llm_response: LLMResponse
) -> MessageChain:
"""处理内容部分并构建消息链"""
finish_reason = result.candidates[0].finish_reason
result_parts: Optional[types.Part] = result.candidates[0].content.parts
if finish_reason == types.FinishReason.SAFETY:
raise Exception("模型生成内容未通过用户定义的内容安全检查")
@@ -259,20 +347,9 @@ class ProviderGoogleGenAI(Provider):
logger.debug(result.candidates)
raise Exception("API 返回的内容为空。")
llm_response.result_chain = self._process_content_parts(
result_parts, llm_response
)
return llm_response
@staticmethod
def _process_content_parts(
parts: types.Part, llm_response: LLMResponse
) -> MessageChain:
"""处理内容部分并构建消息链"""
chain = []
part: types.Part
for part in parts:
for part in result_parts:
if part.text:
chain.append(Comp.Plain(part.text))
elif part.function_call:
@@ -322,19 +399,19 @@ class ProviderGoogleGenAI(Provider):
try:
llm_response = await self._query(payloads, func_tool, temp)
break
except Exception as e:
if "429" in str(e) or "API key not valid" in str(e):
except APIError as e:
if e.code == 429 or "API key not valid" in e.message:
keys.remove(self.chosen_api_key)
if len(keys) > 0:
self.set_key(random.choice(keys))
logger.info(
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
)
await asyncio.sleep(1)
continue
else:
logger.error(
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
@@ -347,30 +424,62 @@ class ProviderGoogleGenAI(Provider):
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
) -> AsyncGenerator[LLMResponse, None]:
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
payloads = {"messages": context_query, **model_config}
retry = 10
keys = self.api_keys.copy()
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
for _ in range(retry):
try:
async for response in self._query_stream(payloads, func_tool, temp):
yield response
break
except APIError as e:
if e.code == 429 or "API key not valid" in e.message:
keys.remove(self.chosen_api_key)
if len(keys) > 0:
self.set_key(random.choice(keys))
logger.info(
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
)
await asyncio.sleep(1)
continue
else:
logger.error(
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
)
raise e
def get_current_key(self) -> str:
return self.chosen_api_key