From 750a93a1aa4a28b4fc6054c24dede46b816024a0 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 2 Dec 2024 19:31:33 +0800 Subject: [PATCH] =?UTF-8?q?remove:=20=E7=A7=BB=E9=99=A4=E4=BA=86=20nakuru-?= =?UTF-8?q?project=20=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 但仍然使用其对 OneBot 的数据格式封装。 --- astrbot/api/__init__.py | 5 +- astrbot/api/message_components.py | 1 + astrbot/core/__init__.py | 2 +- astrbot/core/config/astrbot_config.py | 49 +- astrbot/core/config/default.py | 66 ++- astrbot/core/core_lifecycle.py | 16 +- astrbot/core/db/__init__.py | 2 +- astrbot/core/db/sqlite.py | 2 +- astrbot/core/event_bus.py | 6 +- astrbot/core/message/components.py | 443 ++++++++++++++++++ .../{ => message}/message_event_handler.py | 12 +- .../{ => message}/message_event_result.py | 12 +- astrbot/core/platform/astr_message_event.py | 8 +- astrbot/core/platform/astrbot_message.py | 2 +- astrbot/core/platform/platform.py | 4 +- astrbot/core/plugin/__init__.py | 2 +- astrbot/core/plugin/context.py | 23 +- astrbot/core/plugin/plugin_manager.py | 8 +- astrbot/core/plugin/updator.py | 4 +- astrbot/core/provider/__init__.py | 2 +- astrbot/core/provider/provider.py | 54 ++- astrbot/core/updator.py | 6 +- astrbot/core/utils/func_call.py | 2 +- astrbot/core/utils/metrics.py | 2 +- astrbot/core/utils/t2i/local_strategy.py | 2 +- astrbot/core/utils/t2i/network_strategy.py | 4 +- astrbot/core/utils/t2i/renderer.py | 2 +- astrbot/core/zip_updator.py | 4 +- astrbot/dashboard/dashboard_lifecycle.py | 6 +- astrbot/dashboard/routes/auth.py | 2 +- astrbot/dashboard/routes/config.py | 8 +- astrbot/dashboard/routes/log.py | 4 +- astrbot/dashboard/routes/plugin.py | 8 +- astrbot/dashboard/routes/route.py | 2 +- astrbot/dashboard/routes/stat.py | 10 +- astrbot/dashboard/routes/static_file.py | 2 +- astrbot/dashboard/routes/update.py | 6 +- astrbot/dashboard/server.py | 19 +- astrbot/main.py => main.py | 10 +- .../aiocqhttp_message_event.py | 16 +- .../aiocqhttp_platform_adapter.py | 4 +- .../qqofficial_message_event.py | 2 +- .../qqofficial_platform_adapter.py | 2 +- .../wechat_message_event.py | 7 +- .../wechat_platform_adapter.py | 21 +- packages/astrbot_plugin_openai/commands.py | 29 +- packages/astrbot_plugin_openai/main.py | 43 +- .../astrbot_plugin_openai/openai_adapter.py | 374 +++++---------- requirements.txt | 4 +- 49 files changed, 904 insertions(+), 420 deletions(-) create mode 100644 astrbot/api/message_components.py create mode 100644 astrbot/core/message/components.py rename astrbot/core/{ => message}/message_event_handler.py (96%) rename astrbot/core/{ => message}/message_event_result.py (75%) rename astrbot/main.py => main.py (93%) diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py index 783ff9fff..40a31adbb 100644 --- a/astrbot/api/__init__.py +++ b/astrbot/api/__init__.py @@ -1,10 +1,9 @@ from astrbot.core.plugin import Context from astrbot.core.platform import AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata -from astrbot.core.message_event_result import MessageEventResult, MessageChain, CommandResult -from astrbot.core.provider import Provider +from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, CommandResult +from astrbot.core.provider import Provider, Personality from astrbot.core.config.astrbot_config import AstrBotConfig -from nakuru.entities.components import * from astrbot import logger from astrbot.core.utils.personality import personalities diff --git a/astrbot/api/message_components.py b/astrbot/api/message_components.py new file mode 100644 index 000000000..97d456b39 --- /dev/null +++ b/astrbot/api/message_components.py @@ -0,0 +1 @@ +from astrbot.core.message.components import * \ No newline at end of file diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index d289c6e99..e87e58524 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,5 +1,5 @@ from .log import LogManager, LogBroker -from core.utils.t2i.renderer import HtmlRenderer +from astrbot.core.utils.t2i.renderer import HtmlRenderer html_renderer = HtmlRenderer() logger = LogManager.GetLogger(log_name='astrbot') \ No newline at end of file diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index bc3fd97eb..3e8b572dd 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -67,6 +67,11 @@ class ImageGenerationModelConfig: style: str = "vivid" quality: str = "standard" +@dataclass +class EmbeddingModel: + enable: bool = False + model: str = "" + @dataclass class LLMConfig: id: str = "" @@ -77,12 +82,17 @@ class LLMConfig: prompt_prefix: str = "" default_personality: str = "" model_config: ModelConfig = field(default_factory=ModelConfig) - image_generation_model_config: Optional[ImageGenerationModelConfig] = None + image_generation_model_config: Optional[ImageGenerationModelConfig] = field(default_factory=ImageGenerationModelConfig) + embedding_model: Optional[EmbeddingModel] = field(default_factory=EmbeddingModel) def __post_init__(self): - self.model_config = ModelConfig(**self.model_config) - if self.image_generation_model_config: - self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) + if isinstance(self.model_config, dict): + self.model_config = ModelConfig(**self.model_config) + if isinstance(self.image_generation_model_config, dict): + self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) if self.image_generation_model_config else None + if isinstance(self.embedding_model, dict): + self.embedding_model = EmbeddingModel(**self.embedding_model) if self.embedding_model else None + @dataclass class LLMSettings: wake_prefix: str = "" @@ -115,6 +125,35 @@ class DashboardConfig: enable: bool = True username: str = "" password: str = "" + + +@dataclass +class ATRILongTermMemory: + enable: bool = False + summary_threshold_cnt: int = 5 + +@dataclass +class ATRIActiveMessage: + enable: bool = False + +@dataclass +class ProjectATRI: + enable: bool = False + long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory) + active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage) + persona: str = "" + embedding_provider_id: str = "" + summarize_provider_id: str = "" + chat_provider_id: str = "" + chat_base_model_path: str = "" + chat_adapter_model_path: str = "" + quantization_bit: int = 4 + + def __post_init__(self): + if isinstance(self.long_term_memory, dict): + self.long_term_memory = ATRILongTermMemory(**self.long_term_memory) + if isinstance(self.active_message, dict): + self.active_message = ATRIActiveMessage(**self.active_message) @dataclass class AstrBotConfig(): @@ -134,6 +173,7 @@ class AstrBotConfig(): t2i_endpoint: str = "" pip_install_arg: str = "" plugin_repo_mirror: str = "" + project_atri: ProjectATRI = field(default_factory=ProjectATRI) def __init__(self) -> None: self.init_configs() @@ -190,6 +230,7 @@ class AstrBotConfig(): self.t2i_endpoint=data.get("t2i_endpoint", "") self.pip_install_arg=data.get("pip_install_arg", "") self.plugin_repo_mirror=data.get("plugin_repo_mirror", "") + self.project_atri=ProjectATRI(**data.get("project_atri", {})) def flush_config(self, config: dict = None): '''将配置写入文件, 如果没有传入配置,则写入默认配置''' diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 007e98c45..aeb7312f3 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -27,6 +27,10 @@ PROVIDER_CONFIG_TEMPLATE = { "size": "1024x1024", "style": "vivid", "quality": "standard", + }, + "embedding_model": { + "enable": False, + "model": "text-embedding-3-small" } }, "ollama": { @@ -147,6 +151,23 @@ DEFAULT_CONFIG_VERSION_2 = { "t2i_endpoint": "", "pip_install_arg": "", "plugin_repo_mirror": "default", + "project_atri": { + "enable": False, + "long_term_memory": { + "enable": False, + "summary_threshold_cnt": 6, + }, + "active_message": { + "enable": False, + }, + "persona": "", + "embedding_provider_id": "", + "summarize_provider_id": "", + "chat_provider_id": "", + "chat_base_model_path": "", + "chat_adapter_model_path": "", + "quantization_bit": 4 + } } # 配置项的中文描述、值类型 @@ -167,7 +188,7 @@ CONFIG_METADATA_2 = { "ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"}, "qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"}, "qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"}, - "wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。"}, + "wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"}, } }, "platform_settings": { @@ -200,17 +221,17 @@ CONFIG_METADATA_2 = { "prompt_prefix": {"description": "Prompt 前缀", "type": "text", "hint": "每次与 LLM 对话时在对话前加上的自定义文本。默认为空。"}, "default_personality": {"description": "默认人格", "type": "text", "hint": "在当前版本下,默认人格文本会被添加到 LLM 对话的 `system` 字段中。"}, "model_config": { - "description": "模型配置", + "description": "文本生成模型", "type": "object", "items": { "model": {"description": "模型名称", "type": "string", "hint": "大语言模型的名称,一般是小写的英文。如 gpt-4o-mini, deepseek-chat 等。"}, - "max_tokens": {"description": "最大令牌数", "type": "int"}, + "max_tokens": {"description": "模型最大输出长度(tokens)", "type": "int"}, "temperature": {"description": "温度", "type": "float"}, "top_p": {"description": "Top P值", "type": "float"}, } }, "image_generation_model_config": { - "description": "图像生成模型配置", + "description": "图像生成模型", "type": "object", "items": { "enable": {"description": "启用", "type": "bool", "hint": "启用该功能需要提供商支持图像生成。如 dall-e-3"}, @@ -220,6 +241,14 @@ CONFIG_METADATA_2 = { "quality": {"description": "图像质量", "type": "string"}, } }, + "embedding_model": { + "description": "文本嵌入模型", + "type": "object", + "items": { + "enable": {"description": "启用", "type": "bool", "hint": "启用该功能需要提供商支持文本嵌入。"}, + "model": {"description": "模型名称", "type": "string", "hint": "文本嵌入模型的名称,一般是小写的英文。如 text-embedding-3-small"}, + } + } } }, "llm_settings": { @@ -273,6 +302,35 @@ CONFIG_METADATA_2 = { "t2i_endpoint": {"description": "文本转图像服务接口", "type": "string", "hint": "为空时使用 AstrBot API 服务"}, "pip_install_arg": {"description": "pip 安装参数", "type": "string", "hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。"}, "plugin_repo_mirror": {"description": "插件仓库镜像", "type": "string", "hint": "插件仓库的镜像地址,用于加速插件的下载。", "options": ["default", "https://ghp.ci/", "https://github-mirror.us.kg/"]}, + "project_atri": { + "description": "Project ATRI 配置", + "type": "object", + "items": { + "enable": {"description": "启用", "type": "bool"}, + "long_term_memory": { + "description": "长期记忆", + "type": "object", + "items": { + "enable": {"description": "启用", "type": "bool"}, + "summary_threshold_cnt": {"description": "摘要阈值", "type": "int", "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。"}, + } + }, + "active_message": { + "description": "主动消息", + "type": "object", + "items": { + "enable": {"description": "启用", "type": "bool"}, + } + }, + "persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True}, + "embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True}, + "summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True}, + "chat_provider_id": {"description": "Chat provider ID", "type": "string", "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True}, + "chat_base_model_path": {"description": "用于聊天的基座模型路径", "type": "string", "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。", "obvious_hint": True}, + "chat_adapter_model_path": {"description": "用于聊天的 Lora 模型路径", "type": "string", "hint": "Lora 模型路径。", "obvious_hint": True}, + "quantization_bit": {"description": "量化位数", "type": "int", "hint": "模型量化位数。如果你不知道这是什么,请不要修改。默认为 4。", "obvious_hint": True}, + } + } } DEFAULT_VALUE_MAP = { diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 9fdb32420..b8deef970 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -2,14 +2,14 @@ import asyncio, time, threading from .event_bus import EventBus from asyncio import Queue from typing import List -from core.config.astrbot_config import AstrBotConfig -from core.message_event_handler import MessageEventHandler -from core.plugin import PluginManager -from core import LogBroker -from core.db import BaseDatabase -from core.updator import AstrBotUpdator -from core import logger -from core.config.default import VERSION +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.message.message_event_handler import MessageEventHandler +from astrbot.core.plugin import PluginManager +from astrbot.core import LogBroker +from astrbot.core.db import BaseDatabase +from astrbot.core.updator import AstrBotUpdator +from astrbot.core import logger +from astrbot.core.config.default import VERSION class AstrBotCoreLifecycle: def __init__(self, log_broker: LogBroker, db: BaseDatabase): diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index e536598d9..f53bb3478 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -1,7 +1,7 @@ import abc from dataclasses import dataclass from typing import List -from core.db.po import Stats, LLMHistory +from astrbot.core.db.po import Stats, LLMHistory @dataclass class BaseDatabase(abc.ABC): diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index c9a4699fc..ed2f0c81e 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,7 +1,7 @@ import sqlite3 import os import time -from core.db.po import ( +from astrbot.core.db.po import ( Platform, Command, Provider, diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index d6da1d467..1ab5edb1e 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -2,10 +2,10 @@ import asyncio from asyncio import Queue from collections import defaultdict from typing import List -from .message_event_handler import MessageEventHandler -from core import logger +from astrbot.core.message.message_event_handler import MessageEventHandler +from astrbot.core import logger from .platform import AstrMessageEvent -from nakuru.entities.components import Plain, Image +from astrbot.core.message.components import Image, Plain class EventBus: def __init__(self, event_queue: Queue, message_event_handler: MessageEventHandler): diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py new file mode 100644 index 000000000..b8e8afb8a --- /dev/null +++ b/astrbot/core/message/components.py @@ -0,0 +1,443 @@ +''' +MIT License + +Copyright (c) 2021 Lxns-Network + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' + +import base64 +import json +import os +import typing as T +from enum import Enum +from pydantic.v1 import BaseModel + +class ComponentType(Enum): + Plain = "Plain" + Face = "Face" + Record = "Record" + Video = "Video" + At = "At" + RPS = "RPS" # TODO + Dice = "Dice" # TODO + Shake = "Shake" # TODO + Anonymous = "Anonymous" # TODO + Share = "Share" + Contact = "Contact" # TODO + Location = "Location" # TODO + Music = "Music" + Image = "Image" + Reply = "Reply" + RedBag = "RedBag" + Poke = "Poke" + Forward = "Forward" + Node = "Node" + Xml = "Xml" + Json = "Json" + CardImage = "CardImage" + TTS = "TTS" + Unknown = "Unknown" + + +class BaseMessageComponent(BaseModel): + type: ComponentType + + def toString(self): + output = f"[CQ:{self.type.lower()}" + for k, v in self.__dict__.items(): + if k == "type" or v is None: + continue + if k == "_type": + k = "type" + if isinstance(v, bool): + v = 1 if v else 0 + output += ",%s=%s" % (k, str(v).replace("&", "&") \ + .replace(",", ",") \ + .replace("[", "[") \ + .replace("]", "]")) + output += "]" + return output + + def toDict(self): + data = dict() + for k, v in self.__dict__.items(): + if k == "type" or v is None: + continue + if k == "_type": + k = "type" + data[k] = v + return { + "type": self.type.lower(), + "data": data + } + + +class Plain(BaseMessageComponent): + type: ComponentType = "Plain" + text: str + convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息 + + def __init__(self, text: str, convert: bool = True, **_): + super().__init__(text=text, convert=convert, **_) + + def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本 + if not self.convert: + return self.text + return self.text.replace("&", "&") \ + .replace("[", "[") \ + .replace("]", "]") + + +class Face(BaseMessageComponent): + type: ComponentType = "Face" + id: int + + def __init__(self, **_): + super().__init__(**_) + + +class Record(BaseMessageComponent): + type: ComponentType = "Record" + file: T.Optional[str] = "" + magic: T.Optional[bool] = False + url: T.Optional[str] = "" + cache: T.Optional[bool] = True + proxy: T.Optional[bool] = True + timeout: T.Optional[int] = 0 + # 额外 + path: T.Optional[str] + + def __init__(self, file: T.Optional[str], **_): + for k in _.keys(): + if k == "url": + pass + # Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}") + super().__init__(file=file, **_) + + @staticmethod + def fromFileSystem(path, **_): + return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) + + @staticmethod + def fromURL(url: str, **_): + if url.startswith("http://") or url.startswith("https://"): + return Record(file=url, **_) + raise Exception("not a valid url") + + +class Video(BaseMessageComponent): + type: ComponentType = "Video" + file: str + cover: T.Optional[str] = "" + c: T.Optional[int] = 2 + # 额外 + path: T.Optional[str] = "" + + def __init__(self, file: str, **_): + # for k in _.keys(): + # if k == "c" and _[k] not in [2, 3]: + # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") + super().__init__(file=file, **_) + + @staticmethod + def fromFileSystem(path, **_): + return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) + + @staticmethod + def fromURL(url: str, **_): + if url.startswith("http://") or url.startswith("https://"): + return Video(file=url, **_) + raise Exception("not a valid url") + + +class At(BaseMessageComponent): + type: ComponentType = "At" + qq: T.Union[int, str] # 此处str为all时代表所有人 + name: T.Optional[str] = "" + + def __init__(self, **_): + super().__init__(**_) + + +class AtAll(At): + qq: str = "all" + + def __init__(self, **_): + super().__init__(**_) + + +class RPS(BaseMessageComponent): # TODO + type: ComponentType = "RPS" + + def __init__(self, **_): + super().__init__(**_) + + +class Dice(BaseMessageComponent): # TODO + type: ComponentType = "Dice" + + def __init__(self, **_): + super().__init__(**_) + + +class Shake(BaseMessageComponent): # TODO + type: ComponentType = "Shake" + + def __init__(self, **_): + super().__init__(**_) + + +class Anonymous(BaseMessageComponent): # TODO + type: ComponentType = "Anonymous" + ignore: T.Optional[bool] = False + + def __init__(self, **_): + super().__init__(**_) + + +class Share(BaseMessageComponent): + type: ComponentType = "Share" + url: str + title: str + content: T.Optional[str] = "" + image: T.Optional[str] = "" + + def __init__(self, **_): + super().__init__(**_) + + +class Contact(BaseMessageComponent): # TODO + type: ComponentType = "Contact" + _type: str # type 字段冲突 + id: T.Optional[int] = 0 + + def __init__(self, **_): + super().__init__(**_) + + +class Location(BaseMessageComponent): # TODO + type: ComponentType = "Location" + lat: float + lon: float + title: T.Optional[str] = "" + content: T.Optional[str] = "" + + def __init__(self, **_): + super().__init__(**_) + + +class Music(BaseMessageComponent): + type: ComponentType = "Music" + _type: str + id: T.Optional[int] = 0 + url: T.Optional[str] = "" + audio: T.Optional[str] = "" + title: T.Optional[str] = "" + content: T.Optional[str] = "" + image: T.Optional[str] = "" + + def __init__(self, **_): + # for k in _.keys(): + # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: + # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") + super().__init__(**_) + + +class Image(BaseMessageComponent): + type: ComponentType = "Image" + file: T.Optional[str] = "" + _type: T.Optional[str] = "" + subType: T.Optional[int] = 0 + url: T.Optional[str] = "" + cache: T.Optional[bool] = True + id: T.Optional[int] = 40000 + c: T.Optional[int] = 2 + # 额外 + path: T.Optional[str] = "" + + def __init__(self, file: T.Optional[str], **_): + # for k in _.keys(): + # if (k == "_type" and _[k] not in ["flash", "show", None]) or \ + # (k == "c" and _[k] not in [2, 3]): + # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") + super().__init__(file=file, **_) + + @staticmethod + def fromURL(url: str, **_): + if url.startswith("http://") or url.startswith("https://"): + return Image(file=url, **_) + raise Exception("not a valid url") + + @staticmethod + def fromFileSystem(path, **_): + return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_) + + @staticmethod + def fromBase64(base64: str, **_): + return Image(f"base64://{base64}", **_) + + @staticmethod + def fromBytes(byte: bytes): + return Image.fromBase64(base64.b64encode(byte).decode()) + + @staticmethod + def fromIO(IO): + return Image.fromBytes(IO.read()) + + +class Reply(BaseMessageComponent): + type: ComponentType = "Reply" + id: int + text: T.Optional[str] = "" + qq: T.Optional[int] = 0 + time: T.Optional[int] = 0 + seq: T.Optional[int] = 0 + + def __init__(self, **_): + super().__init__(**_) + + +class RedBag(BaseMessageComponent): + type: ComponentType = "RedBag" + title: str + + def __init__(self, **_): + super().__init__(**_) + + +class Poke(BaseMessageComponent): + type: ComponentType = "Poke" + qq: int + + def __init__(self, **_): + super().__init__(**_) + + +class Forward(BaseMessageComponent): + type: ComponentType = "Forward" + id: str + + def __init__(self, **_): + super().__init__(**_) + + +class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForwardMessage 发送 + type: ComponentType = "Node" + id: T.Optional[int] = 0 + name: T.Optional[str] = "" + uin: T.Optional[int] = 0 + content: T.Optional[T.Union[str, list]] = "" + seq: T.Optional[T.Union[str, list]] = "" # 不清楚是什么 + time: T.Optional[int] = 0 + + def __init__(self, content: T.Union[str, list], **_): + if isinstance(content, list): + _content = "" + for chain in content: + _content += chain.toString() + content = _content + super().__init__(content=content, **_) + + def toString(self): + # logger.warn("Protocol: node doesn't support stringify") + return "" + + +class Xml(BaseMessageComponent): + type: ComponentType = "Xml" + data: str + resid: T.Optional[int] = 0 + + def __init__(self, **_): + super().__init__(**_) + + +class Json(BaseMessageComponent): + type: ComponentType = "Json" + data: T.Union[str, dict] + resid: T.Optional[int] = 0 + + def __init__(self, data, **_): + if isinstance(data, dict): + data = json.dumps(data) + super().__init__(data=data, **_) + + +class CardImage(BaseMessageComponent): + type: ComponentType = "CardImage" + file: str + cache: T.Optional[bool] = True + minwidth: T.Optional[int] = 400 + minheight: T.Optional[int] = 400 + maxwidth: T.Optional[int] = 500 + maxheight: T.Optional[int] = 500 + source: T.Optional[str] = "" + icon: T.Optional[str] = "" + + def __init__(self, **_): + super().__init__(**_) + + @staticmethod + def fromFileSystem(path, **_): + return CardImage(file=f"file:///{os.path.abspath(path)}", **_) + + +class TTS(BaseMessageComponent): + type: ComponentType = "TTS" + text: str + + def __init__(self, **_): + super().__init__(**_) + + +class Unknown(BaseMessageComponent): + type: ComponentType = "Unknown" + text: str + + def toString(self): + return "" + + +ComponentTypes = { + "plain": Plain, + "face": Face, + "record": Record, + "video": Video, + "at": At, + "rps": RPS, + "dice": Dice, + "shake": Shake, + "anonymous": Anonymous, + "share": Share, + "contact": Contact, + "location": Location, + "music": Music, + "image": Image, + "reply": Reply, + "redbag": RedBag, + "poke": Poke, + "forward": Forward, + "node": Node, + "xml": Xml, + "json": Json, + "cardimage": CardImage, + "tts": TTS, + "unknown": Unknown +} diff --git a/astrbot/core/message_event_handler.py b/astrbot/core/message/message_event_handler.py similarity index 96% rename from astrbot/core/message_event_handler.py rename to astrbot/core/message/message_event_handler.py index f1323b4df..45e2ecb1f 100644 --- a/astrbot/core/message_event_handler.py +++ b/astrbot/core/message/message_event_handler.py @@ -2,13 +2,13 @@ import asyncio, re, time import inspect import traceback from typing import List, Union -from .platform import AstrMessageEvent -from .config.astrbot_config import AstrBotConfig +from astrbot.core.platform import AstrMessageEvent +from astrbot.core.config.astrbot_config import AstrBotConfig from .message_event_result import MessageEventResult, CommandResult, MessageChain -from .plugin import PluginManager, Context, CommandMetadata -from nakuru.entities.components import * -from core import logger -from core import html_renderer +from astrbot.core.plugin import PluginManager, Context, CommandMetadata +from .components import * +from astrbot.core import logger +from astrbot.core import html_renderer class CommandTokens(): def __init__(self) -> None: diff --git a/astrbot/core/message_event_result.py b/astrbot/core/message/message_event_result.py similarity index 75% rename from astrbot/core/message_event_result.py rename to astrbot/core/message/message_event_result.py index 89136871b..a7629cbe0 100644 --- a/astrbot/core/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,11 +1,12 @@ from typing import List, Union, Optional from dataclasses import dataclass, field -from nakuru.entities.components import * +from astrbot.core.message.components import * @dataclass class MessageChain(): chain: List[BaseMessageComponent] = field(default_factory=list) use_t2i_: Optional[bool] = None # None 为跟随用户设置 + is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。 def message(self, message: str): ''' @@ -49,6 +50,15 @@ class MessageChain(): ''' self.use_t2i_ = use_t2i return self + + def is_split(self, is_split: bool): + ''' + 设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 + + 具体的效果以各适配器实现为准。 + ''' + self.is_split_ = is_split + return self @dataclass class MessageEventResult(MessageChain): diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 84803523e..8b319462a 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -2,11 +2,11 @@ import abc from dataclasses import dataclass from .astrbot_message import AstrBotMessage from .platform_metadata import PlatformMetadata -from core.message_event_result import MessageEventResult, MessageChain -from core.platform.message_type import MessageType +from astrbot.core.message.message_event_result import MessageEventResult, MessageChain +from astrbot.core.platform.message_type import MessageType from typing import List -from nakuru.entities.components import BaseMessageComponent, Plain, Image -from core.utils.metrics import Metric +from astrbot.core.message.components import BaseMessageComponent, Plain, Image +from astrbot.core.utils.metrics import Metric @dataclass class MessageSesion: diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 2b3fb6519..7c8c9d5d0 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -1,7 +1,7 @@ import time from typing import List from dataclasses import dataclass -from nakuru.entities.components import BaseMessageComponent +from astrbot.core.message.components import BaseMessageComponent from .message_type import MessageType @dataclass diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index d8214e01a..1dd356e89 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -3,9 +3,9 @@ from typing import Awaitable, Any from asyncio import Queue from .platform_metadata import PlatformMetadata from .astr_message_event import AstrMessageEvent -from core.message_event_result import MessageChain +from astrbot.core.message.message_event_result import MessageChain from .astr_message_event import MessageSesion -from core.utils.metrics import Metric +from astrbot.core.utils.metrics import Metric class Platform(abc.ABC): def __init__(self, event_queue: Queue): diff --git a/astrbot/core/plugin/__init__.py b/astrbot/core/plugin/__init__.py index f054edcd3..61e1860d7 100644 --- a/astrbot/core/plugin/__init__.py +++ b/astrbot/core/plugin/__init__.py @@ -1,4 +1,4 @@ from .plugin import Plugin, RegisteredPlugin, PluginMetadata from .plugin_manager import PluginManager from .context import CommandMetadata, Context -from core.provider import Provider \ No newline at end of file +from astrbot.core.provider import Provider \ No newline at end of file diff --git a/astrbot/core/plugin/context.py b/astrbot/core/plugin/context.py index cf3d4c409..a1d341fbe 100644 --- a/astrbot/core/plugin/context.py +++ b/astrbot/core/plugin/context.py @@ -4,12 +4,12 @@ from . import RegisteredPlugin, PluginMetadata from typing import List, Dict, Awaitable, Union from dataclasses import dataclass -from core.platform import Platform -from core.db import BaseDatabase -from core.config.astrbot_config import AstrBotConfig -from core.utils.func_call import FuncCall -from core.platform.astr_message_event import MessageSesion -from core.message_event_result import MessageChain +from astrbot.core.platform import Platform +from astrbot.core.db import BaseDatabase +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.utils.func_call import FuncCall +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.message.message_event_result import MessageChain @dataclass class CommandMetadata(): @@ -67,6 +67,9 @@ class Context: # 维护了 LLM Tools 信息 llm_tools: FuncCall = FuncCall() + # 维护插件存储的数据 + plugin_data: Dict[str, Dict[str, any]] = {} + def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase): self._event_queue = event_queue self._config = config @@ -205,4 +208,10 @@ class Context: if platform.meta().name == session.platform_name: await platform.send_by_session(session, message_chain) return True - return False \ No newline at end of file + return False + + def set_data(self, plugin_name: str, key: str, value: any): + ''' + 设置插件数据。 + ''' + self.plugin_data[plugin_name][key] = value \ No newline at end of file diff --git a/astrbot/core/plugin/plugin_manager.py b/astrbot/core/plugin/plugin_manager.py index 93bbb82f9..34fdbc930 100644 --- a/astrbot/core/plugin/plugin_manager.py +++ b/astrbot/core/plugin/plugin_manager.py @@ -10,13 +10,13 @@ from asyncio import Queue from types import ModuleType from typing import List, Awaitable from pip import main as pip_main -from core.config.astrbot_config import AstrBotConfig -from core import logger +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core import logger from .context import Context from . import RegisteredPlugin, PluginMetadata from .updator import PluginUpdator -from core.db import BaseDatabase -from core.utils.io import remove_dir +from astrbot.core.db import BaseDatabase +from astrbot.core.utils.io import remove_dir class PluginManager: def __init__(self, config: AstrBotConfig, event_queue: Queue, db: BaseDatabase): diff --git a/astrbot/core/plugin/updator.py b/astrbot/core/plugin/updator.py index 33032173a..3d358f834 100644 --- a/astrbot/core/plugin/updator.py +++ b/astrbot/core/plugin/updator.py @@ -1,10 +1,10 @@ import os, zipfile, shutil from ..updator import RepoZipUpdator -from core.utils.io import remove_dir, on_error +from astrbot.core.utils.io import remove_dir, on_error from ..plugin import RegisteredPlugin from typing import Union -from core import logger +from astrbot.core import logger class PluginUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index bf9be469a..d7f09bd65 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1 +1 @@ -from .provider import Provider \ No newline at end of file +from .provider import Provider, Personality \ No newline at end of file diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 1efb5d58a..f8d2289bb 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,11 +1,30 @@ -import abc +import abc, json, threading, time from collections import defaultdict from typing import List -# from core.utils.func_call import FuncCall +from astrbot.core.db import BaseDatabase +from astrbot.core import logger +from typing import TypedDict + +class Personality(TypedDict): + prompt: str + name: str class Provider(abc.ABC): - def __init__(self) -> None: + def __init__(self, db_helper: BaseDatabase, default_personality: str = None, persistant_history: bool = True) -> None: self.model_name = "unknown" + # 维护了 session_id 的上下文,不包含 system 指令 + self.session_memory = defaultdict(list) + self.curr_personality = Personality(prompt=default_personality, name="") + + if persistant_history: + # 读取历史记录 + self.db_helper = db_helper + try: + for history in db_helper.get_llm_history(): + self.session_memory[history.session_id] = json.loads(history.content) + except BaseException as e: + logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。") + def set_model(self, model_name: str): self.model_name = model_name @@ -13,12 +32,32 @@ class Provider(abc.ABC): def get_model(self): return self.model_name + async def get_human_readable_context(self, session_id: str) -> List[str]: + ''' + 获取人类可读的上下文 + + example: + ["User: 你好", "Assistant: 你好"] + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + contexts = [] + for record in self.session_memory[session_id]: + if record['role'] == "user": + contexts.append(f"User: {record['content']}") + elif record['role'] == "assistant": + contexts.append(f"Assistant: {record['content']}") + + return contexts + @abc.abstractmethod async def text_chat(self, prompt: str, session_id: str, image_urls: List[str] = None, - tool = None, + tools = None, + contexts=None, **kwargs) -> str: ''' prompt: 提示词 @@ -38,6 +77,13 @@ class Provider(abc.ABC): ''' raise NotImplementedError() + @abc.abstractmethod + async def get_embedding(self, text: str) -> List[float]: + ''' + 获取文本的嵌入 + ''' + raise NotImplementedError() + @abc.abstractmethod async def forget(self, session_id: str) -> bool: ''' diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 99a31324d..91cf1f1eb 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -1,8 +1,8 @@ import os, psutil, sys, time from .zip_updator import ReleaseInfo, RepoZipUpdator -from core import logger -from core.config.default import VERSION -from core.utils.io import download_file +from astrbot.core import logger +from astrbot.core.config.default import VERSION +from astrbot.core.utils.io import download_file class AstrBotUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: diff --git a/astrbot/core/utils/func_call.py b/astrbot/core/utils/func_call.py index 20e349d62..581e2cc3e 100644 --- a/astrbot/core/utils/func_call.py +++ b/astrbot/core/utils/func_call.py @@ -1,4 +1,4 @@ -from core.provider import Provider +from astrbot.core.provider import Provider from typing import Awaitable import json import textwrap diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 371bf86b4..577a1e872 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -1,7 +1,7 @@ import aiohttp import sys import logging -from core.config import VERSION +from astrbot.core.config import VERSION logger = logging.getLogger("astrbot") diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 33e0088ba..a4e0680fe 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -4,7 +4,7 @@ from io import BytesIO from . import RenderStrategy from PIL import ImageFont, Image, ImageDraw -from core.utils.io import save_temp_img +from astrbot.core.utils.io import save_temp_img class LocalRenderStrategy(RenderStrategy): diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index b305a6fcc..6d6b62e4e 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -2,8 +2,8 @@ import aiohttp import os from . import RenderStrategy -from core.config import VERSION -from core.utils.io import download_image_by_url +from astrbot.core.config import VERSION +from astrbot.core.utils.io import download_image_by_url ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img" diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 94acd82ad..f3298d3cb 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -1,6 +1,6 @@ from .network_strategy import NetworkRenderStrategy from .local_strategy import LocalRenderStrategy -from core.log import LogManager +from astrbot.core.log import LogManager logger = LogManager.GetLogger(log_name='astrbot') diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 439ad9b92..72c87507f 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -1,6 +1,6 @@ import aiohttp, os, zipfile, shutil -from core.utils.io import on_error, download_file -from core import logger +from astrbot.core.utils.io import on_error, download_file +from astrbot.core import logger class ReleaseInfo(): version: str diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/dashboard/dashboard_lifecycle.py index b084d88ef..dc0be1a35 100644 --- a/astrbot/dashboard/dashboard_lifecycle.py +++ b/astrbot/dashboard/dashboard_lifecycle.py @@ -1,9 +1,9 @@ import asyncio from multiprocessing import Process -from core import logger -from core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core import logger +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from .server import AstrBotDashboard -from core.db import BaseDatabase +from astrbot.core.db import BaseDatabase class AstrBotDashBoardLifecycle: def __init__(self, db: BaseDatabase): diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 2da21ea81..b76a9eeaa 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -1,6 +1,6 @@ from .route import Route, Response from quart import Quart, request -from core.config.astrbot_config import AstrBotConfig +from astrbot.core.config.astrbot_config import AstrBotConfig class AuthRoute(Route): def __init__(self, config: AstrBotConfig, app: Quart) -> None: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 63cd0c17a..55f3379c1 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,10 +1,10 @@ import os, json from .route import Route, Response from quart import Quart, request -from core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE -from core.config.astrbot_config import AstrBotConfig -from core.plugin.config import update_config -from core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.plugin.config import update_config +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from dataclasses import asdict def try_cast(value: str, type_: str): diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index bda4af55f..8dabccf6f 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -1,8 +1,8 @@ import asyncio from quart import websocket from quart import Quart -from core.config.astrbot_config import AstrBotConfig -from core import logger, LogBroker +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core import logger, LogBroker from .route import Route, Response class LogRoute(Route): diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index e8f82222e..1a112ff2a 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,10 +1,10 @@ import threading, traceback, uuid from .route import Route, Response -from core import logger +from astrbot.core import logger from quart import Quart, request -from core.config.astrbot_config import AstrBotConfig -from core.plugin.plugin_manager import PluginManager -from core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.plugin.plugin_manager import PluginManager +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle class PluginRoute(Route): def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None: diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 5cc8160c6..5b2ff41c2 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -1,4 +1,4 @@ -from core.config.astrbot_config import AstrBotConfig +from astrbot.core.config.astrbot_config import AstrBotConfig from dataclasses import dataclass from quart import Quart diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index cc5bd95e0..d69e474d1 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,11 +1,11 @@ import traceback, psutil, time, aiohttp from .route import Route, Response -from core import logger +from astrbot.core import logger from quart import Quart, request -from core.config.astrbot_config import AstrBotConfig -from core.core_lifecycle import AstrBotCoreLifecycle -from core.db import BaseDatabase -from core.config import VERSION +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase +from astrbot.core.config import VERSION class StatRoute(Route): def __init__(self, config: AstrBotConfig, app: Quart, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None: diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 53ccf907d..4e8b835c2 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -1,6 +1,6 @@ from .route import Route from quart import Quart -from core.config.astrbot_config import AstrBotConfig +from astrbot.core.config.astrbot_config import AstrBotConfig class StaticFileRoute(Route): def __init__(self, config: AstrBotConfig, app: Quart) -> None: diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index c810e1163..a4566b190 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -1,9 +1,9 @@ import threading, traceback from .route import Route, Response from quart import Quart, request -from core.config.astrbot_config import AstrBotConfig -from core.updator import AstrBotUpdator -from core import logger +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.updator import AstrBotUpdator +from astrbot.core import logger class UpdateRoute(Route): def __init__(self, config: AstrBotConfig, app: Quart, astrbot_updator: AstrBotUpdator) -> None: diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 3e36feb57..32d5713f6 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -2,22 +2,23 @@ import logging import asyncio, os from quart import Quart from quart.logging import default_handler -from core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from .routes import * -from core import logger -from core.db import BaseDatabase -from core.plugin.plugin_manager import PluginManager -from core.updator import AstrBotUpdator -from core.utils.io import get_local_ip_addresses -from core.config import AstrBotConfig -from core.db import BaseDatabase +from astrbot.core import logger +from astrbot.core.db import BaseDatabase +from astrbot.core.plugin.plugin_manager import PluginManager +from astrbot.core.updator import AstrBotUpdator +from astrbot.core.utils.io import get_local_ip_addresses +from astrbot.core.config import AstrBotConfig +from astrbot.core.db import BaseDatabase class AstrBotDashboard(): def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None: self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config self.data_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/dist")) - self.app = Quart("dashboard", static_folder="dist", static_url_path="/") + logger.info(f"Dashboard data path: {self.data_path}") + self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/") self.app.json.sort_keys = False logging.getLogger(self.app.name).removeHandler(default_handler) diff --git a/astrbot/main.py b/main.py similarity index 93% rename from astrbot/main.py rename to main.py index 4c6eeb4b9..182233bdb 100644 --- a/astrbot/main.py +++ b/main.py @@ -6,12 +6,12 @@ import mimetypes import aiohttp import zipfile from typing import List -from core.core_lifecycle import AstrBotCoreLifecycle -from core.db.sqlite import SQLiteDatabase -from core.config import DB_PATH -from dashboard import AstrBotDashBoardLifecycle +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.config import DB_PATH +from astrbot.dashboard import AstrBotDashBoardLifecycle -from core import logger, LogManager, LogBroker +from astrbot.core import logger, LogManager, LogBroker # add parent path to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py index 84943900c..dadcc51a7 100644 --- a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py +++ b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py @@ -1,7 +1,7 @@ -import os, traceback +import os, traceback, random, asyncio from astrbot.api import AstrMessageEvent, MessageChain, logger -from astrbot.api import Plain, Image +from astrbot.api.message_components import Plain, Image from aiocqhttp import CQHttp from astrbot.core.utils.io import file_to_base64, download_image_by_url @@ -11,7 +11,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent): self.bot = bot @staticmethod - async def _parse_onebot_josn(message_chain: MessageChain): + async def _parse_onebot_json(message_chain: MessageChain): '''解析成 OneBot json 格式''' ret = [] for segment in message_chain.chain: @@ -31,8 +31,14 @@ class AiocqhttpMessageEvent(AstrMessageEvent): return ret async def send(self, message: MessageChain): - ret = await AiocqhttpMessageEvent._parse_onebot_josn(message) + ret = await AiocqhttpMessageEvent._parse_onebot_json(message) if os.environ.get('TEST_MODE', 'off') == 'on': return - await self.bot.send(self.message_obj.raw_message, ret) + + if message.is_split_: # 分条发送 + for m in ret: + await self.bot.send(self.message_obj.raw_message, [m]) + await asyncio.sleep(random.uniform(0.75, 2.5)) + else: + await self.bot.send(self.message_obj.raw_message, ret) await super().send(message) \ No newline at end of file diff --git a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py index 1158cf727..95b7eec71 100644 --- a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py +++ b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py @@ -7,7 +7,7 @@ from aiocqhttp import CQHttp, Event from astrbot.api import Platform from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from .aiocqhttp_message_event import * -from nakuru.entities.components import * +from astrbot.api.message_components import * from astrbot.api import logger from .aiocqhttp_message_event import AiocqhttpMessageEvent from astrbot.core.config.astrbot_config import PlatformConfig, AiocqhttpPlatformConfig, PlatformSettings @@ -29,7 +29,7 @@ class AiocqhttpAdapter(Platform): ) async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): - ret = await AiocqhttpMessageEvent._parse_onebot_josn(message_chain) + ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain) match session.message_type.value: case MessageType.GROUP_MESSAGE.value: if "_" in session.session_id: diff --git a/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py b/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py index 48e2826e0..2b02e3455 100644 --- a/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py +++ b/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py @@ -4,7 +4,7 @@ import botpy.types import botpy.types.message from astrbot.core.utils.io import file_to_base64, download_image_by_url from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata, MessageType -from astrbot.api import Plain, Image +from astrbot.api.message_components import Plain, Image from botpy import Client from botpy.http import Route diff --git a/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py b/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py index 6bcc07000..11ba4f53b 100644 --- a/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py +++ b/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py @@ -9,7 +9,7 @@ from botpy import Client from astrbot.api import Platform from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from typing import Union, List, Dict -from nakuru.entities.components import * +from astrbot.api.message_components import * from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion from .qqofficial_message_event import QQOfficialMessageEvent diff --git a/packages/astrbot_adapter_wechat/wechat_message_event.py b/packages/astrbot_adapter_wechat/wechat_message_event.py index 91cbf0180..9451e043c 100644 --- a/packages/astrbot_adapter_wechat/wechat_message_event.py +++ b/packages/astrbot_adapter_wechat/wechat_message_event.py @@ -1,7 +1,7 @@ import random, asyncio from astrbot.core.utils.io import download_image_by_url from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata -from astrbot.api import Plain, Image +from astrbot.api.message_components import Plain, Image from vchat import Core class WechatPlatformEvent(AstrMessageEvent): @@ -14,7 +14,10 @@ class WechatPlatformEvent(AstrMessageEvent): plain = "" for comp in message.chain: if isinstance(comp, Plain): - plain += comp.text + if message.is_split_: + await client.send_msg(comp.text, user_name) + else: + plain += comp.text elif isinstance(comp, Image): if comp.file and comp.file.startswith("file:///"): file_path = comp.file.replace("file:///", "") diff --git a/packages/astrbot_adapter_wechat/wechat_platform_adapter.py b/packages/astrbot_adapter_wechat/wechat_platform_adapter.py index 8d08b99b7..d743ddf45 100644 --- a/packages/astrbot_adapter_wechat/wechat_platform_adapter.py +++ b/packages/astrbot_adapter_wechat/wechat_platform_adapter.py @@ -4,7 +4,7 @@ import asyncio from astrbot.api import Platform from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata from typing import Union, List, Dict -from nakuru.entities.components import * +from astrbot.api.message_components import * from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion from .wechat_message_event import WechatPlatformEvent @@ -24,6 +24,7 @@ class WechatPlatformAdapter(Platform): def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None: super().__init__(event_queue) self.config = platform_config + self.settingss = platform_settings self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' self.client_self_id = uuid.uuid4().hex[:8] @@ -51,6 +52,7 @@ class WechatPlatformAdapter(Platform): if msg.create_time < self.start_time: logger.debug(f"忽略旧消息: {msg}") return + logger.debug(f"收到消息: {msg.todict()}") if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist: logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}") return @@ -80,7 +82,11 @@ class WechatPlatformAdapter(Platform): sender = msg.chatroom_sender or msg.from_ amsg.sender = MessageMember(sender.username, sender.nickname) - amsg.message_str = msg.content.content + + if msg.content.is_at_me: + amsg.message_str = msg.content.content.split("\u2005")[1].strip() + else: + amsg.message_str = msg.content.content amsg.message_id = msg.message_id if isinstance(msg.from_, model.User): amsg.type = MessageType.FRIEND_MESSAGE @@ -91,10 +97,13 @@ class WechatPlatformAdapter(Platform): amsg.raw_message = msg - session_id = msg.from_.username + "$$" + msg.to.username - if msg.chatroom_sender is not None: - session_id += '$$' + msg.chatroom_sender.username - + if self.settingss.unique_session: + session_id = msg.from_.username + "$$" + msg.to.username + if msg.chatroom_sender is not None: + session_id += '$$' + msg.chatroom_sender.username + else: + session_id = msg.from_.username + amsg.session_id = session_id return amsg diff --git a/packages/astrbot_plugin_openai/commands.py b/packages/astrbot_plugin_openai/commands.py index 31a613eb5..885b570fe 100644 --- a/packages/astrbot_plugin_openai/commands.py +++ b/packages/astrbot_plugin_openai/commands.py @@ -1,9 +1,10 @@ from astrbot.api import Context, AstrMessageEvent, MessageEventResult, MessageChain from . import PLUGIN_NAME -from astrbot.api import logger, Image, Plain +from astrbot.api import logger +from astrbot.api.message_components import Image, Plain from astrbot.api import personalities from astrbot.api import command_parser -from astrbot.api import Provider +from astrbot.api import Provider, Personality class OpenAIAdapterCommand: @@ -25,7 +26,7 @@ class OpenAIAdapterCommand: async def reset(self, message: AstrMessageEvent): tokens = command_parser.parse(message.message_str) if tokens.len == 1: - await self.provider.forget(message.session_id, keep_system_prompt=True) + await self.provider.forget(message.session_id) message.set_result(MessageEventResult().message("重置成功")) elif tokens.get(1) == 'p': await self.provider.forget(message.session_id) @@ -81,17 +82,13 @@ class OpenAIAdapterCommand: message.set_result(MessageEventResult().message(f"历史记录:\n\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n\n*输入 /his 2 跳转到第 2 页")) def status(self, message: AstrMessageEvent): - keys_data = self.provider.get_keys_data() - ret = "OpenAI Key" + keys_data = self.provider.get_all_keys() + ret = "{} Key" for k in keys_data: - status = "🟢" if keys_data[k] else "🔴" - ret += "\n|- " + k[:8] + " " + status + ret += "\n|- " + k[:8] ret += "\n当前模型: " + self.provider.get_model() - if message.session_id in self.provider.session_memory and len(self.provider.session_memory[message.session_id]): - ret += "\n你的会话上下文: " + str(self.provider.session_memory[message.session_id][-1]['usage_tokens']) + " tokens" - message.set_result(MessageEventResult().message(ret).use_t2i(False)) async def switch(self, message: AstrMessageEvent): @@ -160,18 +157,10 @@ class OpenAIAdapterCommand: else: ps = "".join(l[1:]).strip() if ps in personalities: - self.provider.curr_personality = { - 'name': ps, - 'prompt': personalities[ps] - } - self.provider.personality_set(self.provider.curr_personality, message.session_id) + self.provider.curr_personality = Personality(name=ps, prompt=personalities[ps]) message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) else: - self.provider.curr_personality = { - 'name': '自定义人格', - 'prompt': ps - } - self.provider.personality_set(self.provider.curr_personality, message.session_id) + self.provider.curr_personality = Personality(name="自定义人格", prompt=ps) message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) async def draw(self, message: AstrMessageEvent): diff --git a/packages/astrbot_plugin_openai/main.py b/packages/astrbot_plugin_openai/main.py index 875b8ee67..eb24200b0 100644 --- a/packages/astrbot_plugin_openai/main.py +++ b/packages/astrbot_plugin_openai/main.py @@ -5,28 +5,34 @@ from .openai_adapter import ProviderOpenAIOfficial from .commands import OpenAIAdapterCommand from astrbot.api import logger from . import PLUGIN_NAME -from astrbot.api import Image, Plain, MessageChain +from astrbot.api import MessageChain +from astrbot.api.message_components import Image, Plain from openai._exceptions import * from openai.types.chat.chat_completion_message_tool_call import Function from astrbot.api import command_parser from .web_searcher import search_from_bing, fetch_website_content from astrbot.core.utils.metrics import Metric from astrbot.core.config.astrbot_config import LLMConfig +from .atri import ATRI class Main: def __init__(self, context: Context) -> None: supported_provider_names = ["openai", "ollama", "gemini", "deepseek", "zhipu"] - self.context = context + # 各 Provider 实例 self.provider_insts: List[ProviderOpenAIOfficial] = [] + # Provider 的配置 self.provider_llm_configs: List[LLMConfig] = [] + # 当前使用的 Provider self.provider = None + # 当前使用的 Provider 的配置 self.provider_config = None - llms_config = self.context.get_config().llm + atri_config = self.context.get_config().project_atri + loaded = False - for llm in llms_config: + for llm in self.context.get_config().llm: if llm.enable: if llm.name in supported_provider_names: if not llm.key or not llm.enable: @@ -36,20 +42,33 @@ class Main: self.provider_llm_configs.append(llm) loaded = True logger.info(f"已启用 LLM Provider(OpenAI API 适配器): {llm.id}({llm.name})。") - if loaded: self.command_handler = OpenAIAdapterCommand(self.context) self.command_handler.set_provider(self.provider_insts[0]) - self.context.register_listener(PLUGIN_NAME, "openai_adapter_chat", self.chat, "OpenAI Adapter LLM 调用监听器", after_commands=True) + self.context.register_listener(PLUGIN_NAME, "llm_chat_listener", self.chat, "llm_chat_listener", after_commands=True) self.provider = self.command_handler.provider self.provider_config = self.provider_llm_configs[0] - self.context.register_commands(PLUGIN_NAME, "provider", "查看当前 LLM Provider", 10, self.provider_info) self.context.register_commands(PLUGIN_NAME, "websearch", "启用/关闭网页搜索", 10, self.web_search) if self.context.get_config().llm_settings.web_search: self.add_web_search_tools() - + + # load atri + self.atri = None + if atri_config.enable: + try: + self.atri = ATRI(self.provider_llm_configs, atri_config, self.context) + self.command_handler.provider = self.atri.atri_chat_provider + except ImportError as e: + logger.error(traceback.format_exc()) + logger.error("载入 ATRI 失败。请确保使用 pip 安装了 requirements_atri.txt 下的库。") + self.atri = None + except BaseException as e: + logger.error(traceback.format_exc()) + logger.error("载入 ATRI 失败。") + self.atri = None + def add_web_search_tools(self): self.context.register_llm_tool("web_search", [{ "type": "string", @@ -121,7 +140,10 @@ class Main: async def chat(self, event: AstrMessageEvent): if not event.is_wake_up(): return - + if self.atri: + await self.atri.chat(event) + return + # prompt 前缀 if self.provider_config.prompt_prefix: event.message_str = self.provider_config.prompt_prefix + event.message_str @@ -131,6 +153,8 @@ class Main: if isinstance(comp, Image): image_url = comp.url if comp.url else comp.file break + + tool_use_flag = False llm_result = None try: if not self.context.llm_tools.empty(): @@ -177,7 +201,6 @@ class Main: return else: # normal chat - tool_use_flag = False # add user info to the prompt if self.context.get_config().llm_settings.identifier: user_id = event.message_obj.sender.user_id diff --git a/packages/astrbot_plugin_openai/openai_adapter.py b/packages/astrbot_plugin_openai/openai_adapter.py index 7f06fb98c..5987083c3 100644 --- a/packages/astrbot_plugin_openai/openai_adapter.py +++ b/packages/astrbot_plugin_openai/openai_adapter.py @@ -1,11 +1,8 @@ import os import asyncio -import json -import time -import tiktoken -import threading import traceback import base64 +import json from openai import AsyncOpenAI from openai.types.chat.chat_completion import ChatCompletion @@ -17,90 +14,38 @@ from astrbot.api import Provider from astrbot.core.config.astrbot_config import LLMConfig from astrbot import logger from typing import List, Dict - from dataclasses import asdict class ProviderOpenAIOfficial(Provider): - def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase) -> None: - super().__init__() + def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase, persistant_history = True) -> None: + super().__init__(db_helper, llm_config.default_personality, persistant_history) self.api_keys = [] self.chosen_api_key = None self.base_url = None self.llm_config = llm_config - self.keys_data = {} # 记录超额 - if llm_config.key: self.api_keys = llm_config.key - if llm_config.api_base: self.base_url = llm_config.api_base - if not self.api_keys: - logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。") - else: - self.chosen_api_key = self.api_keys[0] - - for key in self.api_keys: - self.keys_data[key] = True + self.api_keys = llm_config.key + if llm_config.api_base: + self.base_url = llm_config.api_base + self.chosen_api_key = self.api_keys[0] self.client = AsyncOpenAI( api_key=self.chosen_api_key, base_url=self.base_url ) self.set_model(llm_config.model_config.model) - if llm_config.image_generation_model_config: - self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config) - self.session_memory: Dict[str, List] = {} # 会话记忆 - self.session_memory_lock = threading.Lock() - self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小 - logger.info("正在载入分词器 cl100k_base...") - self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 - - self.DEFAULT_PERSONALITY = { - "prompt": self.llm_config.default_personality, - "name": "default" - } - self.curr_personality = self.DEFAULT_PERSONALITY - self.session_personality = {} # 记录了某个session是否已设置人格。 - # 读取历史记录 - self.db_helper = db_helper - try: - for history in db_helper.get_llm_history(): - self.session_memory_lock.acquire() - self.session_memory[history.session_id] = json.loads(history.content) - self.session_memory_lock.release() - except BaseException as e: - logger.warning(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。") - - # 定时保存历史记录 - threading.Thread(target=self.dump_history, daemon=True).start() - - def dump_history(self): - '''转储历史记录''' - time.sleep(30) - while True: - try: - for session_id, content in self.session_memory.items(): - self.db_helper.update_llm_history(session_id, json.dumps(content)) - except BaseException as e: - logger.error("保存 LLM 历史记录失败: " + str(e)) - finally: - time.sleep(10*60) - - def personality_set(self, personality: dict, session_id: str): - if not personality or not personality['prompt']: return - if session_id not in self.session_memory: - self.session_memory[session_id] = [] - self.curr_personality = personality - self.session_personality = {} # 重置 - - new_record = { - "user": { - "role": "system", - "content": personality['prompt'], - }, - 'usage_tokens': 0, # 到该条目的总 token 数 - 'single-tokens': 0 # 该条目的 token 数 - } - - self.session_memory[session_id] = [new_record] + # 各类模型的配置 + self.image_generator_model_configs = None + self.embedding_model_configs = None + if llm_config.image_generation_model_config and llm_config.image_generation_model_config.enable: + self.image_generator_model_configs: Dict = asdict( + llm_config.image_generation_model_config) + self.image_generator_model_configs.pop("enable") + if llm_config.embedding_model and llm_config.embedding_model.enable: + self.embedding_model_configs: Dict = asdict( + llm_config.embedding_model) + self.embedding_model_configs.pop("enable") async def encode_image_bs64(self, image_url: str) -> str: ''' @@ -108,29 +53,12 @@ class ProviderOpenAIOfficial(Provider): ''' if image_url.startswith("http"): image_url = await download_image_by_url(image_url) - + with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode('utf-8') return "data:image/jpeg;base64," + image_bs64 return '' - async def retrieve_context(self, session_id: str): - ''' - 根据 session_id 获取保存的 OpenAI 格式的上下文 - ''' - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - - # 转换为 openai 要求的格式 - context = [] - for record in self.session_memory[session_id]: - if "user" in record and record['user']: - context.append(record['user']) - if "AI" in record and record['AI']: - context.append(record['AI']) - - return context - async def get_models(self): models = [] try: @@ -140,47 +68,6 @@ class ProviderOpenAIOfficial(Provider): self.client.base_url = bu + "/v1" models = await self.client.models.list() return models - - async def assemble_context(self, session_id: str, prompt: str, image_url: str = None): - ''' - 组装上下文,并且根据当前上下文窗口大小截断 - ''' - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - - tokens_num = len(self.tokenizer.encode(prompt)) - previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens'] - - message = { - "usage_tokens": previous_total_tokens_num + tokens_num, - "single_tokens": tokens_num, - "AI": None - } - if image_url: - base_64_image = await self.encode_image_bs64(image_url) - user_content = { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": base_64_image - } - } - ] - } - else: - user_content = { - "role": "user", - "content": prompt - } - - message["user"] = user_content - self.session_memory[session_id].append(message) async def pop_record(self, session_id: str, pop_system_prompt: bool = False): ''' @@ -188,10 +75,10 @@ class ProviderOpenAIOfficial(Provider): ''' if session_id not in self.session_memory: raise Exception("会话 ID 不存在") - + if len(self.session_memory[session_id]) == 0: return None - + for i in range(len(self.session_memory[session_id])): # 检查是否是 system prompt if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system": @@ -206,156 +93,111 @@ class ProviderOpenAIOfficial(Provider): record = self.session_memory[session_id].pop(i) break - # 更新之后所有记录的 usage_tokens - for i in range(len(self.session_memory[session_id])): - self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens'] - logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}。") return record - async def text_chat(self, - prompt: str, - session_id: str, - image_url=None, - tools=None, + async def assemble_context(self, contexts: List, text: str, image_urls: List[str] = None): + ''' + 组装上下文。 + ''' + if image_urls: + for image_url in image_urls: + base_64_image = await self.encode_image_bs64(image_url) + user_content = {"role": "user","content": [ + {"type": "text", "text": text}, + {"type": "image_url", "image_url": {"url": base_64_image}} + ]} + contexts.append(user_content) + else: + user_content = {"role": "user","content": text} + contexts.append(user_content) + + async def text_chat(self, + prompt: str, + session_id: str, + image_urls=None, + tools=None, + contexts=None, **kwargs - ) -> str: + ) -> str: + ''' + 调用 LLM 进行文本对话。 + + @param tools: LLM Function-calling 的工具函数 + @param contexts: 如果不为 None,则会原封不动地使用这个上下文进行对话。 + ''' if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on": return "这是一个测试消息。" - if not session_id: - session_id = "unknown" - if "unknown" in self.session_memory: - del self.session_memory["unknown"] - - if session_id not in self.session_memory: - self.session_memory[session_id] = [] - - if session_id not in self.session_personality or not self.session_personality[session_id]: - self.personality_set(self.curr_personality, session_id) - self.session_personality[session_id] = True - - # 组装上下文,并且根据当前上下文窗口大小截断 - await self.assemble_context(session_id, prompt, image_url) - - # 获取上下文,openai 格式 - contexts = await self.retrieve_context(session_id) - - logger.debug(f"OpenAI 请求上下文:{contexts}") - + + await self.assemble_context(self.session_memory[session_id], prompt, image_urls) + if not contexts: + contexts = [*self.session_memory[session_id]] + if self.curr_personality["prompt"]: + contexts.insert(0, {"role": "system", "content": self.curr_personality["prompt"]}) + + + logger.debug(f"请求上下文:{contexts}") conf = asdict(self.llm_config.model_config) + if tools: + conf['tools'] = tools # start request retry = 0 - rate_limit_retry = 0 - while retry < 3 or rate_limit_retry < 5: - if tools: - completion_coro = self.client.chat.completions.create( - messages=contexts, - stream=False, - tools=tools, - **conf - ) - else: - completion_coro = self.client.chat.completions.create( - messages=contexts, - stream=False, - **conf - ) + while retry < 3: + completion_coro = self.client.chat.completions.create( + messages=contexts, + stream=False, + **conf + ) try: completion = await completion_coro break - except AuthenticationError as e: - api_key = self.chosen_api_key[10:] + "..." - logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)") - self.keys_data[self.chosen_api_key] = False - ok = await self.switch_to_next_key() - if ok: continue - else: raise Exception("所有 OpenAI API Key 目前都不可用。") - except RateLimitError as e: - if "You exceeded your current quota" in str(e): - self.keys_data[self.chosen_api_key] = False - ok = await self.switch_to_next_key() - if ok: continue - else: raise Exception("所有 OpenAI API Key 目前都不可用。") - logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}") - await self.switch_to_next_key() - rate_limit_retry += 1 - await asyncio.sleep(1) - except BadRequestError as e: - raise e - except NotFoundError as e: - raise e except Exception as e: retry += 1 if retry >= 3: logger.error(traceback.format_exc()) - raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。") + raise Exception(f"请求失败:{e}。重试次数已达到上限。") if "maximum context length" in str(e): - logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") + logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") self.pop_record(session_id) - + logger.warning(traceback.format_exc()) - logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。") + logger.warning(f"请求失败:{e}。重试第 {retry} 次。") await asyncio.sleep(1) assert isinstance(completion, ChatCompletion) - logger.debug(f"openai completion: {completion.usage}") - + logger.debug(f"completion: {completion.usage}") + if len(completion.choices) == 0: - raise Exception("OpenAI API 返回的 completion 为空。") + raise Exception("API 返回的 completion 为空。") choice = completion.choices[0] - - usage_tokens = completion.usage.total_tokens - completion_tokens = completion.usage.completion_tokens - self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens - self.session_memory[session_id][-1]['single_tokens'] += completion_tokens - + if choice.message.content: # 返回文本 completion_text = str(choice.message.content).strip() + self.session_memory[session_id].append({ + "role": "assistant", + "content": completion_text + }) + self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id])) + return completion_text elif choice.message.tool_calls and choice.message.tool_calls: # tools call (function calling) return choice.message.tool_calls[0].function - - self.session_memory[session_id][-1]['AI'] = { - "role": "assistant", - "content": completion_text - } + else: + raise Exception("Internal Error") - return completion_text - - async def switch_to_next_key(self): - ''' - 切换到下一个 API Key - ''' - if not self.api_keys: - logger.error("OpenAI API Key 不存在。") - return False - - for key in self.keys_data: - if self.keys_data[key]: - # 没超额 - self.chosen_api_key = key - self.client.api_key = key - logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。") - return True - - return False - async def image_generate(self, prompt: str, session_id: str = None, **kwargs) -> str: ''' 生成图片 ''' retry = 0 - conf = self.image_generator_model_configs - if not conf: - logger.error("图片生成模型配置不存在。") - raise Exception("图片生成模型配置不存在。") - conf.pop("enable") + if not self.image_generator_model_configs: + return while retry < 3: try: images_response = await self.client.images.generate( prompt=prompt, - **conf + **self.image_generator_model_configs ) image_url = images_response.data[0].url return image_url @@ -367,15 +209,25 @@ class ProviderOpenAIOfficial(Provider): logger.warning(f"图片生成请求失败:{e}。重试第 {retry} 次。") await asyncio.sleep(1) - async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool: - if session_id is None: return False + async def get_embedding(self, text) -> List[float]: + ''' + 获取文本的嵌入 + ''' + if not self.embedding_model_configs: + return + try: + embedding = await self.client.embeddings.create( + input=text, + **self.embedding_model_configs + ) + return embedding.data[0].embedding + except Exception as e: + logger.error(f"获取文本嵌入失败:{e}") + + async def forget(self, session_id: str) -> bool: self.session_memory[session_id] = [] - if keep_system_prompt: - self.personality_set(self.curr_personality, session_id) - else: - self.curr_personality = self.DEFAULT_PERSONALITY return True - + def dump_contexts_page(self, session_id: str, size=5, page=1,): ''' 获取缓存的会话 @@ -383,25 +235,21 @@ class ProviderOpenAIOfficial(Provider): contexts_str = "" if session_id in self.session_memory: for record in self.session_memory[session_id]: - if "user" in record and record['user']: - text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content'] + if record['role'] == "user": + text = record['content'][:100] + "..." if len( + record['content']) > 100 else record['content'] contexts_str += f"User: {text}\n\n" - if "AI" in record and record['AI']: - text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content'] + elif record['role'] == "assistant": + text = record['content'][:100] + "..." if len( + record['content']) > 100 else record['content'] contexts_str += f"Assistant: {text}\n\n" else: contexts_str = "会话 ID 不存在。" return contexts_str, len(self.session_memory[session_id]) - - def get_configs(self): - return asdict(self.llm_config) - - def get_keys_data(self): - return self.keys_data def get_curr_key(self): return self.chosen_api_key - def set_key(self, key): - self.client.api_key = key \ No newline at end of file + def get_all_keys(self): + return self.api_keys \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2703d5140..e301c3904 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,12 @@ -pydantic~=1.10.4 +pydantic vchat aiohttp openai qq-botpy chardet~=5.1.0 Pillow -nakuru-project beautifulsoup4 googlesearch-python -tiktoken readability-lxml quart psutil