diff --git a/astrbot/core/star/command_management.py b/astrbot/core/star/command_management.py index c823695b4..a0b125d33 100644 --- a/astrbot/core/star/command_management.py +++ b/astrbot/core/star/command_management.py @@ -46,20 +46,14 @@ class CommandDescriptor: async def sync_command_configs() -> None: """同步指令配置,清理过期配置。""" - descriptors = _collect_raw_descriptors() + descriptors = _collect_descriptors(include_sub_commands=False) config_records = await db_helper.get_command_configs() - config_map = {cfg.handler_full_name: cfg for cfg in config_records} + config_map = _bind_configs_to_descriptors(descriptors, config_records) live_handlers = {desc.handler_full_name for desc in descriptors} stale_configs = [key for key in config_map if key not in live_handlers] if stale_configs: await db_helper.delete_command_configs(stale_configs) - for key in stale_configs: - config_map.pop(key, None) - - for desc in descriptors: - if cfg := config_map.get(desc.handler_full_name): - _bind_descriptor_with_config(desc, cfg) async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor: @@ -130,25 +124,14 @@ async def rename_command( async def list_commands() -> list[dict[str, Any]]: - descriptors = _collect_all_descriptors() + descriptors = _collect_descriptors(include_sub_commands=True) config_records = await db_helper.get_command_configs() - config_map = {cfg.handler_full_name: cfg for cfg in config_records} + _bind_configs_to_descriptors(descriptors, config_records) - for desc in descriptors: - if cfg := config_map.get(desc.handler_full_name): - _bind_descriptor_with_config(desc, cfg) - - # 检测冲突:按 effective_command 分组 - conflict_groups: dict[str, list[CommandDescriptor]] = defaultdict(list) - for desc in descriptors: - if desc.effective_command and desc.enabled: - conflict_groups[desc.effective_command].append(desc) - - conflict_handler_names: set[str] = set() - for key, group in conflict_groups.items(): - if len(group) > 1: - for desc in group: - conflict_handler_names.add(desc.handler_full_name) + conflict_groups = _group_conflicts(descriptors) + conflict_handler_names: set[str] = { + d.handler_full_name for group in conflict_groups.values() for d in group + } # 分类,设置冲突标志,将子指令挂载到父指令组 group_map: dict[str, CommandDescriptor] = {} @@ -180,60 +163,40 @@ async def list_commands() -> list[dict[str, Any]]: async def list_command_conflicts() -> list[dict[str, Any]]: """列出所有冲突的指令组。""" - descriptors = _collect_raw_descriptors() + descriptors = _collect_descriptors(include_sub_commands=False) config_records = await db_helper.get_command_configs() - config_map = {cfg.handler_full_name: cfg for cfg in config_records} - for desc in descriptors: - if cfg := config_map.get(desc.handler_full_name): - _bind_descriptor_with_config(desc, cfg) + _bind_configs_to_descriptors(descriptors, config_records) - conflicts = defaultdict(list) - for desc in descriptors: - if not desc.effective_command or not desc.enabled: - continue - conflicts[desc.effective_command].append(desc) - - details = [] - for key, group in conflicts.items(): - if len(group) <= 1: - continue - details.append( - { - "conflict_key": key, - "handlers": [ - { - "handler_full_name": item.handler_full_name, - "plugin": item.plugin_name, - "current_name": item.effective_command, - } - for item in group - ], - }, - ) + conflict_groups = _group_conflicts(descriptors) + details = [ + { + "conflict_key": key, + "handlers": [ + { + "handler_full_name": item.handler_full_name, + "plugin": item.plugin_name, + "current_name": item.effective_command, + } + for item in group + ], + } + for key, group in conflict_groups.items() + ] return details # Internal helpers ---------------------------------------------------------- -def _collect_raw_descriptors() -> list[CommandDescriptor]: - """收集所有根级指令(不含子指令)。""" - descriptors: list[CommandDescriptor] = [] - for handler in star_handlers_registry: - desc = _build_descriptor(handler) - if not desc or desc.is_sub_command: - continue - descriptors.append(desc) - return descriptors - - -def _collect_all_descriptors() -> list[CommandDescriptor]: - """收集所有指令,包括子指令。""" +def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]: + """收集指令,按需包含子指令。""" descriptors: list[CommandDescriptor] = [] for handler in star_handlers_registry: desc = _build_descriptor(handler) if not desc: continue + if not include_sub_commands and desc.is_sub_command: + continue descriptors.append(desc) return descriptors @@ -369,10 +332,20 @@ def _compose_command(parent_signature: str, fragment: str | None) -> str: return f"{parent_signature} {fragment}" -def _bind_descriptor_with_config(descriptor: CommandDescriptor, config: CommandConfig): +def _bind_descriptor_with_config( + descriptor: CommandDescriptor, + config: CommandConfig, +) -> None: + _apply_config_to_descriptor(descriptor, config) + _apply_config_to_runtime(descriptor, config) + + +def _apply_config_to_descriptor( + descriptor: CommandDescriptor, + config: CommandConfig, +) -> None: descriptor.config = config descriptor.enabled = config.enabled - descriptor.handler.enabled = config.enabled if config.original_command: descriptor.original_command = config.original_command @@ -384,8 +357,35 @@ def _bind_descriptor_with_config(descriptor: CommandDescriptor, config: CommandC new_fragment, ) - if descriptor.filter_ref and new_fragment: - _set_filter_fragment(descriptor.filter_ref, new_fragment) + +def _apply_config_to_runtime( + descriptor: CommandDescriptor, + config: CommandConfig, +) -> None: + descriptor.handler.enabled = config.enabled + if descriptor.filter_ref and descriptor.current_fragment: + _set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment) + + +def _bind_configs_to_descriptors( + descriptors: list[CommandDescriptor], + config_records: list[CommandConfig], +) -> dict[str, CommandConfig]: + config_map = {cfg.handler_full_name: cfg for cfg in config_records} + for desc in descriptors: + if cfg := config_map.get(desc.handler_full_name): + _bind_descriptor_with_config(desc, cfg) + return config_map + + +def _group_conflicts( + descriptors: list[CommandDescriptor], +) -> dict[str, list[CommandDescriptor]]: + conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list) + for desc in descriptors: + if desc.effective_command and desc.enabled: + conflicts[desc.effective_command].append(desc) + return {k: v for k, v in conflicts.items() if len(v) > 1} def _set_filter_fragment(