feat: 初步完成原生流式请求逻辑
This commit is contained in:
@@ -2,10 +2,11 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, AsyncGenerator
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
@@ -87,8 +88,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
for m in models
|
||||
if "generateContent" in m.supported_actions
|
||||
]
|
||||
except Exception as e:
|
||||
raise Exception(f"获取模型列表失败: {e}")
|
||||
except APIError as e:
|
||||
raise Exception(f"获取模型列表失败: {e.message}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_conversation(payloads: Dict) -> List[types.Content]:
|
||||
@@ -168,17 +169,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
) -> LLMResponse:
|
||||
"""非流式请求 Gemini API"""
|
||||
tool_list = []
|
||||
if func_desc := tools.get_func_desc_google_genai_style() if tools else None:
|
||||
tool_list.append(
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
)
|
||||
tool_list = None
|
||||
if tools:
|
||||
func_desc = tools.get_func_desc_google_genai_style()
|
||||
if func_desc:
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
system_instruction = ""
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "system":
|
||||
system_instruction = message["content"]
|
||||
break
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
@@ -217,20 +219,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "Developer instruction is not enabled" in error_msg:
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in error_msg:
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tool_list = None
|
||||
elif (
|
||||
"Multi-modal output is not supported"
|
||||
or "Model does not support the requested response modalities"
|
||||
in error_msg
|
||||
in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
||||
@@ -241,8 +242,95 @@ class ProviderGoogleGenAI(Provider):
|
||||
continue
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall, temperature: float = 0.7
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
tool_list = None
|
||||
if tools:
|
||||
func_desc = tools.get_func_desc_google_genai_style()
|
||||
if func_desc:
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings
|
||||
if self.safety_settings
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tool_list = None
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
if not result:
|
||||
raise Exception("API 返回异常")
|
||||
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
if chunk.candidates[0].content.parts and any(
|
||||
part.function_call for part in chunk.candidates[0].content.parts
|
||||
):
|
||||
response = LLMResponse("assistant", is_chunk=False)
|
||||
response.result_chain = self._process_content_parts(chunk, response)
|
||||
yield response
|
||||
break
|
||||
|
||||
if chunk.text:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _process_content_parts(
|
||||
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||
) -> MessageChain:
|
||||
"""处理内容部分并构建消息链"""
|
||||
finish_reason = result.candidates[0].finish_reason
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||
@@ -259,20 +347,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.debug(result.candidates)
|
||||
raise Exception("API 返回的内容为空。")
|
||||
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
result_parts, llm_response
|
||||
)
|
||||
|
||||
return llm_response
|
||||
|
||||
@staticmethod
|
||||
def _process_content_parts(
|
||||
parts: types.Part, llm_response: LLMResponse
|
||||
) -> MessageChain:
|
||||
"""处理内容部分并构建消息链"""
|
||||
chain = []
|
||||
part: types.Part
|
||||
for part in parts:
|
||||
for part in result_parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif part.function_call:
|
||||
@@ -322,19 +399,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool, temp)
|
||||
break
|
||||
except Exception as e:
|
||||
if "429" in str(e) or "API key not valid" in str(e):
|
||||
except APIError as e:
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
self.set_key(random.choice(keys))
|
||||
logger.info(
|
||||
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
@@ -347,30 +424,62 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
temp = kwargs.get("temperature", 0.7) # 暂定默认温度为0.7
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
async for response in self._query_stream(payloads, func_tool, temp):
|
||||
yield response
|
||||
break
|
||||
except APIError as e:
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
self.set_key(random.choice(keys))
|
||||
logger.info(
|
||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.chosen_api_key
|
||||
|
||||
Reference in New Issue
Block a user