Compare commits

..

5 Commits

Author SHA1 Message Date
Soulter 853ac4c104 fix: 优化 update 提示 2024-07-27 04:58:15 -04:00
Soulter ed053acad6 update: version 2024-07-27 04:47:57 -04:00
Soulter f147634e51 fix: 修复update异常 2024-07-27 04:43:53 -04:00
Soulter e3b2a68341 Merge pull request #179 from Soulter/refactor-v3.3.0
feat: 新增 Provider 注册接口;新增 provider 指令
2024-07-27 16:31:03 +08:00
Soulter 84c450aef9 feat: 新增 Provider 注册接口;新增 provider 指令 2024-07-27 04:25:27 -04:00
14 changed files with 76 additions and 21 deletions
+1
View File
@@ -10,3 +10,4 @@ cmd_config.json
data/*
cookies.json
logs/
addons/plugins
+1
View File
@@ -101,6 +101,7 @@ class AstrBotBootstrap():
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
self.llm_instance = ProviderOpenAIOfficial(self.context)
self.openai_command_handler.set_provider(self.llm_instance)
self.context.register_provider("internal_openai", self.llm_instance)
logger.info("已启用 OpenAI API 支持。")
def load_plugins(self):
+8 -2
View File
@@ -115,6 +115,9 @@ class MessageHandler():
self.nicks = self.context.nick
self.provider = provider
self.reply_prefix = self.context.reply_prefix
def set_provider(self, provider: Provider):
self.provider = provider
async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult:
'''
@@ -148,7 +151,8 @@ class MessageHandler():
assert(isinstance(cmd_res, CommandResult))
return MessageResult(
cmd_res.message_chain,
is_command_call=True
is_command_call=True,
use_t2i=cmd_res.is_use_t2i
)
# check if the message is a llm-wake-up command
@@ -178,7 +182,9 @@ class MessageHandler():
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider)
else:
llm_result = await provider.text_chat(
msg_plain, message.session_id, image_url
prompt=msg_plain,
session_id=message.session_id,
image_url=image_url
)
except BaseException as e:
logger.error(traceback.format_exc())
+1 -1
View File
@@ -285,7 +285,7 @@ class AstrBotDashBoard():
ret = self.astrbot_updator.check_update(None, None)
return Response(
status="success",
message=str(ret),
message=str(ret) if ret is not None else "已经是最新版本了。",
data={
"has_new_version": ret is not None
}
+30 -1
View File
@@ -26,7 +26,36 @@ class InternalCommandHandler:
self.manager.register("websearch", "网页搜索开关", 10, self.web_search)
self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle)
self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid)
self.manager.register("provider", "查看和切换当前使用的 LLM 资源来源", 10, self.provider)
def provider(self, message: AstrMessageEvent, context: Context):
if len(context.llms) == 0:
return CommandResult().message("当前没有加载任何 LLM 资源。")
tokens = self.manager.command_parser.parse(message.message_str)
if tokens.len == 1:
ret = "## 当前载入的 LLM 资源\n"
for idx, llm in enumerate(context.llms):
ret += f"{idx}. {llm.llm_name}"
if llm.origin:
ret += f" (来源: {llm.origin})"
if context.message_handler.provider == llm.llm_instance:
ret += " (当前使用)"
ret += "\n"
ret += "\n使用 provider <序号> 切换 LLM 资源。"
return CommandResult().message(ret)
else:
try:
idx = int(tokens.get(1))
if idx >= len(context.llms):
return CommandResult().message("provider: 无效的序号。")
context.message_handler.set_provider(context.llms[idx].llm_instance)
return CommandResult().message(f"已经成功切换到 LLM 资源 {context.llms[idx].llm_name}")
except BaseException as e:
return CommandResult().message("provider: 参数错误。")
def set_nick(self, message: AstrMessageEvent, context: Context):
message_str = message.message_str
if message.role != "admin":
+2 -1
View File
@@ -73,7 +73,8 @@ class CommandManager():
if message_str.startswith(command):
logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context)
return command_result
if command_result.hit:
return command_result
async def execute_handler(self,
command: str,
+10
View File
@@ -24,6 +24,7 @@ class CommandResult():
self.success = success
self.message_chain = message_chain
self.command_name = command_name
self.is_use_t2i = None # default
def message(self, message: str):
'''
@@ -61,6 +62,15 @@ class CommandResult():
'''
self.message_chain = [Image.fromFileSystem(path), ]
return self
# def use_t2i(self, use_t2i: bool):
# '''
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
# CommandResult().use_t2i(False)
# '''
# self.is_use_t2i = use_t2i
# return self
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+1 -1
View File
@@ -1 +1 @@
VERSION = '3.3.0'
VERSION = '3.3.2'
+2 -5
View File
@@ -43,13 +43,10 @@ class AstrMessageEvent():
context,
session_id)
return ame
@dataclass
class MessageResult():
result_message: Union[str, list]
is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None
+9
View File
@@ -9,6 +9,7 @@ from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
class Context:
@@ -68,6 +69,14 @@ class Context:
'''
task = asyncio.create_task(coro, name=task_name)
self.ext_tasks.append(task)
def register_provider(self, llm_name: str, provider: Provider, origin: str = ''):
'''
注册一个提供 LLM 资源的 Provider。
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
'''
self.llms.append(RegisteredLLM(llm_name, provider, origin))
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
+2 -2
View File
@@ -122,7 +122,7 @@ class FuncCall():
_c = 0
while _c < 3:
try:
res = self.provider.text_chat(prompt, session_id)
res = self.provider.text_chat(prompt=prompt, session_id=session_id)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT func_call json result",
@@ -187,7 +187,7 @@ class FuncCall():
_c = 0
while _c < 5:
try:
res = self.provider.text_chat(after_prompt, session_id)
res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
# 截取```之间的内容
gu.log(
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
+5 -5
View File
@@ -127,7 +127,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
function_invoked_ret = ""
if official_fc:
# we use official function-calling
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func())
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}")
func_obj = None
@@ -136,14 +136,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
func_obj = i["func_obj"]
break
if not func_obj:
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(result.arguments)
function_invoked_ret = await func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
else:
return result
else:
@@ -162,7 +162,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
has_func = True
if has_func:
await provider.forget(session_id)
await provider.forget(session_id=session_id, )
summary_prompt = f"""
你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
@@ -178,6 +178,6 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
# 相关材料
{function_invoked_ret}"""
ret = await provider.text_chat(summary_prompt, session_id)
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
return ret
return function_invoked_ret
+1 -1
View File
@@ -32,7 +32,7 @@ class NetworkRenderStrategy(RenderStrategy):
"options": {
"full_page": True,
"type": "jpeg",
"quality": 25,
"quality": 40,
}
}
+3 -2
View File
@@ -3,7 +3,7 @@ from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator
from SparkleLogging.utils.core import LogManager
from logging import Logger
from type.config import VERSION
from util.io import on_error
from util.io import on_error, download_file
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -58,7 +58,8 @@ class AstrBotUpdator(RepoZipUpdator):
raise Exception(f"未找到版本号为 {version} 的更新文件。")
try:
self.download_from_repo_url("temp", data['zipball_url'])
# self.download_from_repo_url("temp", file_url)
download_file(file_url, "temp.zip")
self.unzip_file("temp.zip", self.MAIN_PATH)
except BaseException as e:
raise e