diff --git a/api/server.go b/api/server.go index bc6f0986..c758da3d 100644 --- a/api/server.go +++ b/api/server.go @@ -792,12 +792,15 @@ func (s *Server) handleStartTrader(c *gin.Context) { traderID := c.Param("id") // 校验交易员是否属于当前用户 - _, _, _, err := s.database.GetTraderConfig(userID, traderID) + traderRecord, _, _, err := s.database.GetTraderConfig(userID, traderID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在或无访问权限"}) return } + // 获取模板名称 + templateName := traderRecord.SystemPromptTemplate + trader, err := s.traderManager.GetTrader(traderID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在"}) @@ -811,6 +814,9 @@ func (s *Server) handleStartTrader(c *gin.Context) { return } + // 重新加载系统提示词模板(确保使用最新的硬盘文件) + s.reloadPromptTemplatesWithLog(templateName) + // 启动交易员 go func() { log.Printf("▶️ 启动交易员 %s (%s)", traderID, trader.GetName()) @@ -2318,3 +2324,17 @@ func (s *Server) handleGetPublicTraderConfig(c *gin.Context) { c.JSON(http.StatusOK, result) } + +// reloadPromptTemplatesWithLog 重新加载提示词模板并记录日志 +func (s *Server) reloadPromptTemplatesWithLog(templateName string) { + if err := decision.ReloadPromptTemplates(); err != nil { + log.Printf("⚠️ 重新加载提示词模板失败: %v", err) + return + } + + if templateName == "" { + log.Printf("✓ 已重新加载系统提示词模板 [当前使用: default (未指定,使用默认)]") + } else { + log.Printf("✓ 已重新加载系统提示词模板 [当前使用: %s]", templateName) + } +} diff --git a/decision/prompt_manager_test.go b/decision/prompt_manager_test.go new file mode 100644 index 00000000..56f905ba --- /dev/null +++ b/decision/prompt_manager_test.go @@ -0,0 +1,285 @@ +package decision + +import ( + "os" + "path/filepath" + "testing" +) + +func TestPromptManager_LoadTemplates(t *testing.T) { + // 创建临时目录用于测试 + tempDir := t.TempDir() + + tests := []struct { + name string + setupFiles map[string]string // 文件名 -> 内容 + expectedCount int + expectedNames []string + shouldError bool + }{ + { + name: "加载单个模板文件", + setupFiles: map[string]string{ + "default.txt": "你是专业的加密货币交易AI。", + }, + expectedCount: 1, + expectedNames: []string{"default"}, + shouldError: false, + }, + { + name: "加载多个模板文件", + setupFiles: map[string]string{ + "default.txt": "默认策略", + "conservative.txt": "保守策略", + "aggressive.txt": "激进策略", + }, + expectedCount: 3, + expectedNames: []string{"default", "conservative", "aggressive"}, + shouldError: false, + }, + { + name: "空目录", + setupFiles: map[string]string{}, + expectedCount: 0, + expectedNames: []string{}, + shouldError: false, + }, + { + name: "忽略非.txt文件", + setupFiles: map[string]string{ + "default.txt": "正确的模板", + "readme.md": "应该被忽略", + "config.json": "应该被忽略", + }, + expectedCount: 1, + expectedNames: []string{"default"}, + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 为每个测试用例创建独立的子目录 + testDir := filepath.Join(tempDir, tt.name) + if err := os.MkdirAll(testDir, 0755); err != nil { + t.Fatalf("创建测试目录失败: %v", err) + } + + // 设置测试文件 + for filename, content := range tt.setupFiles { + filePath := filepath.Join(testDir, filename) + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Fatalf("创建测试文件失败 %s: %v", filename, err) + } + } + + // 创建新的 PromptManager + pm := NewPromptManager() + + // 执行测试 + err := pm.LoadTemplates(testDir) + + // 检查错误 + if (err != nil) != tt.shouldError { + t.Errorf("LoadTemplates() error = %v, shouldError %v", err, tt.shouldError) + return + } + + // 检查加载的模板数量 + if len(pm.templates) != tt.expectedCount { + t.Errorf("加载的模板数量 = %d, 期望 %d", len(pm.templates), tt.expectedCount) + } + + // 检查模板名称 + for _, expectedName := range tt.expectedNames { + if _, exists := pm.templates[expectedName]; !exists { + t.Errorf("缺少预期的模板: %s", expectedName) + } + } + + // 验证模板内容 + for filename, expectedContent := range tt.setupFiles { + if filepath.Ext(filename) != ".txt" { + continue + } + templateName := filename[:len(filename)-4] // 去掉 .txt + template, err := pm.GetTemplate(templateName) + if err != nil { + t.Errorf("获取模板 %s 失败: %v", templateName, err) + continue + } + if template.Content != expectedContent { + t.Errorf("模板内容不匹配\n期望: %s\n实际: %s", expectedContent, template.Content) + } + } + }) + } +} + +func TestPromptManager_GetTemplate(t *testing.T) { + pm := NewPromptManager() + pm.templates = map[string]*PromptTemplate{ + "default": { + Name: "default", + Content: "默认策略内容", + }, + "aggressive": { + Name: "aggressive", + Content: "激进策略内容", + }, + } + + tests := []struct { + name string + templateName string + expectError bool + expectedContent string + }{ + { + name: "获取存在的模板", + templateName: "default", + expectError: false, + expectedContent: "默认策略内容", + }, + { + name: "获取不存在的模板", + templateName: "nonexistent", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + template, err := pm.GetTemplate(tt.templateName) + + if (err != nil) != tt.expectError { + t.Errorf("GetTemplate() error = %v, expectError %v", err, tt.expectError) + return + } + + if !tt.expectError && template.Content != tt.expectedContent { + t.Errorf("模板内容 = %s, 期望 %s", template.Content, tt.expectedContent) + } + }) + } +} + +func TestPromptManager_ReloadTemplates(t *testing.T) { + tempDir := t.TempDir() + + // 初始文件 + if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("初始内容"), 0644); err != nil { + t.Fatalf("创建初始文件失败: %v", err) + } + + pm := NewPromptManager() + if err := pm.LoadTemplates(tempDir); err != nil { + t.Fatalf("初始加载失败: %v", err) + } + + // 验证初始内容 + template, _ := pm.GetTemplate("default") + if template.Content != "初始内容" { + t.Errorf("初始内容不正确: %s", template.Content) + } + + // 修改文件内容 + if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("更新后内容"), 0644); err != nil { + t.Fatalf("更新文件失败: %v", err) + } + + // 添加新文件 + if err := os.WriteFile(filepath.Join(tempDir, "new.txt"), []byte("新模板内容"), 0644); err != nil { + t.Fatalf("创建新文件失败: %v", err) + } + + // 重新加载 + if err := pm.ReloadTemplates(tempDir); err != nil { + t.Fatalf("重新加载失败: %v", err) + } + + // 验证更新后的内容 + template, err := pm.GetTemplate("default") + if err != nil { + t.Fatalf("获取 default 模板失败: %v", err) + } + if template.Content != "更新后内容" { + t.Errorf("重新加载后内容不正确: got %s, want '更新后内容'", template.Content) + } + + // 验证新模板 + newTemplate, err := pm.GetTemplate("new") + if err != nil { + t.Fatalf("获取 new 模板失败: %v", err) + } + if newTemplate.Content != "新模板内容" { + t.Errorf("新模板内容不正确: %s", newTemplate.Content) + } + + // 验证模板数量 + if len(pm.templates) != 2 { + t.Errorf("重新加载后模板数量 = %d, 期望 2", len(pm.templates)) + } +} + +func TestPromptManager_GetAllTemplateNames(t *testing.T) { + pm := NewPromptManager() + pm.templates = map[string]*PromptTemplate{ + "default": {Name: "default", Content: "默认策略"}, + "conservative": {Name: "conservative", Content: "保守策略"}, + "aggressive": {Name: "aggressive", Content: "激进策略"}, + } + + names := pm.GetAllTemplateNames() + + if len(names) != 3 { + t.Errorf("GetAllTemplateNames() 返回数量 = %d, 期望 3", len(names)) + } + + // 验证所有名称都存在 + nameMap := make(map[string]bool) + for _, name := range names { + nameMap[name] = true + } + + expectedNames := []string{"default", "conservative", "aggressive"} + for _, expectedName := range expectedNames { + if !nameMap[expectedName] { + t.Errorf("缺少预期的模板名称: %s", expectedName) + } + } +} + +func TestReloadPromptTemplates_GlobalFunction(t *testing.T) { + // 保存原始的 promptsDir + originalDir := promptsDir + defer func() { + promptsDir = originalDir + // 恢复原始模板 + globalPromptManager.ReloadTemplates(originalDir) + }() + + // 创建临时目录 + tempDir := t.TempDir() + promptsDir = tempDir + + // 创建测试文件 + if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("测试内容"), 0644); err != nil { + t.Fatalf("创建测试文件失败: %v", err) + } + + // 调用全局重新加载函数 + if err := ReloadPromptTemplates(); err != nil { + t.Fatalf("ReloadPromptTemplates() 失败: %v", err) + } + + // 验证全局管理器已更新 + template, err := GetPromptTemplate("test") + if err != nil { + t.Fatalf("获取模板失败: %v", err) + } + + if template.Content != "测试内容" { + t.Errorf("模板内容不正确: got %s, want '测试内容'", template.Content) + } +} diff --git a/decision/prompt_reload_integration_test.go b/decision/prompt_reload_integration_test.go new file mode 100644 index 00000000..909b3dbb --- /dev/null +++ b/decision/prompt_reload_integration_test.go @@ -0,0 +1,243 @@ +package decision + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// TestPromptReloadEndToEnd 端到端测试:验证从文件修改到决策引擎使用的完整流程 +func TestPromptReloadEndToEnd(t *testing.T) { + // 保存原始的 promptsDir + originalDir := promptsDir + defer func() { + promptsDir = originalDir + // 恢复原始模板 + globalPromptManager.ReloadTemplates(originalDir) + }() + + // 创建临时目录模拟 prompts/ 目录 + tempDir := t.TempDir() + promptsDir = tempDir + + // 步骤1: 创建初始 prompt 文件 + initialContent := "# 初始交易策略\n你是一个保守的交易AI。" + if err := os.WriteFile(filepath.Join(tempDir, "test_strategy.txt"), []byte(initialContent), 0644); err != nil { + t.Fatalf("创建初始文件失败: %v", err) + } + + // 步骤2: 首次加载(模拟系统启动) + if err := ReloadPromptTemplates(); err != nil { + t.Fatalf("首次加载失败: %v", err) + } + + // 步骤3: 验证初始内容 + template, err := GetPromptTemplate("test_strategy") + if err != nil { + t.Fatalf("获取初始模板失败: %v", err) + } + if template.Content != initialContent { + t.Errorf("初始内容不匹配\n期望: %s\n实际: %s", initialContent, template.Content) + } + + // 步骤4: 使用 buildSystemPrompt 验证模板被正确使用 + systemPrompt := buildSystemPrompt(10000.0, 10, 5, "test_strategy") + if !strings.Contains(systemPrompt, initialContent) { + t.Errorf("buildSystemPrompt 未包含模板内容\n生成的 prompt:\n%s", systemPrompt) + } + + // 步骤5: 模拟用户修改文件(这是用户在硬盘上修改 prompt) + updatedContent := "# 更新的交易策略\n你是一个激进的交易AI,追求高风险高收益。" + if err := os.WriteFile(filepath.Join(tempDir, "test_strategy.txt"), []byte(updatedContent), 0644); err != nil { + t.Fatalf("更新文件失败: %v", err) + } + + // 步骤6: 模拟交易员启动时调用 ReloadPromptTemplates() + t.Log("模拟交易员启动,调用 ReloadPromptTemplates()...") + if err := ReloadPromptTemplates(); err != nil { + t.Fatalf("重新加载失败: %v", err) + } + + // 步骤7: 验证新内容已生效 + reloadedTemplate, err := GetPromptTemplate("test_strategy") + if err != nil { + t.Fatalf("获取重新加载的模板失败: %v", err) + } + if reloadedTemplate.Content != updatedContent { + t.Errorf("重新加载后内容不匹配\n期望: %s\n实际: %s", updatedContent, reloadedTemplate.Content) + } + + // 步骤8: 验证 buildSystemPrompt 使用了新内容 + newSystemPrompt := buildSystemPrompt(10000.0, 10, 5, "test_strategy") + if !strings.Contains(newSystemPrompt, updatedContent) { + t.Errorf("buildSystemPrompt 未包含更新后的模板内容\n生成的 prompt:\n%s", newSystemPrompt) + } + + // 步骤9: 验证旧内容不再存在 + if strings.Contains(newSystemPrompt, "保守的交易AI") { + t.Errorf("buildSystemPrompt 仍包含旧的模板内容") + } + + t.Log("✅ 端到端测试通过:文件修改 -> 重新加载 -> 决策引擎使用新内容") +} + +// TestPromptReloadWithCustomPrompt 测试自定义 prompt 与模板重新加载的交互 +func TestPromptReloadWithCustomPrompt(t *testing.T) { + // 保存原始的 promptsDir + originalDir := promptsDir + defer func() { + promptsDir = originalDir + globalPromptManager.ReloadTemplates(originalDir) + }() + + // 创建临时目录 + tempDir := t.TempDir() + promptsDir = tempDir + + // 创建基础模板 + baseContent := "基础策略:稳健交易" + if err := os.WriteFile(filepath.Join(tempDir, "base.txt"), []byte(baseContent), 0644); err != nil { + t.Fatalf("创建文件失败: %v", err) + } + + // 加载模板 + if err := ReloadPromptTemplates(); err != nil { + t.Fatalf("加载失败: %v", err) + } + + // 测试1: 基础模板 + 自定义 prompt(不覆盖) + customPrompt := "个性化规则:只交易 BTC" + result := buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, false, "base") + if !strings.Contains(result, baseContent) { + t.Errorf("未包含基础模板内容") + } + if !strings.Contains(result, customPrompt) { + t.Errorf("未包含自定义 prompt") + } + + // 测试2: 覆盖基础 prompt + result = buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, true, "base") + if strings.Contains(result, baseContent) { + t.Errorf("覆盖模式下仍包含基础模板内容") + } + if !strings.Contains(result, customPrompt) { + t.Errorf("覆盖模式下未包含自定义 prompt") + } + + // 测试3: 重新加载后效果 + updatedBase := "更新的基础策略:激进交易" + if err := os.WriteFile(filepath.Join(tempDir, "base.txt"), []byte(updatedBase), 0644); err != nil { + t.Fatalf("更新文件失败: %v", err) + } + + if err := ReloadPromptTemplates(); err != nil { + t.Fatalf("重新加载失败: %v", err) + } + + result = buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, false, "base") + if !strings.Contains(result, updatedBase) { + t.Errorf("重新加载后未包含更新的基础模板内容") + } + if strings.Contains(result, baseContent) { + t.Errorf("重新加载后仍包含旧的基础模板内容") + } +} + +// TestPromptReloadFallback 测试模板不存在时的降级机制 +func TestPromptReloadFallback(t *testing.T) { + // 保存原始的 promptsDir + originalDir := promptsDir + defer func() { + promptsDir = originalDir + globalPromptManager.ReloadTemplates(originalDir) + }() + + // 创建临时目录 + tempDir := t.TempDir() + promptsDir = tempDir + + // 只创建 default 模板 + defaultContent := "默认策略" + if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte(defaultContent), 0644); err != nil { + t.Fatalf("创建文件失败: %v", err) + } + + if err := ReloadPromptTemplates(); err != nil { + t.Fatalf("加载失败: %v", err) + } + + // 测试1: 请求不存在的模板,应该降级到 default + result := buildSystemPrompt(10000.0, 10, 5, "nonexistent") + if !strings.Contains(result, defaultContent) { + t.Errorf("请求不存在的模板时,未降级到 default") + } + + // 测试2: 空模板名,应该使用 default + result = buildSystemPrompt(10000.0, 10, 5, "") + if !strings.Contains(result, defaultContent) { + t.Errorf("空模板名时,未使用 default") + } +} + +// TestConcurrentPromptReload 测试并发场景下的 prompt 重新加载 +func TestConcurrentPromptReload(t *testing.T) { + // 保存原始的 promptsDir + originalDir := promptsDir + defer func() { + promptsDir = originalDir + globalPromptManager.ReloadTemplates(originalDir) + }() + + // 创建临时目录 + tempDir := t.TempDir() + promptsDir = tempDir + + // 创建测试文件 + if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("测试内容"), 0644); err != nil { + t.Fatalf("创建文件失败: %v", err) + } + + if err := ReloadPromptTemplates(); err != nil { + t.Fatalf("初始加载失败: %v", err) + } + + // 并发测试:同时读取和重新加载 + done := make(chan bool) + + // 启动多个读取 goroutine + for i := 0; i < 10; i++ { + go func() { + for j := 0; j < 100; j++ { + _, _ = GetPromptTemplate("test") + } + done <- true + }() + } + + // 启动多个重新加载 goroutine + for i := 0; i < 3; i++ { + go func() { + for j := 0; j < 10; j++ { + _ = ReloadPromptTemplates() + } + done <- true + }() + } + + // 等待所有 goroutine 完成 + for i := 0; i < 13; i++ { + <-done + } + + // 验证最终状态正确 + template, err := GetPromptTemplate("test") + if err != nil { + t.Errorf("并发测试后获取模板失败: %v", err) + } + if template.Content != "测试内容" { + t.Errorf("并发测试后模板内容错误: %s", template.Content) + } + + t.Log("✅ 并发测试通过:多个 goroutine 同时读取和重新加载模板,无数据竞争") +}