From 752d13b1b1d4fdb830e3d2c36f93607602dcbe69 Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Wed, 7 May 2025 19:04:24 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20gemini=5Fsource=20?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E9=BB=98=E8=AE=A4=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index bf2349533..fb47143d4 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,7 +3,7 @@ import base64 import json import logging import random -from typing import Dict, List, Optional +from typing import Optional from collections.abc import AsyncGenerator from google import genai @@ -15,7 +15,7 @@ from astrbot import logger from astrbot.api.provider import Personality, Provider from astrbot.core.db import BaseDatabase from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.utils.io import download_image_by_url @@ -65,7 +65,7 @@ class ProviderGoogleGenAI(Provider): db_helper, default_persona, ) - self.api_keys: List = provider_config.get("key", []) + self.api_keys: list = provider_config.get("key", []) self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout: int = int(provider_config.get("timeout", 180)) @@ -99,7 +99,7 @@ class ProviderGoogleGenAI(Provider): and threshold_str in self.THRESHOLD_MAPPING ] - async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool: + async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool: """处理API错误,返回是否需要重试""" if e.code == 429 or "API key not valid" in e.message: keys.remove(self.chosen_api_key) @@ -126,7 +126,7 @@ class ProviderGoogleGenAI(Provider): payloads: dict, tools: Optional[FuncCall] = None, system_instruction: Optional[str] = None, - modalities: Optional[List[str]] = None, + modalities: Optional[list[str]] = None, temperature: float = 0.7, ) -> types.GenerateContentConfig: """准备查询配置""" @@ -195,7 +195,7 @@ class ProviderGoogleGenAI(Provider): ), ) - def _prepare_conversation(self, payloads: Dict) -> List[types.Content]: + def _prepare_conversation(self, payloads: dict) -> list[types.Content]: """准备 Gemini SDK 的 Content 列表""" def create_text_part(text: str) -> types.Part: @@ -220,7 +220,7 @@ class ProviderGoogleGenAI(Provider): else: contents.append(content_cls(parts=part)) - gemini_contents: List[types.Content] = [] + gemini_contents: list[types.Content] = [] native_tool_enabled = any( [ self.provider_config.get("gm_native_coderunner", False), @@ -464,13 +464,15 @@ class ProviderGoogleGenAI(Provider): self, prompt: str, session_id: str = None, - image_urls: List[str] = None, + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts=[], - system_prompt=None, - tool_calls_result=None, + contexts: list = None, + system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: + if contexts is None: + contexts = [] new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -504,13 +506,15 @@ class ProviderGoogleGenAI(Provider): self, prompt: str, session_id: str = None, - image_urls: List[str] = [], + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts=[], - system_prompt=None, - tool_calls_result=None, + contexts: str = None, + system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: + if contexts is None: + contexts = [] new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -556,14 +560,14 @@ class ProviderGoogleGenAI(Provider): def get_current_key(self) -> str: return self.chosen_api_key - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key): self.chosen_api_key = key self._init_client() - async def assemble_context(self, text: str, image_urls: List[str] = None): + async def assemble_context(self, text: str, image_urls: list[str] = None): """ 组装上下文。 """