Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0de14c4c8b | |||
| 51de0159fb | |||
| 37a756aeb3 | |||
| 353b6ed761 | |||
| 90815b1ac5 | |||
| 8a50786e61 |
+16
-7
@@ -11,15 +11,14 @@ from model.platform.manager import PlatformManager
|
||||
from typing import Union
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from util.cmd_config import AstrBotConfig, try_migrate
|
||||
from util.metrics import MetricUploader
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.log import LogManager
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
@@ -28,7 +27,11 @@ class AstrBotBootstrap():
|
||||
try_migrate()
|
||||
self.config_helper = AstrBotConfig()
|
||||
self.context.config_helper = self.config_helper
|
||||
# set log queue handler
|
||||
LogManager.set_queue_handler(logger, self.context._log_queue)
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
# set log level
|
||||
logger.setLevel(self.config_helper.log_level)
|
||||
# apply proxy settings
|
||||
http_proxy = self.context.config_helper.http_proxy
|
||||
https_proxy = self.context.config_helper.https_proxy
|
||||
@@ -44,6 +47,12 @@ class AstrBotBootstrap():
|
||||
logger.info("未使用代理。")
|
||||
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
# set t2i endpoint
|
||||
if self.context.config_helper.t2i_endpoint:
|
||||
self.context.image_renderer.set_network_endpoint(
|
||||
self.context.config_helper.t2i_endpoint
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
self.command_manager = CommandManager()
|
||||
@@ -66,11 +75,11 @@ class AstrBotBootstrap():
|
||||
self.context.message_handler = self.message_handler
|
||||
self.context.command_manager = self.command_manager
|
||||
|
||||
|
||||
# load dashboard
|
||||
self.dashboard.run_http_server()
|
||||
dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
|
||||
dashboard_ws_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard")
|
||||
dashboard_log_task = asyncio.create_task(self.dashboard.log_consumer(), name="log")
|
||||
|
||||
if self.test_mode:
|
||||
return
|
||||
|
||||
@@ -82,8 +91,8 @@ class AstrBotBootstrap():
|
||||
platform_tasks = self.load_platform()
|
||||
# load metrics uploader
|
||||
metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader")
|
||||
|
||||
tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks]
|
||||
|
||||
tasks = [metrics_upload_task, dashboard_ws_task, dashboard_log_task, *platform_tasks, *self.context.ext_tasks]
|
||||
tasks = [self.handle_task(task) for task in tasks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from model.command.manager import CommandManager
|
||||
from type.message_event import AstrMessageEvent, MessageResult
|
||||
from type.types import Context
|
||||
from type.command import CommandResult
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.func_call import FuncCall
|
||||
|
||||
+19
-1
@@ -2,7 +2,7 @@ from . import DashBoardData
|
||||
from util.cmd_config import AstrBotConfig
|
||||
from dataclasses import dataclass, asdict
|
||||
from util.plugin_dev.api.v1.config import update_config
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.config import CONFIG_METADATA_2
|
||||
@@ -23,6 +23,20 @@ class DashBoardHelper():
|
||||
return float(value)
|
||||
elif type_ == "float" and isinstance(value, int):
|
||||
return float(value)
|
||||
|
||||
def get_default_val_by_type(self, type_: str):
|
||||
if type_ == "int":
|
||||
return 0
|
||||
elif type_ == "float":
|
||||
return 0.0
|
||||
elif type_ == "bool":
|
||||
return False
|
||||
elif type_ == "string":
|
||||
return ""
|
||||
elif type_ == "list":
|
||||
return []
|
||||
elif type_ == "object":
|
||||
return {}
|
||||
|
||||
|
||||
def validate_config(self, data):
|
||||
@@ -32,6 +46,10 @@ class DashBoardHelper():
|
||||
if key not in data:
|
||||
continue
|
||||
value = data[key]
|
||||
# null 转换
|
||||
if value is None:
|
||||
data[key] = self.get_default_val_by_type(meta["type"])
|
||||
continue
|
||||
# 递归验证
|
||||
if meta["type"] == "list" and isinstance(value, list):
|
||||
for item in value:
|
||||
|
||||
+20
-17
@@ -13,7 +13,7 @@ from werkzeug.serving import make_server
|
||||
from astrbot.persist.helper import dbConn
|
||||
from type.types import Context
|
||||
from typing import List
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from dashboard.helper import DashBoardHelper
|
||||
from util.io import get_local_ip_addresses
|
||||
@@ -290,16 +290,6 @@ class AstrBotDashBoard():
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/log")
|
||||
def log():
|
||||
for item in self.ws_clients:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.ws_clients[item].send(request.data.decode()), self.loop).result()
|
||||
except Exception as e:
|
||||
pass
|
||||
return 'ok'
|
||||
|
||||
@self.dashboard_be.get("/api/check_update")
|
||||
def get_update_info():
|
||||
try:
|
||||
@@ -422,18 +412,22 @@ class AstrBotDashBoard():
|
||||
|
||||
async def get_log_history(self):
|
||||
try:
|
||||
with open("logs/astrbot/astrbot.log", "r", encoding="utf-8") as f:
|
||||
return f.readlines()[-100:]
|
||||
dq = self.context._log_queue.get_cache()
|
||||
ret = ""
|
||||
for log in dq:
|
||||
ret += log + "\n\r"
|
||||
return ret
|
||||
except Exception as e:
|
||||
logger.warning(f"读取日志历史失败: {e.__str__()}")
|
||||
return []
|
||||
return ""
|
||||
|
||||
async def __handle_msg(self, websocket, path):
|
||||
address = websocket.remote_address
|
||||
self.ws_clients[address] = websocket
|
||||
data = await self.get_log_history()
|
||||
data = ''.join(data).replace('\n', '\r\n')
|
||||
await websocket.send(data)
|
||||
|
||||
# 发送日志历史
|
||||
await websocket.send(await self.get_log_history())
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
@@ -450,6 +444,15 @@ class AstrBotDashBoard():
|
||||
ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186)
|
||||
logger.info("WebSocket 服务器已启动。")
|
||||
await ws_server
|
||||
|
||||
async def log_consumer(self):
|
||||
while True:
|
||||
log = await self.context._log_queue.get()
|
||||
for ws in self.ws_clients.values():
|
||||
try:
|
||||
await ws.send(log)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def http_server(self):
|
||||
http_server = make_server(
|
||||
|
||||
@@ -6,8 +6,7 @@ import warnings
|
||||
import traceback
|
||||
import mimetypes
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
from util.log import LogManager
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logo_tmpl = r"""
|
||||
@@ -54,10 +53,5 @@ def check_env():
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_env()
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
main()
|
||||
|
||||
@@ -6,7 +6,7 @@ from type.message_event import AstrMessageEvent
|
||||
from type.command import CommandResult
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from util.agent.web_searcher import search_from_bing, fetch_website_content
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from type.command import CommandResult
|
||||
from type.register import RegisteredPlugins
|
||||
from model.command.parser import CommandParser
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from model.command.manager import CommandManager
|
||||
from type.message_event import AstrMessageEvent
|
||||
from type.command import CommandResult
|
||||
from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
|
||||
|
||||
@@ -40,14 +40,18 @@ class Platform():
|
||||
'''
|
||||
pass
|
||||
|
||||
def parse_message_outline(self, message: AstrBotMessage) -> str:
|
||||
def parse_message_outline(self, message: Union[AstrBotMessage, list]) -> str:
|
||||
'''
|
||||
将消息解析成大纲消息形式,如: xxxxx[图片]xxxxx。用于输出日志等。
|
||||
'''
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
ret = ''
|
||||
parsed = message if isinstance(message, list) else message.message
|
||||
if isinstance(message, list):
|
||||
parsed = message
|
||||
elif isinstance(message, AstrBotMessage):
|
||||
parsed = message.message
|
||||
elif isinstance(message, str):
|
||||
return message
|
||||
|
||||
try:
|
||||
for node in parsed:
|
||||
if isinstance(node, Plain):
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
from util.io import port_checker
|
||||
from type.register import RegisteredPlatform
|
||||
from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import (
|
||||
|
||||
@@ -10,7 +10,7 @@ from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig
|
||||
@@ -209,7 +209,7 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
if isinstance(message, AstrBotMessage):
|
||||
logger.info(
|
||||
f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
|
||||
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}")
|
||||
else:
|
||||
logger.info(f"回复消息: {message_chain}")
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from . import Platform
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from type.command import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, NakuruPlatformConfig
|
||||
@@ -171,7 +171,7 @@ class QQNakuru(Platform):
|
||||
(GroupMessage, FriendMessage, GuildMessage))
|
||||
|
||||
logger.info(
|
||||
f"{source.user_id} <- {self.parse_message_outline(res)}")
|
||||
f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(res)}")
|
||||
|
||||
if isinstance(res, str):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
@@ -16,7 +16,7 @@ from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from type.register import RegisteredPlugins
|
||||
from typing import List, Union, Callable
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -13,7 +13,7 @@ from types import ModuleType
|
||||
from type.types import Context
|
||||
from type.plugin import *
|
||||
from type.register import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -15,7 +15,7 @@ from util.io import download_image_by_url
|
||||
from astrbot.persist.helper import dbConn
|
||||
from model.provider.provider import Provider
|
||||
from util.cmd_config import LLMConfig
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from typing import List, Dict
|
||||
|
||||
@@ -336,12 +336,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if tools:
|
||||
completion_coro = self.client.chat.completions.create(
|
||||
messages=contexts,
|
||||
stream=False,
|
||||
tools=tools,
|
||||
**conf
|
||||
)
|
||||
else:
|
||||
completion_coro = self.client.chat.completions.create(
|
||||
messages=contexts,
|
||||
stream=False,
|
||||
**conf
|
||||
)
|
||||
try:
|
||||
|
||||
+1
-1
@@ -13,5 +13,5 @@ websockets
|
||||
flask
|
||||
psutil
|
||||
lxml_html_clean
|
||||
SparkleLogging
|
||||
colorlog
|
||||
aiocqhttp
|
||||
|
||||
@@ -11,16 +11,11 @@ from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
from util.log import LogManager
|
||||
|
||||
from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
os.environ['TEST_MODE'] = 'on'
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from asyncio import Queue
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
|
||||
class CachedQueue(Queue):
|
||||
def __init__(self, maxsize: int = 0, cachesize: int = 200):
|
||||
super().__init__(maxsize)
|
||||
self.cache = deque(maxlen=cachesize)
|
||||
|
||||
def put_nowait(self, item):
|
||||
self.cache.append(item)
|
||||
super().put_nowait(item)
|
||||
|
||||
def get_nowait(self):
|
||||
item = super().get_nowait()
|
||||
return item
|
||||
|
||||
def get(self):
|
||||
item = super().get()
|
||||
return item
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
with self.mutex:
|
||||
self._queue.clear()
|
||||
|
||||
def get_cache(self) -> Deque:
|
||||
return self.cache
|
||||
+5
-1
@@ -1,4 +1,4 @@
|
||||
VERSION = '3.3.12'
|
||||
VERSION = '3.3.15'
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"qqbot": {
|
||||
@@ -169,6 +169,8 @@ DEFAULT_CONFIG_VERSION_2 = {
|
||||
"username": "",
|
||||
"password": "",
|
||||
},
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
}
|
||||
|
||||
# 这个是用于迁移旧版本配置文件的映射表
|
||||
@@ -350,4 +352,6 @@ CONFIG_METADATA_2 = {
|
||||
"password": {"description": "密码", "type": "string"},
|
||||
}
|
||||
},
|
||||
"log_level": {"description": "控制台日志级别(DEBUG, INFO, WARNING, ERROR)", "type": "string"},
|
||||
"t2i_endpoint": {"description": "文本转图像服务接口(为空时使用公共服务器)", "type": "string"},
|
||||
}
|
||||
|
||||
+3
-1
@@ -13,7 +13,7 @@ from type.middleware import Middleware
|
||||
from type.astrbot_message import MessageType
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from model.provider.provider import Provider
|
||||
from util.agent.func_call import FuncCall
|
||||
from type.cached_queue import CachedQueue
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -49,6 +49,8 @@ class Context:
|
||||
self.command_manager = None
|
||||
self.running = True
|
||||
|
||||
self._log_queue = CachedQueue()
|
||||
|
||||
# useless
|
||||
# self.reply_prefix = ""
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from util.websearch.bing import Bing
|
||||
from util.websearch.sogo import Sogo
|
||||
from util.websearch.google import Google
|
||||
from model.provider.provider import Provider
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.message_event import AstrMessageEvent
|
||||
|
||||
@@ -133,6 +133,8 @@ class AstrBotConfig():
|
||||
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
|
||||
platform: List[PlatformConfig] = field(default_factory=list)
|
||||
wake_prefix: List[str] = field(default_factory=list)
|
||||
log_level: str = "INFO"
|
||||
t2i_endpoint: str = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.init_configs()
|
||||
@@ -174,6 +176,8 @@ class AstrBotConfig():
|
||||
self.http_proxy=data.get("http_proxy", "")
|
||||
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
|
||||
self.wake_prefix=data.get("wake_prefix", [])
|
||||
self.log_level=data.get("log_level", "INFO")
|
||||
self.t2i_endpoint=data.get("t2i_endpoint", "")
|
||||
|
||||
def migrate_config_1_2(self, old: dict) -> dict:
|
||||
'''将配置文件从版本 1 迁移至版本 2'''
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ import time
|
||||
import aiohttp
|
||||
|
||||
from PIL import Image
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
+51
@@ -0,0 +1,51 @@
|
||||
import logging, asyncio, colorlog
|
||||
from type.cached_queue import CachedQueue
|
||||
|
||||
log_color_config = {
|
||||
'DEBUG': 'bold_blue', 'INFO': 'bold_cyan',
|
||||
'WARNING': 'bold_yellow', 'ERROR': 'red',
|
||||
'CRITICAL': 'bold_red', 'RESET': 'reset',
|
||||
'asctime': 'green'
|
||||
}
|
||||
|
||||
class LogQueueHandler(logging.Handler):
|
||||
def __init__(self, log_queue: CachedQueue):
|
||||
super().__init__()
|
||||
self.log_queue = log_queue
|
||||
|
||||
def emit(self, record):
|
||||
log_entry = self.format(record)
|
||||
try:
|
||||
self.log_queue.put_nowait(log_entry)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class LogManager:
|
||||
|
||||
@classmethod
|
||||
def GetLogger(cls, log_name: str = 'default'):
|
||||
logger = logging.getLogger(log_name)
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_formatter = colorlog.ColoredFormatter(
|
||||
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(funcName)s|%(filename)s:%(lineno)d]: %(message)s %(reset)s',
|
||||
datefmt='%H:%M:%S',
|
||||
log_colors=log_color_config
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
@classmethod
|
||||
def set_queue_handler(cls, logger: logging.Logger, log_queue: CachedQueue):
|
||||
handler = LogQueueHandler(log_queue)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
if logger.handlers:
|
||||
handler.setFormatter(logger.handlers[0].formatter)
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
logger.addHandler(handler)
|
||||
@@ -1,17 +1,23 @@
|
||||
from util.t2i.strategies.local_strategy import LocalRenderStrategy
|
||||
from util.t2i.strategies.network_strategy import NetworkRenderStrategy
|
||||
from util.t2i.context import RenderContext
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
class TextToImageRenderer:
|
||||
def __init__(self):
|
||||
self.network_strategy = NetworkRenderStrategy()
|
||||
def __init__(self, endpoint_url: str = None):
|
||||
self.network_strategy = NetworkRenderStrategy(endpoint_url)
|
||||
self.local_strategy = LocalRenderStrategy()
|
||||
self.context = RenderContext(self.network_strategy)
|
||||
|
||||
def set_network_endpoint(self, endpoint_url: str):
|
||||
'''设置 t2i 的网络端点。
|
||||
'''
|
||||
logger.info("文本转图像服务接口: " + endpoint_url)
|
||||
self.network_strategy.set_endpoint(endpoint_url)
|
||||
|
||||
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False):
|
||||
'''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
|
||||
@param tmpl_str: HTML Jinja2 模板。
|
||||
|
||||
@@ -10,9 +10,16 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||
class NetworkRenderStrategy(RenderStrategy):
|
||||
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
|
||||
super().__init__()
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
|
||||
|
||||
def set_endpoint(self, base_url: str):
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
|
||||
async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str:
|
||||
'''使用自定义文转图模板'''
|
||||
post_data = {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os, psutil, sys, time
|
||||
from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from type.config import VERSION
|
||||
from util.io import on_error, download_file
|
||||
|
||||
@@ -4,7 +4,7 @@ from util.updator.zip_updator import RepoZipUpdator
|
||||
from util.io import remove_dir
|
||||
from type.register import RegisteredPlugin
|
||||
from typing import Union
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from util.io import on_error
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import aiohttp, os, zipfile, shutil
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from util.log import LogManager
|
||||
from logging import Logger
|
||||
from util.io import on_error, download_file
|
||||
|
||||
|
||||
Reference in New Issue
Block a user