refactor: simplify runtime tls bootstrap and tighten confirm typing
This commit is contained in:
@@ -3,18 +3,17 @@ import ssl
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
import certifi
|
||||
|
||||
from http_ssl_common import build_ssl_context_with_certifi as _build_ssl_context
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
_CERTIFI_WARNING_LOGGED = False
|
||||
_SHARED_TLS_CONTEXT: ssl.SSLContext | None = None
|
||||
_SHARED_TLS_CONTEXT_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def build_ssl_context_with_certifi() -> ssl.SSLContext:
|
||||
"""Build an SSL context from system trust store and add certifi CAs."""
|
||||
global _CERTIFI_WARNING_LOGGED
|
||||
global _SHARED_TLS_CONTEXT
|
||||
|
||||
if _SHARED_TLS_CONTEXT is not None:
|
||||
@@ -24,20 +23,7 @@ def build_ssl_context_with_certifi() -> ssl.SSLContext:
|
||||
if _SHARED_TLS_CONTEXT is not None:
|
||||
return _SHARED_TLS_CONTEXT
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
|
||||
try:
|
||||
ssl_context.load_verify_locations(cafile=certifi.where())
|
||||
except Exception as exc:
|
||||
if not _CERTIFI_WARNING_LOGGED:
|
||||
logger.warning(
|
||||
"Failed to load certifi CA bundle into SSL context; "
|
||||
"falling back to system trust store only: %s",
|
||||
exc,
|
||||
)
|
||||
_CERTIFI_WARNING_LOGGED = True
|
||||
|
||||
_SHARED_TLS_CONTEXT = ssl_context
|
||||
_SHARED_TLS_CONTEXT = _build_ssl_context(log_obj=logger)
|
||||
return _SHARED_TLS_CONTEXT
|
||||
|
||||
|
||||
|
||||
Vendored
+11
@@ -0,0 +1,11 @@
|
||||
import 'vue'
|
||||
|
||||
import type { ConfirmDialogHandler } from '@/utils/confirmDialog'
|
||||
|
||||
declare module 'vue' {
|
||||
interface ComponentCustomProperties {
|
||||
$confirm?: ConfirmDialogHandler
|
||||
}
|
||||
}
|
||||
|
||||
export {}
|
||||
@@ -7,23 +7,17 @@ export type ConfirmDialogOptions = {
|
||||
|
||||
export type ConfirmDialogHandler = (options: ConfirmDialogOptions) => Promise<boolean>
|
||||
|
||||
export function resolveConfirmDialog(candidate: unknown): ConfirmDialogHandler | undefined {
|
||||
if (typeof candidate === 'function') {
|
||||
return candidate as ConfirmDialogHandler
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
export type ConfirmDialogCandidate = ConfirmDialogHandler | null | undefined
|
||||
|
||||
export function useConfirmDialog(): ConfirmDialogHandler | undefined {
|
||||
return resolveConfirmDialog(inject('$confirm', undefined))
|
||||
return inject<ConfirmDialogHandler | undefined>('$confirm', undefined)
|
||||
}
|
||||
|
||||
export async function askForConfirmation(
|
||||
message: string,
|
||||
candidate?: unknown
|
||||
candidate?: ConfirmDialogCandidate
|
||||
): Promise<boolean> {
|
||||
const confirmDialog = resolveConfirmDialog(candidate)
|
||||
const confirmDialog = candidate ?? undefined
|
||||
|
||||
if (confirmDialog) {
|
||||
try {
|
||||
|
||||
@@ -424,7 +424,7 @@ export default defineComponent({
|
||||
if (
|
||||
!(await askForConfirmationDialog(
|
||||
this.tm('messages.deleteConfirm', { id: persona.persona_id }),
|
||||
(this as any).$confirm,
|
||||
this.$confirm,
|
||||
))
|
||||
) {
|
||||
return;
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
import certifi
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_ssl_context_with_certifi(log_obj: Any | None = None) -> ssl.SSLContext:
|
||||
logger = log_obj or _LOGGER
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
try:
|
||||
ssl_context.load_verify_locations(cafile=certifi.where())
|
||||
except Exception as exc:
|
||||
if logger and hasattr(logger, "warning"):
|
||||
logger.warning(
|
||||
"Failed to load certifi CA bundle into SSL context; "
|
||||
"falling back to system trust store only: %s",
|
||||
exc,
|
||||
)
|
||||
|
||||
return ssl_context
|
||||
@@ -6,10 +6,13 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import runtime_bootstrap
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
|
||||
runtime_bootstrap.initialize_runtime_bootstrap()
|
||||
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger # noqa: E402
|
||||
from astrbot.core.config.default import VERSION # noqa: E402
|
||||
from astrbot.core.initial_loader import InitialLoader # noqa: E402
|
||||
from astrbot.core.utils.astrbot_path import ( # noqa: E402
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_plugin_path,
|
||||
@@ -17,7 +20,10 @@ from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_site_packages_path,
|
||||
get_astrbot_temp_path,
|
||||
)
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.utils.io import ( # noqa: E402
|
||||
download_dashboard,
|
||||
get_dashboard_version,
|
||||
)
|
||||
|
||||
# 将父目录添加到 sys.path
|
||||
sys.path.append(Path(__file__).parent.as_posix())
|
||||
@@ -94,8 +100,6 @@ async def check_dashboard_files(webui_dir: str | None = None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_bootstrap.initialize_runtime_bootstrap(logger)
|
||||
|
||||
parser = argparse.ArgumentParser(description="AstrBot")
|
||||
parser.add_argument(
|
||||
"--webui-dir",
|
||||
|
||||
+19
-49
@@ -1,80 +1,50 @@
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
import aiohttp.connector as aiohttp_connector
|
||||
import certifi
|
||||
|
||||
_BOOTSTRAP_RECORDS: list[tuple[str, str]] = []
|
||||
_TLS_BOOTSTRAP_DONE = False
|
||||
from http_ssl_common import build_ssl_context_with_certifi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _record(level: str, message: str) -> None:
|
||||
_BOOTSTRAP_RECORDS.append((level, message))
|
||||
|
||||
|
||||
def flush_bootstrap_records(log_obj: Any) -> None:
|
||||
if not _BOOTSTRAP_RECORDS:
|
||||
return
|
||||
|
||||
for level, message in _BOOTSTRAP_RECORDS:
|
||||
logger_method = getattr(log_obj, level, None) or getattr(log_obj, "info", None)
|
||||
if callable(logger_method):
|
||||
logger_method(message)
|
||||
|
||||
_BOOTSTRAP_RECORDS.clear()
|
||||
|
||||
|
||||
def _try_patch_aiohttp_ssl_context(ssl_context: ssl.SSLContext) -> bool:
|
||||
def _try_patch_aiohttp_ssl_context(
|
||||
ssl_context: ssl.SSLContext,
|
||||
log_obj: Any | None = None,
|
||||
) -> bool:
|
||||
log = log_obj or logger
|
||||
attr_name = "_SSL_CONTEXT_VERIFIED"
|
||||
|
||||
if not hasattr(aiohttp_connector, attr_name):
|
||||
_record(
|
||||
"warning",
|
||||
log.warning(
|
||||
"aiohttp connector does not expose _SSL_CONTEXT_VERIFIED; skipped patch.",
|
||||
)
|
||||
return False
|
||||
|
||||
current_value = getattr(aiohttp_connector, attr_name, None)
|
||||
if current_value is not None and not isinstance(current_value, ssl.SSLContext):
|
||||
_record(
|
||||
"warning",
|
||||
log.warning(
|
||||
"aiohttp connector exposes _SSL_CONTEXT_VERIFIED with unexpected type; skipped patch.",
|
||||
)
|
||||
return False
|
||||
|
||||
setattr(aiohttp_connector, attr_name, ssl_context)
|
||||
_record(
|
||||
"info",
|
||||
"Configured aiohttp verified SSL context with system+certifi trust chain.",
|
||||
)
|
||||
log.info("Configured aiohttp verified SSL context with system+certifi trust chain.")
|
||||
return True
|
||||
|
||||
|
||||
def configure_runtime_ca_bundle() -> bool:
|
||||
global _TLS_BOOTSTRAP_DONE
|
||||
|
||||
if _TLS_BOOTSTRAP_DONE:
|
||||
return True
|
||||
def configure_runtime_ca_bundle(log_obj: Any | None = None) -> bool:
|
||||
log = log_obj or logger
|
||||
|
||||
try:
|
||||
_record("info", "Bootstrapping runtime CA bundle.")
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.load_verify_locations(cafile=certifi.where())
|
||||
_TLS_BOOTSTRAP_DONE = _try_patch_aiohttp_ssl_context(ssl_context)
|
||||
return _TLS_BOOTSTRAP_DONE
|
||||
log.info("Bootstrapping runtime CA bundle.")
|
||||
ssl_context = build_ssl_context_with_certifi(log_obj=log)
|
||||
return _try_patch_aiohttp_ssl_context(ssl_context, log_obj=log)
|
||||
except Exception as exc:
|
||||
_record(
|
||||
"error",
|
||||
f"Failed to configure runtime CA bundle for aiohttp: {exc!r}",
|
||||
)
|
||||
log.error("Failed to configure runtime CA bundle for aiohttp: %r", exc)
|
||||
return False
|
||||
|
||||
|
||||
def initialize_runtime_bootstrap(log_obj: Any | None = None) -> bool:
|
||||
configured = configure_runtime_ca_bundle()
|
||||
if log_obj is not None:
|
||||
flush_bootstrap_records(log_obj)
|
||||
return configured
|
||||
|
||||
|
||||
configure_runtime_ca_bundle()
|
||||
return configure_runtime_ca_bundle(log_obj=log_obj)
|
||||
|
||||
Reference in New Issue
Block a user