chore: clean codes
This commit is contained in:
+21
-80
@@ -51,16 +51,13 @@ llm_wake_prefix = ""
|
||||
# 百度内容审核实例
|
||||
baidu_judge = None
|
||||
|
||||
# CLI
|
||||
PLATFORM_CLI = 'cli'
|
||||
|
||||
# 全局对象
|
||||
_global_object: GlobalObject = None
|
||||
|
||||
|
||||
def privider_chooser(cfg):
|
||||
l = []
|
||||
if 'openai' in cfg and len(cfg['openai']['key']) > 0 and cfg['openai']['key'][0] is not None:
|
||||
if 'openai' in cfg and len(cfg['openai']['key']) and cfg['openai']['key'][0]:
|
||||
l.append('openai_official')
|
||||
return l
|
||||
|
||||
@@ -157,7 +154,7 @@ def init():
|
||||
logger.info("独立会话配置错误: "+str(e))
|
||||
|
||||
nick_qq = cc.get("nick_qq", None)
|
||||
if nick_qq == None:
|
||||
if not nick_qq:
|
||||
nick_qq = ("ai", "!", "!")
|
||||
if isinstance(nick_qq, str):
|
||||
nick_qq = (nick_qq,)
|
||||
@@ -177,28 +174,24 @@ def init():
|
||||
logger.info(
|
||||
f"成功载入 {len(_global_object.cached_plugins)} 个插件")
|
||||
else:
|
||||
logger.info(err)
|
||||
logger.error(err)
|
||||
|
||||
if chosen_provider is None:
|
||||
llm_command_instance[NONE_LLM] = _command
|
||||
chosen_provider = NONE_LLM
|
||||
|
||||
logger.info("正在载入机器人消息平台")
|
||||
# logger.info("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)")
|
||||
platform_str = ""
|
||||
# GOCQ
|
||||
if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
|
||||
logger.info("启用 QQ_GOCQ 机器人消息平台")
|
||||
threading.Thread(target=run_gocq_bot, args=(
|
||||
cfg, _global_object), daemon=True).start()
|
||||
platform_str += "QQ_GOCQ,"
|
||||
|
||||
# QQ频道
|
||||
if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None:
|
||||
logger.info("启用 QQ_OFFICIAL 机器人消息平台")
|
||||
threading.Thread(target=run_qqchan_bot, args=(
|
||||
cfg, _global_object), daemon=True).start()
|
||||
platform_str += "QQ_OFFICIAL,"
|
||||
|
||||
# 初始化dashboard
|
||||
_global_object.dashboard_data = DashBoardData(
|
||||
@@ -219,19 +212,15 @@ def init():
|
||||
logger.info(
|
||||
"如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。")
|
||||
logger.info("请给 https://github.com/Soulter/AstrBot 点个 star。")
|
||||
if platform_str == '':
|
||||
platform_str = "(未启动任何平台,请前往面板添加)"
|
||||
logger.info(f"🎉 项目启动完成")
|
||||
|
||||
dashboard_thread.join()
|
||||
|
||||
|
||||
'''
|
||||
运行 QQ_OFFICIAL 机器人
|
||||
'''
|
||||
|
||||
|
||||
def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
|
||||
'''
|
||||
运行 QQ_OFFICIAL 机器人
|
||||
'''
|
||||
try:
|
||||
from model.platform.qq_official import QQOfficial
|
||||
qqchannel_bot = QQOfficial(
|
||||
@@ -244,14 +233,11 @@ def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
|
||||
logger.error(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。")
|
||||
|
||||
|
||||
'''
|
||||
运行 QQ_GOCQ 机器人
|
||||
'''
|
||||
|
||||
|
||||
def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
|
||||
'''
|
||||
运行 QQ_GOCQ 机器人
|
||||
'''
|
||||
from model.platform.qq_gocq import QQGOCQ
|
||||
|
||||
noticed = False
|
||||
host = cc.get("gocq_host", "127.0.0.1")
|
||||
port = cc.get("gocq_websocket_port", 6700)
|
||||
@@ -278,12 +264,10 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
|
||||
input("启动QQ机器人出现错误"+str(e))
|
||||
|
||||
|
||||
'''
|
||||
检查发言频率
|
||||
'''
|
||||
|
||||
|
||||
def check_frequency(id) -> bool:
|
||||
'''
|
||||
检查发言频率
|
||||
'''
|
||||
ts = int(time.time())
|
||||
if id in user_frequency:
|
||||
if ts-user_frequency[id]['time'] > frequency_time:
|
||||
@@ -324,11 +308,10 @@ async def oper_msg(message: AstrBotMessage,
|
||||
platform: str 所注册的平台的名称。如果没有注册,将抛出一个异常。
|
||||
"""
|
||||
global chosen_provider, _global_object
|
||||
message_str = ''
|
||||
session_id = session_id
|
||||
role = role
|
||||
message_str = message.message_str
|
||||
hit = False # 是否命中指令
|
||||
command_result = () # 调用指令返回的结果
|
||||
llm_result_str = ""
|
||||
|
||||
# 获取平台实例
|
||||
reg_platform: RegisteredPlatform = None
|
||||
@@ -342,35 +325,13 @@ async def oper_msg(message: AstrBotMessage,
|
||||
# 统计数据,如频道消息量
|
||||
await record_message(platform, session_id)
|
||||
|
||||
for i in message.message:
|
||||
if isinstance(i, Plain):
|
||||
message_str += i.text.strip()
|
||||
if message_str == "":
|
||||
if not message_str:
|
||||
return MessageResult("Hi~")
|
||||
|
||||
# 检查发言频率
|
||||
if not check_frequency(message.sender.user_id):
|
||||
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。')
|
||||
|
||||
# 检查是否是更换语言模型的请求
|
||||
temp_switch = ""
|
||||
if message_str.startswith('/gpt'):
|
||||
target = chosen_provider
|
||||
if message_str.startswith('/gpt'):
|
||||
target = OPENAI_OFFICIAL
|
||||
l = message_str.split(' ')
|
||||
if len(l) > 1 and l[1] != "":
|
||||
# 临时对话模式,先记录下之前的语言模型,回答完毕后再切回
|
||||
temp_switch = chosen_provider
|
||||
chosen_provider = target
|
||||
message_str = l[1]
|
||||
else:
|
||||
chosen_provider = target
|
||||
cc.put("chosen_provider", chosen_provider)
|
||||
return MessageResult(f"已切换至【{chosen_provider}】")
|
||||
|
||||
llm_result_str = ""
|
||||
|
||||
# check commands and plugins
|
||||
message_str_no_wake_prefix = message_str
|
||||
for wake_prefix in _global_object.nick: # nick: tuple
|
||||
@@ -400,7 +361,7 @@ async def oper_msg(message: AstrBotMessage,
|
||||
logger.info("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。")
|
||||
return
|
||||
try:
|
||||
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
|
||||
if llm_wake_prefix and not message_str.startswith(llm_wake_prefix):
|
||||
return
|
||||
# check image url
|
||||
image_url = None
|
||||
@@ -418,7 +379,7 @@ async def oper_msg(message: AstrBotMessage,
|
||||
message_str = message_str[3:]
|
||||
web_sch_flag = True
|
||||
else:
|
||||
message_str += " " + cc.get("llm_env_prompt", "")
|
||||
message_str += "\n" + cc.get("llm_env_prompt", "")
|
||||
if chosen_provider == OPENAI_OFFICIAL:
|
||||
if _global_object.web_search or web_sch_flag:
|
||||
official_fc = chosen_provider == OPENAI_OFFICIAL
|
||||
@@ -431,32 +392,15 @@ async def oper_msg(message: AstrBotMessage,
|
||||
logger.error(f"调用异常:{traceback.format_exc()}")
|
||||
return MessageResult(f"调用异常。详细原因:{str(e)}")
|
||||
|
||||
# 切换回原来的语言模型
|
||||
if temp_switch != "":
|
||||
chosen_provider = temp_switch
|
||||
|
||||
if hit:
|
||||
# 有指令或者插件触发
|
||||
# command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
|
||||
if command_result == None:
|
||||
if not command_result:
|
||||
return
|
||||
command = command_result[2]
|
||||
|
||||
if not command_result[0]:
|
||||
return MessageResult(f"指令调用错误: \n{str(command_result[1])}")
|
||||
|
||||
# 画图指令
|
||||
if command == 'draw':
|
||||
# 保存到本地
|
||||
path = await gu.download_image_by_url(command_result[1])
|
||||
return MessageResult([Image.fromFileSystem(path)])
|
||||
# 其他指令
|
||||
else:
|
||||
try:
|
||||
return MessageResult(command_result[1])
|
||||
except BaseException as e:
|
||||
return MessageResult(f"回复消息出错: {str(e)}")
|
||||
return
|
||||
if isinstance(command_result[1], (list, str)):
|
||||
return MessageResult(command_result[1])
|
||||
|
||||
# 敏感过滤
|
||||
# 过滤不合适的词
|
||||
@@ -468,7 +412,4 @@ async def oper_msg(message: AstrBotMessage,
|
||||
if not check:
|
||||
return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}")
|
||||
# 发送信息
|
||||
try:
|
||||
return MessageResult(llm_result_str)
|
||||
except BaseException as e:
|
||||
logger.info("回复消息错误: \n"+str(e))
|
||||
return MessageResult(llm_result_str)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from model.command.command import Command
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
|
||||
from util.personality import personalities
|
||||
from util.general_utils import download_image_by_url
|
||||
from type.types import GlobalObject
|
||||
from type.command import CommandItem
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from openai._exceptions import NotFoundError
|
||||
from nakuru.entities.components import Image
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||
|
||||
@@ -248,4 +250,6 @@ class CommandOpenAIOfficial(Command):
|
||||
return False, "未启用 OpenAI 官方 API", "draw"
|
||||
message = message.removeprefix("/").removeprefix("画")
|
||||
img_url = await self.provider.image_generate(message)
|
||||
return True, img_url, "draw"
|
||||
p = await download_image_by_url(url=img_url)
|
||||
with open(p, 'rb') as f:
|
||||
return True, [Image.fromBytes(f.read())], "draw"
|
||||
Reference in New Issue
Block a user