diff --git a/astrbot/core.py b/astrbot/core.py index 229d55e91..e5a64adcd 100644 --- a/astrbot/core.py +++ b/astrbot/core.py @@ -51,16 +51,13 @@ llm_wake_prefix = "" # 百度内容审核实例 baidu_judge = None -# CLI -PLATFORM_CLI = 'cli' - # 全局对象 _global_object: GlobalObject = None def privider_chooser(cfg): l = [] - if 'openai' in cfg and len(cfg['openai']['key']) > 0 and cfg['openai']['key'][0] is not None: + if 'openai' in cfg and len(cfg['openai']['key']) and cfg['openai']['key'][0]: l.append('openai_official') return l @@ -157,7 +154,7 @@ def init(): logger.info("独立会话配置错误: "+str(e)) nick_qq = cc.get("nick_qq", None) - if nick_qq == None: + if not nick_qq: nick_qq = ("ai", "!", "!") if isinstance(nick_qq, str): nick_qq = (nick_qq,) @@ -177,28 +174,24 @@ def init(): logger.info( f"成功载入 {len(_global_object.cached_plugins)} 个插件") else: - logger.info(err) + logger.error(err) if chosen_provider is None: llm_command_instance[NONE_LLM] = _command chosen_provider = NONE_LLM logger.info("正在载入机器人消息平台") - # logger.info("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)") - platform_str = "" # GOCQ if 'gocqbot' in cfg and cfg['gocqbot']['enable']: logger.info("启用 QQ_GOCQ 机器人消息平台") threading.Thread(target=run_gocq_bot, args=( cfg, _global_object), daemon=True).start() - platform_str += "QQ_GOCQ," # QQ频道 if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None: logger.info("启用 QQ_OFFICIAL 机器人消息平台") threading.Thread(target=run_qqchan_bot, args=( cfg, _global_object), daemon=True).start() - platform_str += "QQ_OFFICIAL," # 初始化dashboard _global_object.dashboard_data = DashBoardData( @@ -219,19 +212,15 @@ def init(): logger.info( "如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。") logger.info("请给 https://github.com/Soulter/AstrBot 点个 star。") - if platform_str == '': - platform_str = "(未启动任何平台,请前往面板添加)" logger.info(f"🎉 项目启动完成") dashboard_thread.join() -''' -运行 QQ_OFFICIAL 机器人 -''' - - def run_qqchan_bot(cfg: dict, global_object: GlobalObject): + ''' + 运行 QQ_OFFICIAL 机器人 + ''' try: from model.platform.qq_official import QQOfficial qqchannel_bot = QQOfficial( @@ -244,14 +233,11 @@ def run_qqchan_bot(cfg: dict, global_object: GlobalObject): logger.error(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。") -''' -运行 QQ_GOCQ 机器人 -''' - - def run_gocq_bot(cfg: dict, _global_object: GlobalObject): + ''' + 运行 QQ_GOCQ 机器人 + ''' from model.platform.qq_gocq import QQGOCQ - noticed = False host = cc.get("gocq_host", "127.0.0.1") port = cc.get("gocq_websocket_port", 6700) @@ -278,12 +264,10 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject): input("启动QQ机器人出现错误"+str(e)) -''' -检查发言频率 -''' - - def check_frequency(id) -> bool: + ''' + 检查发言频率 + ''' ts = int(time.time()) if id in user_frequency: if ts-user_frequency[id]['time'] > frequency_time: @@ -324,11 +308,10 @@ async def oper_msg(message: AstrBotMessage, platform: str 所注册的平台的名称。如果没有注册,将抛出一个异常。 """ global chosen_provider, _global_object - message_str = '' - session_id = session_id - role = role + message_str = message.message_str hit = False # 是否命中指令 command_result = () # 调用指令返回的结果 + llm_result_str = "" # 获取平台实例 reg_platform: RegisteredPlatform = None @@ -342,35 +325,13 @@ async def oper_msg(message: AstrBotMessage, # 统计数据,如频道消息量 await record_message(platform, session_id) - for i in message.message: - if isinstance(i, Plain): - message_str += i.text.strip() - if message_str == "": + if not message_str: return MessageResult("Hi~") # 检查发言频率 if not check_frequency(message.sender.user_id): return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。') - # 检查是否是更换语言模型的请求 - temp_switch = "" - if message_str.startswith('/gpt'): - target = chosen_provider - if message_str.startswith('/gpt'): - target = OPENAI_OFFICIAL - l = message_str.split(' ') - if len(l) > 1 and l[1] != "": - # 临时对话模式,先记录下之前的语言模型,回答完毕后再切回 - temp_switch = chosen_provider - chosen_provider = target - message_str = l[1] - else: - chosen_provider = target - cc.put("chosen_provider", chosen_provider) - return MessageResult(f"已切换至【{chosen_provider}】") - - llm_result_str = "" - # check commands and plugins message_str_no_wake_prefix = message_str for wake_prefix in _global_object.nick: # nick: tuple @@ -400,7 +361,7 @@ async def oper_msg(message: AstrBotMessage, logger.info("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。") return try: - if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix): + if llm_wake_prefix and not message_str.startswith(llm_wake_prefix): return # check image url image_url = None @@ -418,7 +379,7 @@ async def oper_msg(message: AstrBotMessage, message_str = message_str[3:] web_sch_flag = True else: - message_str += " " + cc.get("llm_env_prompt", "") + message_str += "\n" + cc.get("llm_env_prompt", "") if chosen_provider == OPENAI_OFFICIAL: if _global_object.web_search or web_sch_flag: official_fc = chosen_provider == OPENAI_OFFICIAL @@ -431,32 +392,15 @@ async def oper_msg(message: AstrBotMessage, logger.error(f"调用异常:{traceback.format_exc()}") return MessageResult(f"调用异常。详细原因:{str(e)}") - # 切换回原来的语言模型 - if temp_switch != "": - chosen_provider = temp_switch - if hit: # 有指令或者插件触发 # command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型) - if command_result == None: + if not command_result: return - command = command_result[2] - if not command_result[0]: return MessageResult(f"指令调用错误: \n{str(command_result[1])}") - - # 画图指令 - if command == 'draw': - # 保存到本地 - path = await gu.download_image_by_url(command_result[1]) - return MessageResult([Image.fromFileSystem(path)]) - # 其他指令 - else: - try: - return MessageResult(command_result[1]) - except BaseException as e: - return MessageResult(f"回复消息出错: {str(e)}") - return + if isinstance(command_result[1], (list, str)): + return MessageResult(command_result[1]) # 敏感过滤 # 过滤不合适的词 @@ -468,7 +412,4 @@ async def oper_msg(message: AstrBotMessage, if not check: return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}") # 发送信息 - try: - return MessageResult(llm_result_str) - except BaseException as e: - logger.info("回复消息错误: \n"+str(e)) + return MessageResult(llm_result_str) diff --git a/model/command/openai_official.py b/model/command/openai_official.py index bafac6806..34f3f18b5 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -1,11 +1,13 @@ from model.command.command import Command from model.provider.openai_official import ProviderOpenAIOfficial, MODELS from util.personality import personalities +from util.general_utils import download_image_by_url from type.types import GlobalObject from type.command import CommandItem from SparkleLogging.utils.core import LogManager from logging import Logger from openai._exceptions import NotFoundError +from nakuru.entities.components import Image logger: Logger = LogManager.GetLogger(log_name='astrbot-core') @@ -248,4 +250,6 @@ class CommandOpenAIOfficial(Command): return False, "未启用 OpenAI 官方 API", "draw" message = message.removeprefix("/").removeprefix("画") img_url = await self.provider.image_generate(message) - return True, img_url, "draw" \ No newline at end of file + p = await download_image_by_url(url=img_url) + with open(p, 'rb') as f: + return True, [Image.fromBytes(f.read())], "draw" \ No newline at end of file