diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index af4e0fb37..9a63350fb 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -869,6 +869,18 @@ CONFIG_METADATA_2 = { "timeout": 60, "hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!", }, + "Coze": { + "id": "coze", + "provider": "coze", + "provider_type": "chat_completion", + "type": "coze", + "enable": True, + "coze_api_key": "", + "bot_id": "", + "coze_api_base": "https://api.coze.cn", + "timeout": 60, + "auto_save_history": True, + }, "阿里云百炼应用": { "id": "dashscope", "provider": "dashscope", @@ -1735,6 +1747,26 @@ CONFIG_METADATA_2 = { "hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。", "obvious": True, }, + "coze_api_key": { + "description": "Coze API Key", + "type": "string", + "hint": "Coze API 密钥,用于访问 Coze 服务。", + }, + "bot_id": { + "description": "Bot ID", + "type": "string", + "hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。", + }, + "coze_api_base": { + "description": "API Base URL", + "type": "string", + "hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn", + }, + "auto_save_history": { + "description": "由 Coze 管理对话记录", + "type": "bool", + "hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。", + }, }, }, "provider_settings": { diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 562fd178d..4219ae518 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -536,7 +536,23 @@ class LLMRequestSubStage(Stage): latest_pair = messages[-2:] if not latest_pair: return - cleaned_text = "User: " + latest_pair[0].get("content", "").strip() + content = latest_pair[0].get("content", "") + if isinstance(content, list): + # 多模态 + text_parts = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + text_parts.append(item.get("text", "")) + elif item.get("type") == "image": + text_parts.append("[图片]") + elif isinstance(item, str): + text_parts.append(item) + cleaned_text = "User: " + " ".join(text_parts).strip() + elif isinstance(content, str): + cleaned_text = "User: " + content.strip() + else: + return logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}") llm_resp = await prov.text_chat( system_prompt="You are expert in summarizing user's query.", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index a23788a76..6666b33e6 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -234,6 +234,8 @@ class ProviderManager: ) case "dify": from .sources.dify_source import ProviderDify as ProviderDify + case "coze": + from .sources.coze_source import ProviderCoze as ProviderCoze case "dashscope": from .sources.dashscope_source import ( ProviderDashscope as ProviderDashscope, diff --git a/astrbot/core/provider/sources/coze_api_client.py b/astrbot/core/provider/sources/coze_api_client.py new file mode 100644 index 000000000..a768979c6 --- /dev/null +++ b/astrbot/core/provider/sources/coze_api_client.py @@ -0,0 +1,314 @@ +import json +import asyncio +import aiohttp +import io +from typing import Dict, List, Any, AsyncGenerator +from astrbot.core import logger + + +class CozeAPIClient: + def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"): + self.api_key = api_key + self.api_base = api_base + self.session = None + + async def _ensure_session(self): + """确保HTTP session存在""" + if self.session is None: + connector = aiohttp.TCPConnector( + ssl=False if self.api_base.startswith("http://") else True, + limit=100, + limit_per_host=30, + keepalive_timeout=30, + enable_cleanup_closed=True, + ) + timeout = aiohttp.ClientTimeout( + total=120, # 默认超时时间 + connect=30, + sock_read=120, + ) + headers = { + "Authorization": f"Bearer {self.api_key}", + "Accept": "text/event-stream", + } + self.session = aiohttp.ClientSession( + headers=headers, timeout=timeout, connector=connector + ) + return self.session + + async def upload_file( + self, + file_data: bytes, + ) -> str: + """上传文件到 Coze 并返回 file_id + + Args: + file_data (bytes): 文件的二进制数据 + Returns: + str: 上传成功后返回的 file_id + """ + session = await self._ensure_session() + url = f"{self.api_base}/v1/files/upload" + + try: + file_io = io.BytesIO(file_data) + async with session.post( + url, + data={ + "file": file_io, + }, + timeout=aiohttp.ClientTimeout(total=60), + ) as response: + if response.status == 401: + raise Exception("Coze API 认证失败,请检查 API Key 是否正确") + + response_text = await response.text() + logger.debug( + f"文件上传响应状态: {response.status}, 内容: {response_text}" + ) + + if response.status != 200: + raise Exception( + f"文件上传失败,状态码: {response.status}, 响应: {response_text}" + ) + + try: + result = await response.json() + except json.JSONDecodeError: + raise Exception(f"文件上传响应解析失败: {response_text}") + + if result.get("code") != 0: + raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}") + + file_id = result["data"]["id"] + logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}") + return file_id + + except asyncio.TimeoutError: + logger.error("文件上传超时") + raise Exception("文件上传超时") + except Exception as e: + logger.error(f"文件上传失败: {str(e)}") + raise Exception(f"文件上传失败: {str(e)}") + + async def download_image(self, image_url: str) -> bytes: + """下载图片并返回字节数据 + + Args: + image_url (str): 图片的URL + Returns: + bytes: 图片的二进制数据 + """ + session = await self._ensure_session() + + try: + async with session.get(image_url) as response: + if response.status != 200: + raise Exception(f"下载图片失败,状态码: {response.status}") + + image_data = await response.read() + return image_data + + except Exception as e: + logger.error(f"下载图片失败 {image_url}: {str(e)}") + raise Exception(f"下载图片失败: {str(e)}") + + async def chat_messages( + self, + bot_id: str, + user_id: str, + additional_messages: List[Dict] | None = None, + conversation_id: str | None = None, + auto_save_history: bool = True, + stream: bool = True, + timeout: float = 120, + ) -> AsyncGenerator[Dict[str, Any], None]: + """发送聊天消息并返回流式响应 + + Args: + bot_id: Bot ID + user_id: 用户ID + additional_messages: 额外消息列表 + conversation_id: 会话ID + auto_save_history: 是否自动保存历史 + stream: 是否流式响应 + timeout: 超时时间 + """ + session = await self._ensure_session() + url = f"{self.api_base}/v3/chat" + + payload = { + "bot_id": bot_id, + "user_id": user_id, + "stream": stream, + "auto_save_history": auto_save_history, + } + + if additional_messages: + payload["additional_messages"] = additional_messages + + params = {} + if conversation_id: + params["conversation_id"] = conversation_id + + logger.debug(f"Coze chat_messages payload: {payload}, params: {params}") + + try: + async with session.post( + url, + json=payload, + params=params, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status == 401: + raise Exception("Coze API 认证失败,请检查 API Key 是否正确") + + if response.status != 200: + raise Exception(f"Coze API 流式请求失败,状态码: {response.status}") + + # SSE + buffer = "" + event_type = None + event_data = None + + async for chunk in response.content: + if chunk: + buffer += chunk.decode("utf-8", errors="ignore") + lines = buffer.split("\n") + buffer = lines[-1] + + for line in lines[:-1]: + line = line.strip() + + if not line: + if event_type and event_data: + yield {"event": event_type, "data": event_data} + event_type = None + event_data = None + elif line.startswith("event:"): + event_type = line[6:].strip() + elif line.startswith("data:"): + data_str = line[5:].strip() + if data_str and data_str != "[DONE]": + try: + event_data = json.loads(data_str) + except json.JSONDecodeError: + event_data = {"content": data_str} + + except asyncio.TimeoutError: + raise Exception(f"Coze API 流式请求超时 ({timeout}秒)") + except Exception as e: + raise Exception(f"Coze API 流式请求失败: {str(e)}") + + async def clear_context(self, conversation_id: str): + """清空会话上下文 + + Args: + conversation_id: 会话ID + Returns: + dict: API响应结果 + """ + session = await self._ensure_session() + url = f"{self.api_base}/v3/conversation/message/clear_context" + payload = {"conversation_id": conversation_id} + + try: + async with session.post(url, json=payload) as response: + response_text = await response.text() + + if response.status == 401: + raise Exception("Coze API 认证失败,请检查 API Key 是否正确") + + if response.status != 200: + raise Exception(f"Coze API 请求失败,状态码: {response.status}") + + try: + return json.loads(response_text) + except json.JSONDecodeError: + raise Exception("Coze API 返回非JSON格式") + + except asyncio.TimeoutError: + raise Exception("Coze API 请求超时") + except aiohttp.ClientError as e: + raise Exception(f"Coze API 请求失败: {str(e)}") + + async def get_message_list( + self, + conversation_id: str, + order: str = "desc", + limit: int = 10, + offset: int = 0, + ): + """获取消息列表 + + Args: + conversation_id: 会话ID + order: 排序方式 (asc/desc) + limit: 限制数量 + offset: 偏移量 + Returns: + dict: API响应结果 + """ + session = await self._ensure_session() + url = f"{self.api_base}/v3/conversation/message/list" + params = { + "conversation_id": conversation_id, + "order": order, + "limit": limit, + "offset": offset, + } + + try: + async with session.get(url, params=params) as response: + response.raise_for_status() + return await response.json() + + except Exception as e: + logger.error(f"获取Coze消息列表失败: {str(e)}") + raise Exception(f"获取Coze消息列表失败: {str(e)}") + + async def close(self): + """关闭会话""" + if self.session: + await self.session.close() + self.session = None + + +if __name__ == "__main__": + import os + import asyncio + + async def test_coze_api_client(): + api_key = os.getenv("COZE_API_KEY", "") + bot_id = os.getenv("COZE_BOT_ID", "") + client = CozeAPIClient(api_key=api_key) + + try: + with open("README.md", "rb") as f: + file_data = f.read() + file_id = await client.upload_file(file_data) + print(f"Uploaded file_id: {file_id}") + async for event in client.chat_messages( + bot_id=bot_id, + user_id="test_user", + additional_messages=[ + { + "role": "user", + "content": json.dumps( + [ + {"type": "text", "text": "这是什么"}, + {"type": "file", "file_id": file_id}, + ], + ensure_ascii=False, + ), + "content_type": "object_string", + }, + ], + stream=True, + ): + print(f"Event: {event}") + + finally: + await client.close() + + asyncio.run(test_coze_api_client()) diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py new file mode 100644 index 000000000..639af0814 --- /dev/null +++ b/astrbot/core/provider/sources/coze_source.py @@ -0,0 +1,635 @@ +import json +import os +import base64 +import hashlib +from typing import AsyncGenerator, Dict +from astrbot.core.message.message_event_result import MessageChain +import astrbot.core.message.components as Comp +from astrbot.api.provider import Provider +from astrbot import logger +from astrbot.core.provider.entities import LLMResponse +from ..register import register_provider_adapter +from .coze_api_client import CozeAPIClient + + +@register_provider_adapter("coze", "Coze (扣子) 智能体适配器") +class ProviderCoze(Provider): + def __init__( + self, + provider_config, + provider_settings, + default_persona=None, + ) -> None: + super().__init__( + provider_config, + provider_settings, + default_persona, + ) + self.api_key = provider_config.get("coze_api_key", "") + if not self.api_key: + raise Exception("Coze API Key 不能为空。") + self.bot_id = provider_config.get("bot_id", "") + if not self.bot_id: + raise Exception("Coze Bot ID 不能为空。") + self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn") + + if not isinstance(self.api_base, str) or not self.api_base.startswith( + ("http://", "https://") + ): + raise Exception( + "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。" + ) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + self.auto_save_history = provider_config.get("auto_save_history", True) + self.conversation_ids: Dict[str, str] = {} + self.file_id_cache: Dict[str, Dict[str, str]] = {} + + # 创建 API 客户端 + self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) + + def _generate_cache_key(self, data: str, is_base64: bool = False) -> str: + """生成统一的缓存键 + + Args: + data: 图片数据或路径 + is_base64: 是否是 base64 数据 + + Returns: + str: 缓存键 + """ + + try: + if is_base64 and data.startswith("data:image/"): + try: + header, encoded = data.split(",", 1) + image_bytes = base64.b64decode(encoded) + cache_key = hashlib.md5(image_bytes).hexdigest() + return cache_key + except Exception: + cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest() + return cache_key + else: + if data.startswith(("http://", "https://")): + # URL图片,使用URL作为缓存键 + cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() + return cache_key + else: + clean_path = ( + data.split("_")[0] + if "_" in data and len(data.split("_")) >= 3 + else data + ) + + if os.path.exists(clean_path): + with open(clean_path, "rb") as f: + file_content = f.read() + cache_key = hashlib.md5(file_content).hexdigest() + return cache_key + else: + cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest() + return cache_key + + except Exception as e: + cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() + logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}") + return cache_key + + async def _upload_file( + self, + file_data: bytes, + session_id: str | None = None, + cache_key: str | None = None, + ) -> str: + """上传文件到 Coze 并返回 file_id""" + # 使用 API 客户端上传文件 + file_id = await self.api_client.upload_file(file_data) + + # 缓存 file_id + if session_id and cache_key: + if session_id not in self.file_id_cache: + self.file_id_cache[session_id] = {} + self.file_id_cache[session_id][cache_key] = file_id + logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}") + + return file_id + + async def _download_and_upload_image( + self, image_url: str, session_id: str | None = None + ) -> str: + """下载图片并上传到 Coze,返回 file_id""" + # 计算哈希实现缓存 + cache_key = self._generate_cache_key(image_url) if session_id else None + + if session_id and cache_key: + if session_id not in self.file_id_cache: + self.file_id_cache[session_id] = {} + + if cache_key in self.file_id_cache[session_id]: + file_id = self.file_id_cache[session_id][cache_key] + return file_id + + try: + image_data = await self.api_client.download_image(image_url) + + file_id = await self._upload_file(image_data, session_id, cache_key) + + if session_id and cache_key: + self.file_id_cache[session_id][cache_key] = file_id + + return file_id + + except Exception as e: + logger.error(f"处理图片失败 {image_url}: {str(e)}") + raise Exception(f"处理图片失败: {str(e)}") + + async def _process_context_images( + self, content: str | list, session_id: str + ) -> str: + """处理上下文中的图片内容,将 base64 图片上传并替换为 file_id""" + + try: + if isinstance(content, str): + return content + + processed_content = [] + if session_id not in self.file_id_cache: + self.file_id_cache[session_id] = {} + + for item in content: + if not isinstance(item, dict): + processed_content.append(item) + continue + if item.get("type") == "text": + processed_content.append(item) + elif item.get("type") == "image_url": + # 处理图片逻辑 + if "file_id" in item: + # 已经有 file_id + logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}") + processed_content.append(item) + else: + # 获取图片数据 + image_data = "" + if "image_url" in item and isinstance(item["image_url"], dict): + image_data = item["image_url"].get("url", "") + elif "data" in item: + image_data = item.get("data", "") + elif "url" in item: + image_data = item.get("url", "") + + if not image_data: + continue + # 计算哈希用于缓存 + cache_key = self._generate_cache_key( + image_data, is_base64=image_data.startswith("data:image/") + ) + + # 检查缓存 + if cache_key in self.file_id_cache[session_id]: + file_id = self.file_id_cache[session_id][cache_key] + processed_content.append( + {"type": "image", "file_id": file_id} + ) + else: + # 上传图片并缓存 + if image_data.startswith("data:image/"): + # base64 处理 + _, encoded = image_data.split(",", 1) + image_bytes = base64.b64decode(encoded) + file_id = await self._upload_file( + image_bytes, + session_id, + cache_key, + ) + elif image_data.startswith(("http://", "https://")): + # URL 图片 + file_id = await self._download_and_upload_image( + image_data, session_id + ) + # 为URL图片也添加缓存 + self.file_id_cache[session_id][cache_key] = file_id + elif os.path.exists(image_data): + # 本地文件 + with open(image_data, "rb") as f: + image_bytes = f.read() + file_id = await self._upload_file( + image_bytes, + session_id, + cache_key, + ) + else: + logger.warning( + f"无法处理的图片格式: {image_data[:50]}..." + ) + continue + + processed_content.append( + {"type": "image", "file_id": file_id} + ) + + result = json.dumps(processed_content, ensure_ascii=False) + return result + except Exception as e: + logger.error(f"处理上下文图片失败: {str(e)}") + if isinstance(content, str): + return content + else: + return json.dumps(content, ensure_ascii=False) + + async def text_chat( + self, + prompt: str, + session_id=None, + image_urls=None, + func_tool=None, + contexts=None, + system_prompt=None, + tool_calls_result=None, + model=None, + **kwargs, + ) -> LLMResponse: + """文本对话, 内部使用流式接口实现非流式 + + Args: + prompt (str): 用户提示词 + session_id (str): 会话ID + image_urls (List[str]): 图片URL列表 + func_tool (FuncCall): 函数调用工具(不支持) + contexts (List): 上下文列表 + system_prompt (str): 系统提示语 + tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持) + model (str): 模型名称(不支持) + Returns: + LLMResponse: LLM响应对象 + """ + accumulated_content = "" + final_response = None + + async for llm_response in self.text_chat_stream( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + model=model, + **kwargs, + ): + if llm_response.is_chunk: + if llm_response.completion_text: + accumulated_content += llm_response.completion_text + else: + final_response = llm_response + + if final_response: + return final_response + + if accumulated_content: + chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) + return LLMResponse(role="assistant", result_chain=chain) + else: + return LLMResponse(role="assistant", completion_text="") + + async def text_chat_stream( + self, + prompt: str, + session_id=None, + image_urls=None, + func_tool=None, + contexts=None, + system_prompt=None, + tool_calls_result=None, + model=None, + **kwargs, + ) -> AsyncGenerator[LLMResponse, None]: + """流式对话接口""" + # 用户ID参数(参考文档, 可以自定义) + user_id = session_id or kwargs.get("user", "default_user") + + # 获取或创建会话ID + conversation_id = self.conversation_ids.get(user_id) + + # 构建消息 + additional_messages = [] + + if system_prompt: + if not self.auto_save_history or not conversation_id: + additional_messages.append( + {"role": "system", "content": system_prompt, "content_type": "text"} + ) + + if not self.auto_save_history and contexts: + # 如果关闭了自动保存历史,传入上下文 + for ctx in contexts: + if isinstance(ctx, dict) and "role" in ctx and "content" in ctx: + content = ctx["content"] + content_type = ctx.get("content_type", "text") + + # 处理可能包含图片的上下文 + if ( + content_type == "object_string" + or (isinstance(content, str) and content.startswith("[")) + or ( + isinstance(content, list) + and any( + isinstance(item, dict) + and item.get("type") == "image_url" + for item in content + ) + ) + ): + processed_content = await self._process_context_images( + content, user_id + ) + additional_messages.append( + { + "role": ctx["role"], + "content": processed_content, + "content_type": "object_string", + } + ) + else: + # 纯文本 + additional_messages.append( + { + "role": ctx["role"], + "content": ( + content + if isinstance(content, str) + else json.dumps(content, ensure_ascii=False) + ), + "content_type": "text", + } + ) + else: + logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}") + + if prompt or image_urls: + if image_urls: + # 多模态 + object_string_content = [] + if prompt: + object_string_content.append({"type": "text", "text": prompt}) + + for url in image_urls: + try: + if url.startswith(("http://", "https://")): + # 网络图片 + file_id = await self._download_and_upload_image( + url, user_id + ) + else: + # 本地文件或 base64 + if url.startswith("data:image/"): + # base64 + _, encoded = url.split(",", 1) + image_data = base64.b64decode(encoded) + cache_key = self._generate_cache_key( + url, is_base64=True + ) + file_id = await self._upload_file( + image_data, user_id, cache_key + ) + else: + # 本地文件 + if os.path.exists(url): + with open(url, "rb") as f: + image_data = f.read() + # 用文件路径和修改时间来缓存 + file_stat = os.stat(url) + cache_key = self._generate_cache_key( + f"{url}_{file_stat.st_mtime}_{file_stat.st_size}", + is_base64=False, + ) + file_id = await self._upload_file( + image_data, user_id, cache_key + ) + else: + logger.warning(f"图片文件不存在: {url}") + continue + + object_string_content.append( + { + "type": "image", + "file_id": file_id, + } + ) + except Exception as e: + logger.error(f"处理图片失败 {url}: {str(e)}") + continue + + if object_string_content: + content = json.dumps(object_string_content, ensure_ascii=False) + additional_messages.append( + { + "role": "user", + "content": content, + "content_type": "object_string", + } + ) + else: + # 纯文本 + if prompt: + additional_messages.append( + { + "role": "user", + "content": prompt, + "content_type": "text", + } + ) + + try: + accumulated_content = "" + message_started = False + + async for chunk in self.api_client.chat_messages( + bot_id=self.bot_id, + user_id=user_id, + additional_messages=additional_messages, + conversation_id=conversation_id, + auto_save_history=self.auto_save_history, + stream=True, + timeout=self.timeout, + ): + event_type = chunk.get("event") + data = chunk.get("data", {}) + + if event_type == "conversation.chat.created": + if isinstance(data, dict) and "conversation_id" in data: + self.conversation_ids[user_id] = data["conversation_id"] + + elif event_type == "conversation.message.delta": + if isinstance(data, dict): + content = data.get("content", "") + if not content and "delta" in data: + content = data["delta"].get("content", "") + if not content and "text" in data: + content = data.get("text", "") + + if content: + message_started = True + accumulated_content += content + yield LLMResponse( + role="assistant", + completion_text=content, + is_chunk=True, + ) + + elif event_type == "conversation.message.completed": + if isinstance(data, dict): + msg_type = data.get("type") + if msg_type == "answer" and data.get("role") == "assistant": + final_content = data.get("content", "") + if not accumulated_content and final_content: + chain = MessageChain(chain=[Comp.Plain(final_content)]) + yield LLMResponse( + role="assistant", + result_chain=chain, + is_chunk=False, + ) + + elif event_type == "conversation.chat.completed": + if accumulated_content: + chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) + yield LLMResponse( + role="assistant", + result_chain=chain, + is_chunk=False, + ) + break + + elif event_type == "done": + break + + elif event_type == "error": + error_msg = ( + data.get("message", "未知错误") + if isinstance(data, dict) + else str(data) + ) + logger.error(f"Coze 流式响应错误: {error_msg}") + yield LLMResponse( + role="err", + completion_text=f"Coze 错误: {error_msg}", + is_chunk=False, + ) + break + + if not message_started and not accumulated_content: + yield LLMResponse( + role="assistant", + completion_text="LLM 未响应任何内容。", + is_chunk=False, + ) + elif message_started and accumulated_content: + chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) + yield LLMResponse( + role="assistant", + result_chain=chain, + is_chunk=False, + ) + + except Exception as e: + logger.error(f"Coze 流式请求失败: {str(e)}") + yield LLMResponse( + role="err", + completion_text=f"Coze 流式请求失败: {str(e)}", + is_chunk=False, + ) + + async def forget(self, session_id: str): + """清空指定会话的上下文""" + user_id = session_id + conversation_id = self.conversation_ids.get(user_id) + + if user_id in self.file_id_cache: + self.file_id_cache.pop(user_id, None) + + if not conversation_id: + return True + + try: + response = await self.api_client.clear_context(conversation_id) + + if "code" in response and response["code"] == 0: + self.conversation_ids.pop(user_id, None) + return True + else: + logger.warning(f"清空 Coze 会话上下文失败: {response}") + return False + + except Exception as e: + logger.error(f"清空 Coze 会话失败: {str(e)}") + return False + + async def get_current_key(self): + """获取当前API Key""" + return self.api_key + + async def set_key(self, key: str): + """设置新的API Key""" + raise NotImplementedError("Coze 适配器不支持设置 API Key。") + + async def get_models(self): + """获取可用模型列表""" + return [f"bot_{self.bot_id}"] + + def get_model(self): + """获取当前模型""" + return f"bot_{self.bot_id}" + + def set_model(self, model: str): + """设置模型(在Coze中是Bot ID)""" + if model.startswith("bot_"): + self.bot_id = model[4:] + else: + self.bot_id = model + + async def get_human_readable_context( + self, session_id: str, page: int = 1, page_size: int = 10 + ): + """获取人类可读的上下文历史""" + user_id = session_id + conversation_id = self.conversation_ids.get(user_id) + + if not conversation_id: + return [] + + try: + data = await self.api_client.get_message_list( + conversation_id=conversation_id, + order="desc", + limit=page_size, + offset=(page - 1) * page_size, + ) + + if data.get("code") != 0: + logger.warning(f"获取 Coze 消息历史失败: {data}") + return [] + + messages = data.get("data", {}).get("messages", []) + + readable_history = [] + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + msg_type = msg.get("type", "") + + if role == "user": + readable_history.append(f"用户: {content}") + elif role == "assistant" and msg_type == "answer": + readable_history.append(f"助手: {content}") + + return readable_history + + except Exception as e: + logger.error(f"获取 Coze 消息历史失败: {str(e)}") + return [] + + async def terminate(self): + """清理资源""" + await self.api_client.close() diff --git a/dashboard/src/utils/providerUtils.js b/dashboard/src/utils/providerUtils.js index 1dbc66874..4fdc7a929 100644 --- a/dashboard/src/utils/providerUtils.js +++ b/dashboard/src/utils/providerUtils.js @@ -22,6 +22,7 @@ export function getProviderIcon(type) { 'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg', 'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg', 'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg', + "coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg", 'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg', 'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg', 'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg', diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 5da5bd47b..4aa012cf1 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -303,6 +303,7 @@ export default { "googlegenai_chat_completion": "chat_completion", "zhipu_chat_completion": "chat_completion", "dify": "chat_completion", + "coze": "chat_completion", "dashscope": "chat_completion", "openai_whisper_api": "speech_to_text", "openai_whisper_selfhost": "speech_to_text", diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 943363e07..5d33512b4 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -527,12 +527,11 @@ UID: {user_id} 此 ID 可用于设置管理员。 return provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) + if provider and provider.meta().type in ["dify", "coze"]: await provider.forget(message.unified_msg_origin) message.set_result( MessageEventResult().message( - "已重置当前 Dify 会话,新聊天将更换到新的会话。" + "已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。" ) ) return @@ -755,8 +754,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 创建新对话 """ provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) + if provider and provider.meta().type in ["dify", "coze"]: await provider.forget(message.unified_msg_origin) message.set_result( MessageEventResult().message("成功,下次聊天将是新对话。") @@ -783,8 +781,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 async def groupnew_conv(self, message: AstrMessageEvent, sid: str): """创建新群聊对话""" provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) + if provider and provider.meta().type in ["dify", "coze"]: await provider.forget(message.unified_msg_origin) message.set_result( MessageEventResult().message("成功,下次聊天将是新对话。") @@ -823,7 +820,6 @@ UID: {user_id} 此 ID 可用于设置管理员。 provider = self.context.get_using_provider(message.unified_msg_origin) if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) data = await provider.api_client.get_chat_convs(message.unified_msg_origin) if not data["data"]: message.set_result(MessageEventResult().message("未找到任何对话。")) diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index 4e817564c..b799e5a17 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -375,5 +375,3 @@ class Main(star.Star): tool_set.add_tool(tavily_extract_web_page) tool_set.remove_tool("web_search") tool_set.remove_tool("fetch_url") - - print(req.func_tool)