Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e14ed804da | |||
| 8e4e49df20 | |||
| 5d856900ef | |||
| 380a68b96c | |||
| 8879bd7e9d | |||
| 2cce09400f | |||
| 54d26dcd38 |
@@ -0,0 +1,65 @@
|
||||
# CONTRIBUTING
|
||||
|
||||
## 贡献指南
|
||||
|
||||
首先,感谢您花时间做出贡献!❤️
|
||||
|
||||
所有类型的贡献都受到鼓励和重视。有关不同的帮助方式和处理方式的详细信息,请参阅[目录](#目录)。在做出贡献之前,请确保阅读相关部分。这将使我们维护人员的工作变得更加容易,并为所有参与者带来顺畅的体验。社区期待您的贡献。🎉
|
||||
|
||||
### 目录
|
||||
|
||||
- [报告问题](#报告问题)
|
||||
- [提交代码更改](#提交代码更改)
|
||||
|
||||
### 报告问题
|
||||
|
||||
如果您在使用 AstrBot 时遇到任何问题,请按照以下步骤报告:
|
||||
|
||||
1. **检查现有问题**:在提交新问题之前,请先检查 [Issues](https://github.com/AstrBotDevs/AstrBot/issues) 中是否已经存在类似的问题。
|
||||
2. **创建新问题**:如果没有类似的问题,请创建一个新问题。请确保提供以下信息:
|
||||
- 问题的简要描述
|
||||
- 重现问题的步骤
|
||||
- 预期结果和实际结果
|
||||
- 相关日志或错误消息
|
||||
|
||||
### 提交代码更改
|
||||
|
||||
#### 分支命名
|
||||
|
||||
我们使用 `fix/` 前缀来修复错误,使用 `feat/` 前缀来添加新功能。对于 `fix/` 分支,请使用简短的描述,或者直接使用 Issue 编号。例如:`fix/1234` 或者 `fix/1234-login-typo`。对于 `feat/` 分支,请使用简短的描述,例如:`feat/add-user-profile`。
|
||||
|
||||
#### PR 描述
|
||||
|
||||
- 请使用英文描述您的 PR。
|
||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
First off, thanks for taking the time to contribute! ❤️
|
||||
|
||||
All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- [Reporting Issues](#reporting-issues)
|
||||
- [Pull Requests](#pull-requests)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
If you encounter any issues while using AstrBot, please follow these steps to report them:
|
||||
1. **Check Existing Issues**: Before submitting a new issue, please check if a similar issue already exists in the [Issues](https://github.com/AstrBotDevs/AstrBot/issues) section of the repository.
|
||||
2. **Create a New Issue**: If no similar issue exists, please create a new issue. Make sure to provide the following information:
|
||||
- A brief description of the issue
|
||||
- Steps to reproduce the issue
|
||||
- Expected and actual results
|
||||
- Relevant logs or error messages
|
||||
|
||||
### Pull Requests
|
||||
|
||||
#### Branch Naming
|
||||
|
||||
We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features. For `fix/` branches, please use a short description or directly use the Issue number, e.g., `fix/1234` or `fix/1234-login-typo`. For `feat/` branches, please use a short description, e.g., `feat/add-user-profile`.
|
||||
|
||||
#### PR Description
|
||||
- Please use English to describe your PR.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.7.4"
|
||||
__version__ = "4.8.0"
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import os
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.db.sqlite import BaseDatabase
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.file_token_service import FileTokenService
|
||||
from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
@@ -16,44 +14,13 @@ from .utils.astrbot_path import get_astrbot_data_path
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
|
||||
class AstrBotMySQLSettings(BaseSettings):
|
||||
host: str = "localhost"
|
||||
port: int = 3306
|
||||
user: str = "root"
|
||||
password: str = ""
|
||||
database: str = "astrbot"
|
||||
charset: str = "utf8mb4"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_prefix="ASTR_MYSQL_")
|
||||
|
||||
|
||||
def get_db_helper() -> BaseDatabase:
|
||||
db_type = os.getenv("ASTR_DB_TYPE", "sqlite")
|
||||
match db_type:
|
||||
case "sqlite":
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
|
||||
return SQLiteDatabase(DB_PATH)
|
||||
case "mysql":
|
||||
from astrbot.core.db.mysql import MySQLDatabase
|
||||
|
||||
mysql_settings = AstrBotMySQLSettings()
|
||||
|
||||
return MySQLDatabase(**mysql_settings.model_dump())
|
||||
case _:
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
|
||||
return SQLiteDatabase(DB_PATH)
|
||||
|
||||
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||
html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
db_helper = get_db_helper()
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences(db_helper=db_helper)
|
||||
# 文件令牌服务
|
||||
|
||||
@@ -9,6 +9,7 @@ from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
@@ -72,7 +73,20 @@ async def run_agent(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
|
||||
error_llm_response = LLMResponse(
|
||||
role="err",
|
||||
completion_text=err_msg,
|
||||
)
|
||||
try:
|
||||
await agent_runner.agent_hooks.on_agent_done(
|
||||
agent_runner.run_context, error_llm_response
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in on_agent_done hook")
|
||||
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.7.4"
|
||||
VERSION = "4.8.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
|
||||
@@ -3,7 +3,6 @@ import datetime
|
||||
import typing as T
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
@@ -21,17 +20,11 @@ from astrbot.core.db.po import (
|
||||
)
|
||||
|
||||
|
||||
class DatabaseType(Enum):
|
||||
SQLITE = "sqlite"
|
||||
MYSQL = "mysql"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
"""数据库基类"""
|
||||
|
||||
DATABASE_URL = ""
|
||||
database_type: DatabaseType
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.engine = create_async_engine(
|
||||
@@ -90,7 +83,7 @@ class BaseDatabase(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def count_platform_stats(self) -> int:
|
||||
"""Sum the count of platform statistics records."""
|
||||
"""Count the number of platform statistics records."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -98,16 +91,6 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_stats_time_series(
|
||||
self, offset_sec: int = 86400
|
||||
) -> list[tuple[int, int]]:
|
||||
"""Get platform statistics time series data grouped by hour.
|
||||
|
||||
Returns a list of tuples (hour_timestamp, count) sorted by timestamp ascending.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_conversations(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase, DatabaseType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .migra_3_to_4 import (
|
||||
@@ -24,10 +24,6 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
|
||||
if not os.path.exists(data_v3_db):
|
||||
return False
|
||||
|
||||
if db_helper.database_type == DatabaseType.MYSQL:
|
||||
return False
|
||||
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global",
|
||||
"global",
|
||||
|
||||
@@ -38,7 +38,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
query = (
|
||||
select(
|
||||
col(PlatformMessageHistory.user_id),
|
||||
func.max(PlatformMessageHistory.sender_name).label("sender_name"),
|
||||
col(PlatformMessageHistory.sender_name),
|
||||
func.min(PlatformMessageHistory.created_at).label("earliest"),
|
||||
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
||||
)
|
||||
|
||||
@@ -1,875 +0,0 @@
|
||||
import asyncio
|
||||
import typing as T
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase, DatabaseType
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
SQLModel,
|
||||
)
|
||||
from astrbot.core.db.po import Stats as DeprecatedStats
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
|
||||
class MySQLDatabase(BaseDatabase):
|
||||
"""MySQL 数据库实现
|
||||
|
||||
使用方式:
|
||||
db = MySQLDatabase(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="root",
|
||||
password="password",
|
||||
database="astrbot"
|
||||
)
|
||||
await db.initialize()
|
||||
"""
|
||||
|
||||
database_type = DatabaseType.MYSQL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 3306,
|
||||
user: str = "root",
|
||||
password: str = "",
|
||||
database: str = "astrbot",
|
||||
charset: str = "utf8mb4",
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.database = database
|
||||
self.charset = charset
|
||||
self.DATABASE_URL = (
|
||||
f"mysql+aiomysql://{user}:{password}@{host}:{port}/{database}"
|
||||
f"?charset={charset}"
|
||||
)
|
||||
self.inited = False
|
||||
self._current_loop: asyncio.AbstractEventLoop | None = None
|
||||
super().__init__()
|
||||
|
||||
def _recreate_engine(self) -> None:
|
||||
"""重新创建数据库引擎和会话工厂,用于处理事件循环切换的情况"""
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session.
|
||||
|
||||
此方法会检查当前事件循环,如果事件循环发生变化会重新创建引擎,
|
||||
以解决 aiomysql 的 "attached to a different loop" 问题。
|
||||
"""
|
||||
try:
|
||||
current_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
current_loop = None
|
||||
|
||||
# 检查事件循环是否变化,如果变化则重新创建引擎
|
||||
if current_loop is not None and self._current_loop != current_loop:
|
||||
self._recreate_engine()
|
||||
self._current_loop = current_loop
|
||||
self.inited = False # 需要重新初始化
|
||||
|
||||
if not self.inited:
|
||||
await self.initialize()
|
||||
self.inited = True
|
||||
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the database by creating tables if they do not exist."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await conn.commit()
|
||||
|
||||
# ====
|
||||
# Platform Statistics
|
||||
# ====
|
||||
|
||||
async def insert_platform_stats(
|
||||
self,
|
||||
platform_id,
|
||||
platform_type,
|
||||
count=1,
|
||||
timestamp=None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now().replace(
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
current_hour = timestamp
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
|
||||
VALUES (:timestamp, :platform_id, :platform_type, :count)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
count = count + VALUES(count)
|
||||
"""),
|
||||
{
|
||||
"timestamp": current_hour,
|
||||
"platform_id": platform_id,
|
||||
"platform_type": platform_type,
|
||||
"count": count,
|
||||
},
|
||||
)
|
||||
|
||||
async def count_platform_stats(self) -> int:
|
||||
"""Count the number of platform statistics records."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||
)
|
||||
count = result.scalar_one_or_none()
|
||||
return count or 0
|
||||
|
||||
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts
|
||||
FROM platform_stats
|
||||
WHERE timestamp >= :start_time
|
||||
GROUP BY platform_id, platform_type
|
||||
ORDER BY latest_ts DESC
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
return [
|
||||
PlatformStat(
|
||||
id=0,
|
||||
platform_id=row.platform_id,
|
||||
platform_type=row.platform_type,
|
||||
count=row.total_count,
|
||||
timestamp=row.latest_ts,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_platform_stats_time_series(
|
||||
self, offset_sec: int = 86400
|
||||
) -> list[tuple[int, int]]:
|
||||
"""Get platform statistics time series data grouped by hour."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT UNIX_TIMESTAMP(DATE_FORMAT(timestamp, '%Y-%m-%d %H:00:00')) as hour_ts, SUM(count) as total_count
|
||||
FROM platform_stats
|
||||
WHERE timestamp >= :start_time
|
||||
GROUP BY hour_ts
|
||||
ORDER BY hour_ts ASC
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
return [(int(row.hour_ts), row.total_count) for row in rows]
|
||||
|
||||
# ====
|
||||
# Conversation Management
|
||||
# ====
|
||||
|
||||
async def get_conversations(self, user_id=None, platform_id=None):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(ConversationV2)
|
||||
|
||||
if user_id:
|
||||
query = query.where(ConversationV2.user_id == user_id)
|
||||
if platform_id:
|
||||
query = query.where(ConversationV2.platform_id == platform_id)
|
||||
# order by
|
||||
query = query.order_by(desc(ConversationV2.created_at))
|
||||
result = await session.execute(query)
|
||||
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_conversation_by_id(self, cid):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_all_conversations(self, page=1, page_size=20):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
result = await session.execute(
|
||||
select(ConversationV2)
|
||||
.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size),
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_filtered_conversations(
|
||||
self,
|
||||
page=1,
|
||||
page_size=20,
|
||||
platform_ids=None,
|
||||
search_query="",
|
||||
**kwargs,
|
||||
):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
# Build the base query with filters
|
||||
base_query = select(ConversationV2)
|
||||
|
||||
if platform_ids:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.platform_id).in_(platform_ids),
|
||||
)
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
||||
),
|
||||
)
|
||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||
for msg_type in kwargs["message_types"]:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
|
||||
)
|
||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
|
||||
)
|
||||
|
||||
# Get total count matching the filters
|
||||
count_query = select(func.count()).select_from(base_query.subquery())
|
||||
total_count = await session.execute(count_query)
|
||||
total = total_count.scalar_one()
|
||||
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
result_query = (
|
||||
base_query.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
result = await session.execute(result_query)
|
||||
conversations = result.scalars().all()
|
||||
|
||||
return conversations, total
|
||||
|
||||
async def create_conversation(
|
||||
self,
|
||||
user_id,
|
||||
platform_id,
|
||||
content=None,
|
||||
title=None,
|
||||
persona_id=None,
|
||||
cid=None,
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
):
|
||||
kwargs = {}
|
||||
if cid:
|
||||
kwargs["conversation_id"] = cid
|
||||
if created_at:
|
||||
kwargs["created_at"] = created_at
|
||||
if updated_at:
|
||||
kwargs["updated_at"] = updated_at
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
new_conversation = ConversationV2(
|
||||
user_id=user_id,
|
||||
content=content or [],
|
||||
platform_id=platform_id,
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(new_conversation)
|
||||
return new_conversation
|
||||
|
||||
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(ConversationV2).where(
|
||||
col(ConversationV2.conversation_id) == cid,
|
||||
)
|
||||
values = {}
|
||||
if title is not None:
|
||||
values["title"] = title
|
||||
if persona_id is not None:
|
||||
values["persona_id"] = persona_id
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
await session.execute(query)
|
||||
return await self.get_conversation_by_id(cid)
|
||||
|
||||
async def delete_conversation(self, cid):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.conversation_id) == cid,
|
||||
),
|
||||
)
|
||||
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.user_id) == user_id
|
||||
),
|
||||
)
|
||||
|
||||
async def get_session_conversations(
|
||||
self,
|
||||
page=1,
|
||||
page_size=20,
|
||||
search_query=None,
|
||||
platform=None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# MySQL 使用 JSON_EXTRACT 函数(与 SQLite 的 json_extract 兼容)
|
||||
base_query = (
|
||||
select(
|
||||
col(Preference.scope_id).label("session_id"),
|
||||
func.json_extract(Preference.value, "$.val").label(
|
||||
"conversation_id",
|
||||
), # type: ignore
|
||||
col(ConversationV2.persona_id).label("persona_id"),
|
||||
col(ConversationV2.title).label("title"),
|
||||
col(Persona.persona_id).label("persona_name"),
|
||||
)
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona,
|
||||
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 搜索筛选
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
),
|
||||
)
|
||||
|
||||
# 平台筛选
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
base_query = base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern),
|
||||
)
|
||||
|
||||
# 排序
|
||||
base_query = base_query.order_by(Preference.scope_id)
|
||||
|
||||
# 分页结果
|
||||
result_query = base_query.offset(offset).limit(page_size)
|
||||
result = await session.execute(result_query)
|
||||
rows = result.fetchall()
|
||||
|
||||
# 查询总数(应用相同的筛选条件)
|
||||
count_base_query = (
|
||||
select(func.count(col(Preference.scope_id)))
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona,
|
||||
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 应用相同的搜索和平台筛选条件到计数查询
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
count_base_query = count_base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
),
|
||||
)
|
||||
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
count_base_query = count_base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern),
|
||||
)
|
||||
|
||||
total_result = await session.execute(count_base_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
sessions_data = [
|
||||
{
|
||||
"session_id": row.session_id,
|
||||
"conversation_id": row.conversation_id,
|
||||
"persona_id": row.persona_id,
|
||||
"title": row.title,
|
||||
"persona_name": row.persona_name,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return sessions_data, total
|
||||
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id,
|
||||
user_id,
|
||||
content,
|
||||
sender_id=None,
|
||||
sender_name=None,
|
||||
):
|
||||
"""Insert a new platform message history record."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
new_history = PlatformMessageHistory(
|
||||
platform_id=platform_id,
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
sender_id=sender_id,
|
||||
sender_name=sender_name,
|
||||
)
|
||||
session.add(new_history)
|
||||
return new_history
|
||||
|
||||
async def delete_platform_message_offset(
|
||||
self,
|
||||
platform_id,
|
||||
user_id,
|
||||
offset_sec=86400,
|
||||
):
|
||||
"""Delete platform message history records newer than the specified offset."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
now = datetime.now()
|
||||
cutoff_time = now - timedelta(seconds=offset_sec)
|
||||
await session.execute(
|
||||
delete(PlatformMessageHistory).where(
|
||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||
col(PlatformMessageHistory.user_id) == user_id,
|
||||
col(PlatformMessageHistory.created_at) >= cutoff_time,
|
||||
),
|
||||
)
|
||||
|
||||
async def get_platform_message_history(
|
||||
self,
|
||||
platform_id,
|
||||
user_id,
|
||||
page=1,
|
||||
page_size=20,
|
||||
):
|
||||
"""Get platform message history records."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
query = (
|
||||
select(PlatformMessageHistory)
|
||||
.where(
|
||||
PlatformMessageHistory.platform_id == platform_id,
|
||||
PlatformMessageHistory.user_id == user_id,
|
||||
)
|
||||
.order_by(desc(PlatformMessageHistory.created_at))
|
||||
)
|
||||
result = await session.execute(query.offset(offset).limit(page_size))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_platform_message_history_by_id(
|
||||
self, message_id: int
|
||||
) -> PlatformMessageHistory | None:
|
||||
"""Get a platform message history record by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(PlatformMessageHistory).where(
|
||||
PlatformMessageHistory.id == message_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def insert_attachment(self, path, type, mime_type):
|
||||
"""Insert a new attachment record."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
new_attachment = Attachment(
|
||||
path=path,
|
||||
type=type,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
session.add(new_attachment)
|
||||
return new_attachment
|
||||
|
||||
async def get_attachment_by_id(self, attachment_id):
|
||||
"""Get an attachment by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_attachments(self, attachment_ids: list[str]) -> list:
|
||||
"""Get multiple attachments by their IDs."""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(
|
||||
Attachment.attachment_id.in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_attachment(self, attachment_id: str) -> bool:
|
||||
"""Delete an attachment by its ID.
|
||||
|
||||
Returns True if the attachment was deleted, False if it was not found.
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id) == attachment_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
"""Delete multiple attachments by their IDs.
|
||||
|
||||
Returns the number of attachments deleted.
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return 0
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.rowcount
|
||||
|
||||
async def insert_persona(
|
||||
self,
|
||||
persona_id,
|
||||
system_prompt,
|
||||
begin_dialogs=None,
|
||||
tools=None,
|
||||
):
|
||||
"""Insert a new persona record."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
new_persona = Persona(
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs or [],
|
||||
tools=tools,
|
||||
)
|
||||
session.add(new_persona)
|
||||
return new_persona
|
||||
|
||||
async def get_persona_by_id(self, persona_id):
|
||||
"""Get a persona by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Persona).where(Persona.persona_id == persona_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_personas(self):
|
||||
"""Get all personas for a specific bot."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Persona)
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id,
|
||||
system_prompt=None,
|
||||
begin_dialogs=None,
|
||||
tools=NOT_GIVEN,
|
||||
):
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
values = {}
|
||||
if system_prompt is not None:
|
||||
values["system_prompt"] = system_prompt
|
||||
if begin_dialogs is not None:
|
||||
values["begin_dialogs"] = begin_dialogs
|
||||
if tools is not NOT_GIVEN:
|
||||
values["tools"] = tools
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
await session.execute(query)
|
||||
return await self.get_persona_by_id(persona_id)
|
||||
|
||||
async def delete_persona(self, persona_id):
|
||||
"""Delete a persona by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Persona).where(col(Persona.persona_id) == persona_id),
|
||||
)
|
||||
|
||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||
"""Insert a new preference record or update if it exists."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = select(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
existing_preference = result.scalar_one_or_none()
|
||||
if existing_preference:
|
||||
existing_preference.value = value
|
||||
else:
|
||||
new_preference = Preference(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
session.add(new_preference)
|
||||
return existing_preference or new_preference
|
||||
|
||||
async def get_preference(self, scope, scope_id, key):
|
||||
"""Get a preference by key."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_preferences(self, scope, scope_id=None, key=None):
|
||||
"""Get all preferences for a specific scope ID or key."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Preference).where(Preference.scope == scope)
|
||||
if scope_id is not None:
|
||||
query = query.where(Preference.scope_id == scope_id)
|
||||
if key is not None:
|
||||
query = query.where(Preference.key == key)
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def remove_preference(self, scope, scope_id, key):
|
||||
"""Remove a preference by scope ID and key."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
col(Preference.key) == key,
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def clear_preferences(self, scope, scope_id):
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Platform Session Management
|
||||
# ====
|
||||
|
||||
async def create_platform_session(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str = "webchat",
|
||||
session_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
is_group: int = 0,
|
||||
) -> PlatformSession:
|
||||
"""Create a new Platform session."""
|
||||
kwargs = {}
|
||||
if session_id:
|
||||
kwargs["session_id"] = session_id
|
||||
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
new_session = PlatformSession(
|
||||
creator=creator,
|
||||
platform_id=platform_id,
|
||||
display_name=display_name,
|
||||
is_group=is_group,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(new_session)
|
||||
await session.flush()
|
||||
await session.refresh(new_session)
|
||||
return new_session
|
||||
|
||||
async def get_platform_session_by_id(
|
||||
self, session_id: str
|
||||
) -> PlatformSession | None:
|
||||
"""Get a Platform session by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(PlatformSession).where(
|
||||
PlatformSession.session_id == session_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_platform_sessions_by_creator(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformSession]:
|
||||
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
query = select(PlatformSession).where(PlatformSession.creator == creator)
|
||||
|
||||
if platform_id:
|
||||
query = query.where(PlatformSession.platform_id == platform_id)
|
||||
|
||||
query = (
|
||||
query.order_by(desc(PlatformSession.updated_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_platform_session(
|
||||
self,
|
||||
session_id: str,
|
||||
display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
|
||||
await session.execute(
|
||||
update(PlatformSession)
|
||||
.where(col(PlatformSession.session_id) == session_id)
|
||||
.values(**values),
|
||||
)
|
||||
|
||||
async def delete_platform_session(self, session_id: str) -> None:
|
||||
"""Delete a Platform session by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(PlatformSession).where(
|
||||
col(PlatformSession.session_id) == session_id,
|
||||
),
|
||||
)
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
def get_base_stats(self, offset_sec=86400):
|
||||
"""Get base statistics within the specified offset in seconds."""
|
||||
return DeprecatedStats()
|
||||
|
||||
def get_total_message_count(self):
|
||||
"""Get the total message count from platform statistics."""
|
||||
return 0
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec=86400):
|
||||
# group by platform_id
|
||||
return DeprecatedStats()
|
||||
+105
-137
@@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase, DatabaseType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
ConversationV2,
|
||||
@@ -28,8 +28,6 @@ NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
database_type = DatabaseType.SQLITE
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
@@ -90,10 +88,12 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
||||
PlatformStat,
|
||||
),
|
||||
)
|
||||
count = result.scalar_one_or_none()
|
||||
return count or 0
|
||||
return count if count is not None else 0
|
||||
|
||||
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||
@@ -103,46 +103,14 @@ class SQLiteDatabase(BaseDatabase):
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts
|
||||
FROM platform_stats
|
||||
SELECT * FROM platform_stats
|
||||
WHERE timestamp >= :start_time
|
||||
GROUP BY platform_id, platform_type
|
||||
ORDER BY latest_ts DESC
|
||||
GROUP BY platform_id
|
||||
ORDER BY timestamp DESC
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
return [
|
||||
PlatformStat(
|
||||
id=0,
|
||||
platform_id=row.platform_id,
|
||||
platform_type=row.platform_type,
|
||||
count=row.total_count,
|
||||
timestamp=row.latest_ts,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_platform_stats_time_series(
|
||||
self, offset_sec: int = 86400
|
||||
) -> list[tuple[int, int]]:
|
||||
"""Get platform statistics time series data grouped by hour."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT strftime('%s', datetime(timestamp, 'start of hour')) as hour_ts, SUM(count) as total_count
|
||||
FROM platform_stats
|
||||
WHERE timestamp >= :start_time
|
||||
GROUP BY hour_ts
|
||||
ORDER BY hour_ts ASC
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
return [(int(row.hour_ts), row.total_count) for row in rows]
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ====
|
||||
# Conversation Management
|
||||
@@ -701,6 +669,102 @@ class SQLiteDatabase(BaseDatabase):
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
def get_base_stats(self, offset_sec=86400):
|
||||
"""Get base statistics within the specified offset in seconds."""
|
||||
|
||||
async def _inner():
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
||||
)
|
||||
all_datas = result.scalars().all()
|
||||
deprecated_stats = DeprecatedStats()
|
||||
for data in all_datas:
|
||||
deprecated_stats.platform.append(
|
||||
DeprecatedPlatformStat(
|
||||
name=data.platform_id,
|
||||
count=data.count,
|
||||
timestamp=int(data.timestamp.timestamp()),
|
||||
),
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
def get_total_message_count(self):
|
||||
"""Get the total message count from platform statistics."""
|
||||
|
||||
async def _inner():
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||
)
|
||||
total_count = result.scalar_one_or_none()
|
||||
return total_count if total_count is not None else 0
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec=86400):
|
||||
# group by platform_id
|
||||
async def _inner():
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
||||
.where(PlatformStat.timestamp >= start_time)
|
||||
.group_by(PlatformStat.platform_id),
|
||||
)
|
||||
grouped_stats = result.all()
|
||||
deprecated_stats = DeprecatedStats()
|
||||
for platform_id, count in grouped_stats:
|
||||
deprecated_stats.platform.append(
|
||||
DeprecatedPlatformStat(
|
||||
name=platform_id,
|
||||
count=count,
|
||||
timestamp=int(start_time.timestamp()),
|
||||
),
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
# ====
|
||||
# Platform Session Management
|
||||
# ====
|
||||
@@ -798,99 +862,3 @@ class SQLiteDatabase(BaseDatabase):
|
||||
col(PlatformSession.session_id) == session_id,
|
||||
),
|
||||
)
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
def get_base_stats(self, offset_sec=86400):
|
||||
"""Get base statistics within the specified offset in seconds."""
|
||||
|
||||
async def _inner():
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
||||
)
|
||||
all_datas = result.scalars().all()
|
||||
deprecated_stats = DeprecatedStats()
|
||||
for data in all_datas:
|
||||
deprecated_stats.platform.append(
|
||||
DeprecatedPlatformStat(
|
||||
name=data.platform_id,
|
||||
count=data.count,
|
||||
timestamp=int(data.timestamp.timestamp()),
|
||||
),
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
def get_total_message_count(self):
|
||||
"""Get the total message count from platform statistics."""
|
||||
|
||||
async def _inner():
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||
)
|
||||
total_count = result.scalar_one_or_none()
|
||||
return total_count if total_count is not None else 0
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec=86400):
|
||||
# group by platform_id
|
||||
async def _inner():
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
||||
.where(PlatformStat.timestamp >= start_time)
|
||||
.group_by(PlatformStat.platform_id),
|
||||
)
|
||||
grouped_stats = result.all()
|
||||
deprecated_stats = DeprecatedStats()
|
||||
for platform_id, count in grouped_stats:
|
||||
deprecated_stats.platform.append(
|
||||
DeprecatedPlatformStat(
|
||||
name=platform_id,
|
||||
count=count,
|
||||
timestamp=int(start_time.timestamp()),
|
||||
),
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
@@ -57,7 +57,7 @@ async def run_third_party_agent(
|
||||
logger.error(f"Third party agent runner error: {e}")
|
||||
err_msg = (
|
||||
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
|
||||
f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
f"错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
)
|
||||
yield MessageChain().message(err_msg)
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class PlatformManager:
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。",
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")
|
||||
|
||||
@@ -16,7 +16,7 @@ try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
|
||||
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.tencent_record_helper import (
|
||||
convert_to_pcm_wav,
|
||||
tencent_silk_to_wav,
|
||||
)
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import STTProvider
|
||||
@@ -35,18 +38,28 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
|
||||
self.set_model(provider_config.get("model"))
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
async def _get_audio_format(self, file_path):
|
||||
# 定义要检测的头部字节
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
amr_header = b"#!AMR"
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
return False
|
||||
return "silk"
|
||||
|
||||
if amr_header in file_header:
|
||||
return "amr"
|
||||
return None
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
"""Only supports mp3, mp4, mpeg, m4a, wav, webm"""
|
||||
is_tencent = False
|
||||
output_path = None
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
if "multimedia.nt.qq.com.cn" in audio_url:
|
||||
@@ -62,16 +75,35 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
file_format = await self._get_audio_format(audio_url)
|
||||
|
||||
# 判断是否需要转换
|
||||
if file_format in ["silk", "amr"]:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
|
||||
if file_format == "silk":
|
||||
logger.info(
|
||||
"Converting silk file to wav using tencent_silk_to_wav..."
|
||||
)
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
elif file_format == "amr":
|
||||
logger.info(
|
||||
"Converting amr file to wav using convert_to_pcm_wav..."
|
||||
)
|
||||
await convert_to_pcm_wav(audio_url, output_path)
|
||||
|
||||
audio_url = output_path
|
||||
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
file=("audio.wav", open(audio_url, "rb")),
|
||||
)
|
||||
|
||||
# remove temp file
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(audio_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temp file {audio_url}: {e}")
|
||||
return result.text
|
||||
|
||||
@@ -36,7 +36,7 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
import pilk
|
||||
except (ImportError, ModuleNotFoundError) as _:
|
||||
raise Exception(
|
||||
"pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库",
|
||||
"pilk 模块未安装,请前往管理面板->平台日志->安装pip库 安装 pilk 这个库",
|
||||
)
|
||||
# with wave.open(wav_path, 'rb') as wav:
|
||||
# wav_data = wav.readframes(wav.getnframes())
|
||||
|
||||
@@ -274,7 +274,7 @@ class KnowledgeBaseRoute(Route):
|
||||
except Exception as e:
|
||||
return (
|
||||
Response()
|
||||
.error(f"测试重排序模型失败: {e!s},请检查控制台日志输出。")
|
||||
.error(f"测试重排序模型失败: {e!s},请检查平台日志输出。")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
|
||||
@@ -84,14 +84,23 @@ class StatRoute(Route):
|
||||
offset_sec = request.args.get("offset_sec", 86400)
|
||||
offset_sec = int(offset_sec)
|
||||
try:
|
||||
platform_stats = await self.db_helper.get_platform_stats(offset_sec)
|
||||
stat = self.db_helper.get_base_stats(offset_sec)
|
||||
now = int(time.time())
|
||||
start_time = now - offset_sec
|
||||
message_time_based_stats = []
|
||||
|
||||
# 获取时间序列数据(按小时分桶)
|
||||
time_series = await self.db_helper.get_platform_stats_time_series(
|
||||
offset_sec
|
||||
)
|
||||
message_time_based_stats = [[ts, cnt] for ts, cnt in time_series]
|
||||
message_count = await self.db_helper.count_platform_stats()
|
||||
idx = 0
|
||||
for bucket_end in range(start_time, now, 3600):
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(stat.platform)
|
||||
and stat.platform[idx].timestamp < bucket_end
|
||||
):
|
||||
cnt += stat.platform[idx].count
|
||||
idx += 1
|
||||
message_time_based_stats.append([bucket_end, cnt])
|
||||
|
||||
stat_dict = stat.__dict__
|
||||
|
||||
cpu_percent = psutil.cpu_percent(interval=0.5)
|
||||
thread_count = threading.active_count()
|
||||
@@ -112,36 +121,28 @@ class StatRoute(Route):
|
||||
int(time.time()) - self.core_lifecycle.start_time,
|
||||
)
|
||||
|
||||
# 构建平台统计数据
|
||||
platform_data = [
|
||||
stat_dict.update(
|
||||
{
|
||||
"name": stat.platform_id,
|
||||
"count": stat.count,
|
||||
"timestamp": int(stat.timestamp.timestamp())
|
||||
if stat.timestamp
|
||||
else 0,
|
||||
}
|
||||
for stat in platform_stats
|
||||
]
|
||||
|
||||
stat_dict = {
|
||||
"platform": platform_data,
|
||||
"message_count": message_count,
|
||||
"platform_count": len(
|
||||
self.core_lifecycle.platform_manager.get_insts(),
|
||||
),
|
||||
"plugin_count": len(plugins),
|
||||
"plugins": plugin_info,
|
||||
"message_time_series": message_time_based_stats,
|
||||
"running": running_time, # 现在返回时间组件而不是格式化的字符串
|
||||
"memory": {
|
||||
"process": psutil.Process().memory_info().rss >> 20,
|
||||
"system": psutil.virtual_memory().total >> 20,
|
||||
"platform": self.db_helper.get_grouped_base_stats(
|
||||
offset_sec,
|
||||
).platform,
|
||||
"message_count": self.db_helper.get_total_message_count() or 0,
|
||||
"platform_count": len(
|
||||
self.core_lifecycle.platform_manager.get_insts(),
|
||||
),
|
||||
"plugin_count": len(plugins),
|
||||
"plugins": plugin_info,
|
||||
"message_time_series": message_time_based_stats,
|
||||
"running": running_time, # 现在返回时间组件而不是格式化的字符串
|
||||
"memory": {
|
||||
"process": psutil.Process().memory_info().rss >> 20,
|
||||
"system": psutil.virtual_memory().total >> 20,
|
||||
},
|
||||
"cpu_percent": round(cpu_percent, 1),
|
||||
"thread_count": thread_count,
|
||||
"start_time": self.core_lifecycle.start_time,
|
||||
},
|
||||
"cpu_percent": round(cpu_percent, 1),
|
||||
"thread_count": thread_count,
|
||||
"start_time": self.core_lifecycle.start_time,
|
||||
}
|
||||
)
|
||||
|
||||
return Response().ok(stat_dict).__dict__
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
## What's Changed
|
||||
|
||||
**新增:**
|
||||
- 对部分需要 Webhook 的适配器(QQ 官方机器人、Slack、企业微信、微信客服、企业微信智能机器人、微信公众号)支持统一的 Webhook 链接模式,避免开多个端口。并支持在 WebUI 机器人卡片中查看和复制 Webhook 链接。详情请看:[统一 Webhook 模式](https://docs.astrbot.app/use/unified-webhook.html)
|
||||
- 新增 Kubernetes 部署文档。
|
||||
|
||||
**修复:**
|
||||
- 修复:Telegram 和 QQ 场景下,使用 Whisper API 报错。
|
||||
- 修复:部分情况下 Slack 输出消息段代码的问题。
|
||||
- 修复:当启动了流式输出时,QQ 官方机器人适配器无法正常回复消息。
|
||||
- 修复:对话数据页的对话详情在暗夜模式下显示异常的问题。
|
||||
|
||||
**优化:**
|
||||
- 重构:WebChat 的消息数据结构,支持引用回复、文件发送、时间显示等功能,优化思考内容显示的部分 Bug。
|
||||
- 优化:机器人页面支持显示报错信息,方便排查问题。
|
||||
+2
-3
@@ -9,10 +9,9 @@ services:
|
||||
restart: always
|
||||
ports: # mappings description: https://github.com/AstrBotDevs/AstrBot/issues/497
|
||||
- "6185:6185" # 必选,AstrBot WebUI 端口
|
||||
- "6195:6195" # 可选, 企业微信 Webhook 端口
|
||||
- "6199:6199" # 可选, QQ 个人号 WebSocket 端口
|
||||
- "6196:6196" # 可选, QQ 官方接口 Webhook 端口
|
||||
- "11451:11451" # 可选, 微信个人号 Webhook 端口
|
||||
# - "6195:6195" # 可选, 企业微信 Webhook 端口
|
||||
# - "6196:6196" # 可选, QQ 官方接口 Webhook 端口
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
volumes:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
<script setup>
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { storeToRefs } from 'pinia';
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -23,6 +24,8 @@ import { useCommonStore } from '@/stores/common';
|
||||
export default {
|
||||
name: 'ConsoleDisplayer',
|
||||
data() {
|
||||
const commonStore = useCommonStore();
|
||||
const { log_cache } = storeToRefs(commonStore);
|
||||
return {
|
||||
autoScroll: true, // 默认开启自动滚动
|
||||
logColorAnsiMap: {
|
||||
@@ -35,7 +38,7 @@ export default {
|
||||
'\u001b[32m': 'color: #00FF00;', // green
|
||||
'default': 'color: #FFFFFF;'
|
||||
},
|
||||
logCache: useCommonStore().getLogCache(),
|
||||
logCache: log_cache,
|
||||
historyNum_: -1,
|
||||
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"chat": "聊天",
|
||||
"conversation": "对话数据",
|
||||
"sessionManagement": "自定义规则",
|
||||
"console": "控制台",
|
||||
"console": "平台日志",
|
||||
"alkaid": "Alkaid",
|
||||
"knowledgeBase": "知识库",
|
||||
"about": "关于",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"title": "控制台",
|
||||
"title": "平台日志",
|
||||
"autoScroll": {
|
||||
"enabled": "自动滚动已开启",
|
||||
"disabled": "自动滚动已关闭"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"title": "控制台",
|
||||
"title": "平台日志",
|
||||
"subtitle": "实时监控和统计数据",
|
||||
"lastUpdate": "最后更新",
|
||||
"status": {
|
||||
|
||||
@@ -94,7 +94,7 @@
|
||||
"dialogs": {
|
||||
"error": {
|
||||
"title": "错误信息",
|
||||
"checkConsole": "详情请检查控制台"
|
||||
"checkConsole": "详情请检查平台日志"
|
||||
},
|
||||
"config": {
|
||||
"title": "插件配置",
|
||||
|
||||
@@ -37,13 +37,21 @@
|
||||
:config_data="config_data"
|
||||
/>
|
||||
|
||||
<v-btn icon="mdi-content-save" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
|
||||
color="darkprimary" @click="updateConfig">
|
||||
</v-btn>
|
||||
<v-tooltip :text="tm('actions.save')" location="left">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" icon="mdi-content-save" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
|
||||
color="darkprimary" @click="updateConfig">
|
||||
</v-btn>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
<v-btn icon="mdi-code-json" size="x-large" style="position: fixed; right: 52px; bottom: 124px;" color="primary"
|
||||
@click="configToString(); codeEditorDialog = true">
|
||||
</v-btn>
|
||||
<v-tooltip :text="tm('codeEditor.title')" location="left">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" icon="mdi-code-json" size="x-large" style="position: fixed; right: 52px; bottom: 124px;" color="primary"
|
||||
@click="configToString(); codeEditorDialog = true">
|
||||
</v-btn>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
<v-tooltip text="测试当前配置" location="left" v-if="!isSystemConfig">
|
||||
<template v-slot:activator="{ props }">
|
||||
|
||||
@@ -13,10 +13,11 @@ const { tm } = useModuleI18n('features/console');
|
||||
<h4>{{ tm('title') }}</h4>
|
||||
<div class="d-flex align-center">
|
||||
<v-switch
|
||||
v-model="autoScrollDisabled"
|
||||
:label="autoScrollDisabled ? tm('autoScroll.disabled') : tm('autoScroll.enabled')"
|
||||
v-model="autoScrollEnabled"
|
||||
:label="autoScrollEnabled ? tm('autoScroll.enabled') : tm('autoScroll.disabled')"
|
||||
hide-details
|
||||
density="compact"
|
||||
color="primary"
|
||||
style="margin-right: 16px;"
|
||||
></v-switch>
|
||||
<v-dialog v-model="pipDialog" width="400">
|
||||
@@ -57,7 +58,7 @@ export default {
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
autoScrollDisabled: false,
|
||||
autoScrollEnabled: true,
|
||||
pipDialog: false,
|
||||
pipInstallPayload: {
|
||||
package: '',
|
||||
@@ -68,9 +69,9 @@ export default {
|
||||
}
|
||||
},
|
||||
watch: {
|
||||
autoScrollDisabled(val) {
|
||||
autoScrollEnabled(val) {
|
||||
if (this.$refs.consoleDisplayer) {
|
||||
this.$refs.consoleDisplayer.autoScroll = !val;
|
||||
this.$refs.consoleDisplayer.autoScroll = val;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: astrbot-standalone-ns
|
||||
@@ -0,0 +1,14 @@
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: astrbot-data-pvc
|
||||
namespace: astrbot-standalone-ns
|
||||
labels:
|
||||
app: astrbot-standalone
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 10Gi
|
||||
# storageClassName: standard # uncomment and set proper StorageClass
|
||||
@@ -0,0 +1,49 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: astrbot-standalone
|
||||
namespace: astrbot-standalone-ns
|
||||
labels:
|
||||
app: astrbot-standalone
|
||||
spec:
|
||||
replicas: 1
|
||||
strategy:
|
||||
type: Recreate
|
||||
selector:
|
||||
matchLabels:
|
||||
app: astrbot-standalone
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: astrbot-standalone
|
||||
spec:
|
||||
containers:
|
||||
- name: astrbot
|
||||
image: soulter/astrbot:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
env:
|
||||
- name: TZ
|
||||
value: "Asia/Shanghai"
|
||||
ports:
|
||||
- containerPort: 6185
|
||||
name: webui
|
||||
- containerPort: 6199
|
||||
name: qq-ws
|
||||
# - containerPort: 6195
|
||||
# name: wecom-wh
|
||||
# - containerPort: 6196
|
||||
# name: qq-off-wh
|
||||
volumeMounts:
|
||||
- name: data
|
||||
mountPath: /AstrBot/data
|
||||
- name: localtime
|
||||
mountPath: /etc/localtime
|
||||
readOnly: true
|
||||
volumes:
|
||||
- name: data
|
||||
persistentVolumeClaim:
|
||||
claimName: astrbot-data-pvc
|
||||
- name: localtime
|
||||
hostPath:
|
||||
path: /etc/localtime
|
||||
type: File
|
||||
@@ -0,0 +1,28 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: astrbot-standalone-nodeport
|
||||
namespace: astrbot-standalone-ns
|
||||
labels:
|
||||
app: astrbot-standalone
|
||||
spec:
|
||||
type: NodePort
|
||||
selector:
|
||||
app: astrbot-standalone
|
||||
ports:
|
||||
- name: webui
|
||||
port: 6185
|
||||
targetPort: 6185
|
||||
nodePort: 30185
|
||||
- name: qq-ws
|
||||
port: 6199
|
||||
targetPort: 6199
|
||||
nodePort: 30199
|
||||
# - name: wecom-wh
|
||||
# port: 6195
|
||||
# targetPort: 6195
|
||||
# nodePort: 30195
|
||||
# - name: qq-off-wh
|
||||
# port: 6196
|
||||
# targetPort: 6196
|
||||
# nodePort: 30196
|
||||
@@ -0,0 +1,24 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: astrbot-standalone-lb
|
||||
namespace: astrbot-standalone-ns
|
||||
labels:
|
||||
app: astrbot-standalone
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
selector:
|
||||
app: astrbot-standalone
|
||||
ports:
|
||||
- name: webui
|
||||
port: 6185
|
||||
targetPort: 6185
|
||||
- name: qq-ws
|
||||
port: 6199
|
||||
targetPort: 6199
|
||||
# - name: wecom-wh
|
||||
# port: 6195
|
||||
# targetPort: 6195
|
||||
# - name: qq-off-wh
|
||||
# port: 6196
|
||||
# targetPort: 6196
|
||||
@@ -0,0 +1,4 @@
|
||||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: astrbot-ns
|
||||
@@ -0,0 +1,46 @@
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: astrbot-data-shared-pvc
|
||||
namespace: astrbot-ns
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteMany
|
||||
resources:
|
||||
requests:
|
||||
storage: 10Gi
|
||||
# storageClassName: nfs-client # Uncomment and set your RWX storage class if needed
|
||||
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: napcat-config-pvc
|
||||
namespace: astrbot-ns
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 5Gi
|
||||
# storageClassName: standard
|
||||
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: napcat-qq-pvc
|
||||
namespace: astrbot-ns
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 5Gi
|
||||
# storageClassName: standard
|
||||
@@ -0,0 +1,64 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: astrbot-stack
|
||||
namespace: astrbot-ns
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
replicas: 1
|
||||
strategy:
|
||||
type: Recreate # Use Recreate strategy for stateful applications
|
||||
selector:
|
||||
matchLabels:
|
||||
app: astrbot-stack
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
containers:
|
||||
- name: napcat
|
||||
image: mlikiowa/napcat-docker:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
env:
|
||||
- name: NAPCAT_UID
|
||||
value: "1000"
|
||||
- name: NAPCAT_GID
|
||||
value: "1000"
|
||||
- name: MODE
|
||||
value: "astrbot"
|
||||
ports:
|
||||
- containerPort: 6099
|
||||
name: napcat-web
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /AstrBot/data
|
||||
- name: napcat-config
|
||||
mountPath: /app/napcat/config
|
||||
- name: napcat-qq
|
||||
mountPath: /app/.config/QQ
|
||||
|
||||
- name: astrbot
|
||||
image: soulter/astrbot:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
env:
|
||||
- name: TZ
|
||||
value: "Asia/Shanghai"
|
||||
ports:
|
||||
- containerPort: 6185
|
||||
name: astrbot-web
|
||||
volumeMounts:
|
||||
- name: shared-data
|
||||
mountPath: /AstrBot/data
|
||||
|
||||
volumes:
|
||||
- name: shared-data
|
||||
persistentVolumeClaim:
|
||||
claimName: astrbot-data-shared-pvc
|
||||
- name: napcat-config
|
||||
persistentVolumeClaim:
|
||||
claimName: napcat-config-pvc
|
||||
- name: napcat-qq
|
||||
persistentVolumeClaim:
|
||||
claimName: napcat-qq-pvc
|
||||
@@ -0,0 +1,20 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: astrbot-service-nodeport
|
||||
namespace: astrbot-ns
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
type: NodePort
|
||||
selector:
|
||||
app: astrbot-stack
|
||||
ports:
|
||||
- name: napcat-web
|
||||
port: 6099
|
||||
targetPort: 6099
|
||||
# nodePort: 30099 # Optional: Specify a fixed NodePort if needed, otherwise remove this line
|
||||
- name: astrbot-web
|
||||
port: 6185
|
||||
targetPort: 6185
|
||||
# nodePort: 30185 # Optional: Specify a fixed NodePort if needed, otherwise remove this line
|
||||
@@ -0,0 +1,18 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: astrbot-service-lb
|
||||
namespace: astrbot-ns
|
||||
labels:
|
||||
app: astrbot-stack
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
selector:
|
||||
app: astrbot-stack
|
||||
ports:
|
||||
- name: napcat-web
|
||||
port: 6099
|
||||
targetPort: 6099
|
||||
- name: astrbot-web
|
||||
port: 6185
|
||||
targetPort: 6185
|
||||
+1
-2
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.7.4"
|
||||
version = "4.8.0"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -60,7 +60,6 @@ dependencies = [
|
||||
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
||||
"xinference-client",
|
||||
"tenacity>=9.1.2",
|
||||
"aiomysql>=0.3.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
Reference in New Issue
Block a user