From 97e4d169b3ee991e5e9aed3107c05c0182d2fc55 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 2 Feb 2025 11:23:33 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=9C=AA=E5=90=AF=E7=94=A8=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=8F=90=E4=BE=9B=E5=95=86=E6=97=B6=E7=9A=84=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 4 +- .../process_stage/method/dify_request.py | 4 + astrbot/core/pipeline/process_stage/stage.py | 5 ++ astrbot/core/platform/sources/wecom/client.py | 48 ++++++++++ .../platform/sources/wecom/webchat_adapter.py | 90 +++++++++++++++++++ .../platform/sources/wecom/webchat_event.py | 41 +++++++++ packages/astrbot/main.py | 31 +++++++ 7 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 astrbot/core/platform/sources/wecom/client.py create mode 100644 astrbot/core/platform/sources/wecom/webchat_adapter.py create mode 100644 astrbot/core/platform/sources/wecom/webchat_event.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6c262248f..0fc1c87f2 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -324,7 +324,7 @@ CONFIG_METADATA_2 = { "type": "openai_chat_completion", "enable": True, "key": [], - "api_base": "", + "api_base": "https://api.openai.com/v1", "model_config": { "model": "gpt-4o-mini", }, @@ -453,7 +453,7 @@ CONFIG_METADATA_2 = { "api_base": { "description": "API Base URL", "type": "string", - "hint": "API Base URL 请在在模型提供商处获得。如使用时出现了 404 报错,可以尝试在地址末尾加上 `/v1`。", + "hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1", "obvious_hint": True, }, "base_model_path": { diff --git a/astrbot/core/pipeline/process_stage/method/dify_request.py b/astrbot/core/pipeline/process_stage/method/dify_request.py index 9a52fbb03..acdb22010 100644 --- a/astrbot/core/pipeline/process_stage/method/dify_request.py +++ b/astrbot/core/pipeline/process_stage/method/dify_request.py @@ -21,6 +21,10 @@ class DifyRequestSubStage(Stage): req: ProviderRequest = None provider = self.ctx.plugin_manager.context.get_using_provider() + + if not provider: + return + if provider.meta().type != "dify": return diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index e8b3db18a..1d9af801f 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -49,6 +49,11 @@ class ProcessStage(Stage): if not event._has_send_oper and event.is_at_or_wake_command: if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result(): provider = self.ctx.plugin_manager.context.get_using_provider() + + if not provider: + logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。") + return + match provider.meta().type: case "dify": async for _ in self.dify_request_sub_stage.process(event): diff --git a/astrbot/core/platform/sources/wecom/client.py b/astrbot/core/platform/sources/wecom/client.py new file mode 100644 index 000000000..1ede1fe46 --- /dev/null +++ b/astrbot/core/platform/sources/wecom/client.py @@ -0,0 +1,48 @@ + +import asyncio +import aiohttp +import quart +import base64 + +from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType +from astrbot.api.message_components import Plain, Image, At, Record +from astrbot.api import logger, sp +from .downloader import GeweDownloader +from astrbot.core.utils.io import download_image_by_url + + +class WeComClient(): + def __init__(self, config: dict): + self.base_url = base_url + if self.base_url.endswith('/'): + self.base_url = self.base_url[:-1] + + self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口 + self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/" + + self.base_url += "/v2/api" + + logger.info(f"wecom API: {self.base_url}") + logger.info(f"Gewechat 下载 API: {self.download_base_url}") + + if isinstance(port, str): + port = int(port) + + self.token = None + self.headers = {} + self.nickname = nickname + self.appid = sp.get(f"gewechat-appid-{nickname}", "") + + self.server = quart.Quart(__name__) + self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST']) + self.server.add_url_rule('/astrbot-gewechat/file/', view_func=self.handle_file, methods=['GET']) + + self.host = host + self.port = port + self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback" + self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file" + + self.event_queue = event_queue + + self.multimedia_downloader = None + \ No newline at end of file diff --git a/astrbot/core/platform/sources/wecom/webchat_adapter.py b/astrbot/core/platform/sources/wecom/webchat_adapter.py new file mode 100644 index 000000000..19c01c336 --- /dev/null +++ b/astrbot/core/platform/sources/wecom/webchat_adapter.py @@ -0,0 +1,90 @@ +import time +import asyncio +import uuid +import os +from typing import Awaitable, Any +from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Plain, Image, Record # noqa: F403 +from astrbot.api import logger +from astrbot.core import web_chat_queue, web_chat_back_queue +from .webchat_event import WebChatMessageEvent +from astrbot.core.platform.astr_message_event import MessageSesion +from ...register import register_platform_adapter + + +@register_platform_adapter("wecom", "wecom") +class WecomAdapter(Platform): + def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + super().__init__(event_queue) + + self.config = platform_config + self.settings = platform_settings + self.unique_session = platform_settings['unique_session'] + + self.metadata = PlatformMetadata( + "wecom", + "wecom", + ) + + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + + await super().send_by_session(session, message_chain) + + async def convert_message(self, data: tuple) -> AstrBotMessage: + username, cid, payload = data + + + abm = AstrBotMessage() + abm.self_id = "webchat" + abm.tag = "webchat" + abm.sender = MessageMember(username, username) + + abm.type = MessageType.FRIEND_MESSAGE + + abm.session_id = f"webchat!{username}!{cid}" + + abm.message_id = str(uuid.uuid4()) + abm.message = [] + + if payload['message']: + abm.message.append(Plain(payload['message'])) + if payload['image_url']: + if isinstance(payload['image_url'], list): + for img in payload['image_url']: + abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img))) + else: + abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url']))) + if payload['audio_url']: + if isinstance(payload['audio_url'], list): + for audio in payload['audio_url']: + path = os.path.join(self.imgs_dir, audio) + abm.message.append(Record(file=path, path=path)) + else: + path = os.path.join(self.imgs_dir, payload['audio_url']) + abm.message.append(Record(file=path, path=path)) + + logger.debug(f"WebChatAdapter: {abm.message}") + + message_str = payload['message'] + abm.timestamp = int(time.time()) + abm.message_str = message_str + abm.raw_message = data + return abm + + def run(self) -> Awaitable[Any]: + pass + + def meta(self) -> PlatformMetadata: + return self.metadata + + async def handle_msg(self, message: AstrBotMessage): + + message_event = WebChatMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id + ) + + self.commit_event(message_event) \ No newline at end of file diff --git a/astrbot/core/platform/sources/wecom/webchat_event.py b/astrbot/core/platform/sources/wecom/webchat_event.py new file mode 100644 index 000000000..f447a616c --- /dev/null +++ b/astrbot/core/platform/sources/wecom/webchat_event.py @@ -0,0 +1,41 @@ +import os +import uuid +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.message_components import Plain, Image +from astrbot.core.utils.io import file_to_base64, download_image_by_url +from astrbot.core import web_chat_back_queue + +class WebChatMessageEvent(AstrMessageEvent): + def __init__(self, message_str, message_obj, platform_meta, session_id): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.imgs_dir = "data/webchat/imgs" + os.makedirs(self.imgs_dir, exist_ok=True) + + async def send(self, message: MessageChain): + if not message: + web_chat_back_queue.put_nowait(None) + return + + cid = self.session_id.split("!")[-1] + + for comp in message.chain: + if isinstance(comp, Plain): + web_chat_back_queue.put_nowait((comp.text, cid)) + elif isinstance(comp, Image): + # save image to local + filename = str(uuid.uuid4()) + ".jpg" + path = os.path.join(self.imgs_dir, filename) + if comp.file and comp.file.startswith("file:///"): + ph = comp.file[8:] + with open(path, "wb") as f: + with open(ph, "rb") as f2: + f.write(f2.read()) + elif comp.file and comp.file.startswith("http"): + await download_image_by_url(comp.file, path=path) + else: + with open(path, "wb") as f: + with open(comp.file, "rb") as f2: + f.write(f2.read()) + web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid)) + web_chat_back_queue.put_nowait(None) + await super().send(message) \ No newline at end of file diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index b52a6e824..568d8b13c 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -203,6 +203,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo async def provider(self, event: AstrMessageEvent, idx: int = None): '''查看或者切换 LLM Provider''' + if not self.context.get_using_provider(): + event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + return + if idx is None: ret = "## 当前载入的 LLM 提供商\n" for idx, llm in enumerate(self.context.get_all_providers()): @@ -227,6 +231,11 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("reset") async def reset(self, message: AstrMessageEvent): + + if not self.context.get_using_provider(): + message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + return + await self.context.get_using_provider().forget(message.session_id) ret = "清除会话 LLM 聊天历史成功。" if self.ltm: @@ -237,6 +246,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("model") async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None): + + if not self.context.get_using_provider(): + message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + return + + if idx_or_name is None: models = [] try: @@ -277,6 +292,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("history") async def his(self, message: AstrMessageEvent, page: int = 1): + + + if not self.context.get_using_provider(): + message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + return + size_per_page = 3 contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page) @@ -296,6 +317,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") async def key(self, message: AstrMessageEvent, index: int=None): + + if not self.context.get_using_provider(): + message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + return if index is None: keys_data = self.context.get_using_provider().get_keys() @@ -324,6 +349,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("persona") async def persona(self, message: AstrMessageEvent): + + if not self.context.get_using_provider(): + message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + return + + l = message.message_str.split(" ") curr_persona_name = "无"