perf: 优化 gemini_source 方法默认参数

This commit is contained in:
Raven95676
2025-05-07 19:04:24 +08:00
parent 54c0dc1b2b
commit 752d13b1b1
+21 -17
View File
@@ -3,7 +3,7 @@ import base64
import json
import logging
import random
from typing import Dict, List, Optional
from typing import Optional
from collections.abc import AsyncGenerator
from google import genai
@@ -15,7 +15,7 @@ from astrbot import logger
from astrbot.api.provider import Personality, Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.utils.io import download_image_by_url
@@ -65,7 +65,7 @@ class ProviderGoogleGenAI(Provider):
db_helper,
default_persona,
)
self.api_keys: List = provider_config.get("key", [])
self.api_keys: list = provider_config.get("key", [])
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout: int = int(provider_config.get("timeout", 180))
@@ -99,7 +99,7 @@ class ProviderGoogleGenAI(Provider):
and threshold_str in self.THRESHOLD_MAPPING
]
async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
"""处理API错误,返回是否需要重试"""
if e.code == 429 or "API key not valid" in e.message:
keys.remove(self.chosen_api_key)
@@ -126,7 +126,7 @@ class ProviderGoogleGenAI(Provider):
payloads: dict,
tools: Optional[FuncCall] = None,
system_instruction: Optional[str] = None,
modalities: Optional[List[str]] = None,
modalities: Optional[list[str]] = None,
temperature: float = 0.7,
) -> types.GenerateContentConfig:
"""准备查询配置"""
@@ -195,7 +195,7 @@ class ProviderGoogleGenAI(Provider):
),
)
def _prepare_conversation(self, payloads: Dict) -> List[types.Content]:
def _prepare_conversation(self, payloads: dict) -> list[types.Content]:
"""准备 Gemini SDK 的 Content 列表"""
def create_text_part(text: str) -> types.Part:
@@ -220,7 +220,7 @@ class ProviderGoogleGenAI(Provider):
else:
contents.append(content_cls(parts=part))
gemini_contents: List[types.Content] = []
gemini_contents: list[types.Content] = []
native_tool_enabled = any(
[
self.provider_config.get("gm_native_coderunner", False),
@@ -464,13 +464,15 @@ class ProviderGoogleGenAI(Provider):
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
@@ -504,13 +506,15 @@ class ProviderGoogleGenAI(Provider):
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
contexts: str = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
@@ -556,14 +560,14 @@ class ProviderGoogleGenAI(Provider):
def get_current_key(self) -> str:
return self.chosen_api_key
def get_keys(self) -> List[str]:
def get_keys(self) -> list[str]:
return self.api_keys
def set_key(self, key):
self.chosen_api_key = key
self._init_client()
async def assemble_context(self, text: str, image_urls: List[str] = None):
async def assemble_context(self, text: str, image_urls: list[str] = None):
"""
组装上下文。
"""