Compare commits

...

5 Commits

Author SHA1 Message Date
Soulter 853ac4c104 fix: 优化 update 提示 2024-07-27 04:58:15 -04:00
Soulter ed053acad6 update: version 2024-07-27 04:47:57 -04:00
Soulter f147634e51 fix: 修复update异常 2024-07-27 04:43:53 -04:00
Soulter e3b2a68341 Merge pull request #179 from Soulter/refactor-v3.3.0
feat: 新增 Provider 注册接口;新增 provider 指令
2024-07-27 16:31:03 +08:00
Soulter 84c450aef9 feat: 新增 Provider 注册接口;新增 provider 指令 2024-07-27 04:25:27 -04:00
14 changed files with 76 additions and 21 deletions
+1
View File
@@ -10,3 +10,4 @@ cmd_config.json
data/* data/*
cookies.json cookies.json
logs/ logs/
addons/plugins
+1
View File
@@ -101,6 +101,7 @@ class AstrBotBootstrap():
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
self.llm_instance = ProviderOpenAIOfficial(self.context) self.llm_instance = ProviderOpenAIOfficial(self.context)
self.openai_command_handler.set_provider(self.llm_instance) self.openai_command_handler.set_provider(self.llm_instance)
self.context.register_provider("internal_openai", self.llm_instance)
logger.info("已启用 OpenAI API 支持。") logger.info("已启用 OpenAI API 支持。")
def load_plugins(self): def load_plugins(self):
+8 -2
View File
@@ -115,6 +115,9 @@ class MessageHandler():
self.nicks = self.context.nick self.nicks = self.context.nick
self.provider = provider self.provider = provider
self.reply_prefix = self.context.reply_prefix self.reply_prefix = self.context.reply_prefix
def set_provider(self, provider: Provider):
self.provider = provider
async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult: async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult:
''' '''
@@ -148,7 +151,8 @@ class MessageHandler():
assert(isinstance(cmd_res, CommandResult)) assert(isinstance(cmd_res, CommandResult))
return MessageResult( return MessageResult(
cmd_res.message_chain, cmd_res.message_chain,
is_command_call=True is_command_call=True,
use_t2i=cmd_res.is_use_t2i
) )
# check if the message is a llm-wake-up command # check if the message is a llm-wake-up command
@@ -178,7 +182,9 @@ class MessageHandler():
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider) llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider)
else: else:
llm_result = await provider.text_chat( llm_result = await provider.text_chat(
msg_plain, message.session_id, image_url prompt=msg_plain,
session_id=message.session_id,
image_url=image_url
) )
except BaseException as e: except BaseException as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
+1 -1
View File
@@ -285,7 +285,7 @@ class AstrBotDashBoard():
ret = self.astrbot_updator.check_update(None, None) ret = self.astrbot_updator.check_update(None, None)
return Response( return Response(
status="success", status="success",
message=str(ret), message=str(ret) if ret is not None else "已经是最新版本了。",
data={ data={
"has_new_version": ret is not None "has_new_version": ret is not None
} }
+30 -1
View File
@@ -26,7 +26,36 @@ class InternalCommandHandler:
self.manager.register("websearch", "网页搜索开关", 10, self.web_search) self.manager.register("websearch", "网页搜索开关", 10, self.web_search)
self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle) self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle)
self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid) self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid)
self.manager.register("provider", "查看和切换当前使用的 LLM 资源来源", 10, self.provider)
def provider(self, message: AstrMessageEvent, context: Context):
if len(context.llms) == 0:
return CommandResult().message("当前没有加载任何 LLM 资源。")
tokens = self.manager.command_parser.parse(message.message_str)
if tokens.len == 1:
ret = "## 当前载入的 LLM 资源\n"
for idx, llm in enumerate(context.llms):
ret += f"{idx}. {llm.llm_name}"
if llm.origin:
ret += f" (来源: {llm.origin})"
if context.message_handler.provider == llm.llm_instance:
ret += " (当前使用)"
ret += "\n"
ret += "\n使用 provider <序号> 切换 LLM 资源。"
return CommandResult().message(ret)
else:
try:
idx = int(tokens.get(1))
if idx >= len(context.llms):
return CommandResult().message("provider: 无效的序号。")
context.message_handler.set_provider(context.llms[idx].llm_instance)
return CommandResult().message(f"已经成功切换到 LLM 资源 {context.llms[idx].llm_name}")
except BaseException as e:
return CommandResult().message("provider: 参数错误。")
def set_nick(self, message: AstrMessageEvent, context: Context): def set_nick(self, message: AstrMessageEvent, context: Context):
message_str = message.message_str message_str = message.message_str
if message.role != "admin": if message.role != "admin":
+2 -1
View File
@@ -73,7 +73,8 @@ class CommandManager():
if message_str.startswith(command): if message_str.startswith(command):
logger.info(f"触发 {command} 指令。") logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context) command_result = await self.execute_handler(command, message_event, context)
return command_result if command_result.hit:
return command_result
async def execute_handler(self, async def execute_handler(self,
command: str, command: str,
+10
View File
@@ -24,6 +24,7 @@ class CommandResult():
self.success = success self.success = success
self.message_chain = message_chain self.message_chain = message_chain
self.command_name = command_name self.command_name = command_name
self.is_use_t2i = None # default
def message(self, message: str): def message(self, message: str):
''' '''
@@ -61,6 +62,15 @@ class CommandResult():
''' '''
self.message_chain = [Image.fromFileSystem(path), ] self.message_chain = [Image.fromFileSystem(path), ]
return self return self
# def use_t2i(self, use_t2i: bool):
# '''
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
# CommandResult().use_t2i(False)
# '''
# self.is_use_t2i = use_t2i
# return self
def _result_tuple(self): def _result_tuple(self):
return (self.success, self.message_chain, self.command_name) return (self.success, self.message_chain, self.command_name)
+1 -1
View File
@@ -1 +1 @@
VERSION = '3.3.0' VERSION = '3.3.2'
+2 -5
View File
@@ -43,13 +43,10 @@ class AstrMessageEvent():
context, context,
session_id) session_id)
return ame return ame
@dataclass @dataclass
class MessageResult(): class MessageResult():
result_message: Union[str, list] result_message: Union[str, list]
is_command_call: Optional[bool] = False is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None callback: Optional[callable] = None
+9
View File
@@ -9,6 +9,7 @@ from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator from util.updator.plugin_updator import PluginUpdator
from model.plugin.command import PluginCommandBridge from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
class Context: class Context:
@@ -68,6 +69,14 @@ class Context:
''' '''
task = asyncio.create_task(coro, name=task_name) task = asyncio.create_task(coro, name=task_name)
self.ext_tasks.append(task) self.ext_tasks.append(task)
def register_provider(self, llm_name: str, provider: Provider, origin: str = ''):
'''
注册一个提供 LLM 资源的 Provider。
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
'''
self.llms.append(RegisteredLLM(llm_name, provider, origin))
def find_platform(self, platform_name: str) -> RegisteredPlatform: def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms: for platform in self.platforms:
+2 -2
View File
@@ -122,7 +122,7 @@ class FuncCall():
_c = 0 _c = 0
while _c < 3: while _c < 3:
try: try:
res = self.provider.text_chat(prompt, session_id) res = self.provider.text_chat(prompt=prompt, session_id=session_id)
if res.find('```') != -1: if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')] res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT func_call json result", gu.log("REVGPT func_call json result",
@@ -187,7 +187,7 @@ class FuncCall():
_c = 0 _c = 0
while _c < 5: while _c < 5:
try: try:
res = self.provider.text_chat(after_prompt, session_id) res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
# 截取```之间的内容 # 截取```之间的内容
gu.log( gu.log(
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) "DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
+5 -5
View File
@@ -127,7 +127,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
function_invoked_ret = "" function_invoked_ret = ""
if official_fc: if official_fc:
# we use official function-calling # we use official function-calling
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func()) result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
if isinstance(result, Function): if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}") logger.debug(f"web_searcher - function-calling: {result}")
func_obj = None func_obj = None
@@ -136,14 +136,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
func_obj = i["func_obj"] func_obj = i["func_obj"]
break break
if not func_obj: if not func_obj:
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
try: try:
args = json.loads(result.arguments) args = json.loads(result.arguments)
function_invoked_ret = await func_obj(**args) function_invoked_ret = await func_obj(**args)
has_func = True has_func = True
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
else: else:
return result return result
else: else:
@@ -162,7 +162,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
has_func = True has_func = True
if has_func: if has_func:
await provider.forget(session_id) await provider.forget(session_id=session_id, )
summary_prompt = f""" summary_prompt = f"""
你是一个专业且高效的助手,你的任务是 你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; 1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
@@ -178,6 +178,6 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
# 相关材料 # 相关材料
{function_invoked_ret}""" {function_invoked_ret}"""
ret = await provider.text_chat(summary_prompt, session_id) ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
return ret return ret
return function_invoked_ret return function_invoked_ret
+1 -1
View File
@@ -32,7 +32,7 @@ class NetworkRenderStrategy(RenderStrategy):
"options": { "options": {
"full_page": True, "full_page": True,
"type": "jpeg", "type": "jpeg",
"quality": 25, "quality": 40,
} }
} }
+3 -2
View File
@@ -3,7 +3,7 @@ from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
from logging import Logger from logging import Logger
from type.config import VERSION from type.config import VERSION
from util.io import on_error from util.io import on_error, download_file
logger: Logger = LogManager.GetLogger(log_name='astrbot') logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -58,7 +58,8 @@ class AstrBotUpdator(RepoZipUpdator):
raise Exception(f"未找到版本号为 {version} 的更新文件。") raise Exception(f"未找到版本号为 {version} 的更新文件。")
try: try:
self.download_from_repo_url("temp", data['zipball_url']) # self.download_from_repo_url("temp", file_url)
download_file(file_url, "temp.zip")
self.unzip_file("temp.zip", self.MAIN_PATH) self.unzip_file("temp.zip", self.MAIN_PATH)
except BaseException as e: except BaseException as e:
raise e raise e