feat(decision): auto-reload prompt templates when starting trader (#833)

* feat: 启动交易员时自动重新加载系统提示词模板
## 改动内容
- 在 handleStartTrader 中调用 decision.ReloadPromptTemplates()
- 每次启动交易员时从硬盘重新加载 prompts/ 目录下的所有 .txt 模板文件
- 添加完整的单元测试和端到端集成测试
## 测试覆盖
- 单元测试:模板加载、获取、重新加载功能
- 集成测试:文件修改 → 重新加载 → 决策引擎使用新内容的完整流程
- 并发测试:验证多 goroutine 场景下的线程安全性
- Race detector 测试通过
## 用户体验改进
- 修改 prompt 文件后无需重启服务
- 只需停止交易员再启动即可应用新的 prompt
- 控制台会输出重新加载成功的日志提示
* feat: 在重新加载日志中显示当前使用的模板名称
* feat: fallback 到 default 模板时明确显示原因
* fix: correct GetTraderConfig return type to get SystemPromptTemplate
* refactor: extract reloadPromptTemplatesWithLog as reusable method
This commit is contained in:
Lawrence Liu
2025-11-11 10:37:46 +08:00
committed by GitHub
parent 57e31b2ace
commit 6efe733127
3 changed files with 549 additions and 1 deletions
+21 -1
View File
@@ -792,12 +792,15 @@ func (s *Server) handleStartTrader(c *gin.Context) {
traderID := c.Param("id")
// 校验交易员是否属于当前用户
_, _, _, err := s.database.GetTraderConfig(userID, traderID)
traderRecord, _, _, err := s.database.GetTraderConfig(userID, traderID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在或无访问权限"})
return
}
// 获取模板名称
templateName := traderRecord.SystemPromptTemplate
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在"})
@@ -811,6 +814,9 @@ func (s *Server) handleStartTrader(c *gin.Context) {
return
}
// 重新加载系统提示词模板(确保使用最新的硬盘文件)
s.reloadPromptTemplatesWithLog(templateName)
// 启动交易员
go func() {
log.Printf("▶️ 启动交易员 %s (%s)", traderID, trader.GetName())
@@ -2318,3 +2324,17 @@ func (s *Server) handleGetPublicTraderConfig(c *gin.Context) {
c.JSON(http.StatusOK, result)
}
// reloadPromptTemplatesWithLog 重新加载提示词模板并记录日志
func (s *Server) reloadPromptTemplatesWithLog(templateName string) {
if err := decision.ReloadPromptTemplates(); err != nil {
log.Printf("⚠️ 重新加载提示词模板失败: %v", err)
return
}
if templateName == "" {
log.Printf("✓ 已重新加载系统提示词模板 [当前使用: default (未指定,使用默认)]")
} else {
log.Printf("✓ 已重新加载系统提示词模板 [当前使用: %s]", templateName)
}
}
+285
View File
@@ -0,0 +1,285 @@
package decision
import (
"os"
"path/filepath"
"testing"
)
func TestPromptManager_LoadTemplates(t *testing.T) {
// 创建临时目录用于测试
tempDir := t.TempDir()
tests := []struct {
name string
setupFiles map[string]string // 文件名 -> 内容
expectedCount int
expectedNames []string
shouldError bool
}{
{
name: "加载单个模板文件",
setupFiles: map[string]string{
"default.txt": "你是专业的加密货币交易AI。",
},
expectedCount: 1,
expectedNames: []string{"default"},
shouldError: false,
},
{
name: "加载多个模板文件",
setupFiles: map[string]string{
"default.txt": "默认策略",
"conservative.txt": "保守策略",
"aggressive.txt": "激进策略",
},
expectedCount: 3,
expectedNames: []string{"default", "conservative", "aggressive"},
shouldError: false,
},
{
name: "空目录",
setupFiles: map[string]string{},
expectedCount: 0,
expectedNames: []string{},
shouldError: false,
},
{
name: "忽略非.txt文件",
setupFiles: map[string]string{
"default.txt": "正确的模板",
"readme.md": "应该被忽略",
"config.json": "应该被忽略",
},
expectedCount: 1,
expectedNames: []string{"default"},
shouldError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 为每个测试用例创建独立的子目录
testDir := filepath.Join(tempDir, tt.name)
if err := os.MkdirAll(testDir, 0755); err != nil {
t.Fatalf("创建测试目录失败: %v", err)
}
// 设置测试文件
for filename, content := range tt.setupFiles {
filePath := filepath.Join(testDir, filename)
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
t.Fatalf("创建测试文件失败 %s: %v", filename, err)
}
}
// 创建新的 PromptManager
pm := NewPromptManager()
// 执行测试
err := pm.LoadTemplates(testDir)
// 检查错误
if (err != nil) != tt.shouldError {
t.Errorf("LoadTemplates() error = %v, shouldError %v", err, tt.shouldError)
return
}
// 检查加载的模板数量
if len(pm.templates) != tt.expectedCount {
t.Errorf("加载的模板数量 = %d, 期望 %d", len(pm.templates), tt.expectedCount)
}
// 检查模板名称
for _, expectedName := range tt.expectedNames {
if _, exists := pm.templates[expectedName]; !exists {
t.Errorf("缺少预期的模板: %s", expectedName)
}
}
// 验证模板内容
for filename, expectedContent := range tt.setupFiles {
if filepath.Ext(filename) != ".txt" {
continue
}
templateName := filename[:len(filename)-4] // 去掉 .txt
template, err := pm.GetTemplate(templateName)
if err != nil {
t.Errorf("获取模板 %s 失败: %v", templateName, err)
continue
}
if template.Content != expectedContent {
t.Errorf("模板内容不匹配\n期望: %s\n实际: %s", expectedContent, template.Content)
}
}
})
}
}
func TestPromptManager_GetTemplate(t *testing.T) {
pm := NewPromptManager()
pm.templates = map[string]*PromptTemplate{
"default": {
Name: "default",
Content: "默认策略内容",
},
"aggressive": {
Name: "aggressive",
Content: "激进策略内容",
},
}
tests := []struct {
name string
templateName string
expectError bool
expectedContent string
}{
{
name: "获取存在的模板",
templateName: "default",
expectError: false,
expectedContent: "默认策略内容",
},
{
name: "获取不存在的模板",
templateName: "nonexistent",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
template, err := pm.GetTemplate(tt.templateName)
if (err != nil) != tt.expectError {
t.Errorf("GetTemplate() error = %v, expectError %v", err, tt.expectError)
return
}
if !tt.expectError && template.Content != tt.expectedContent {
t.Errorf("模板内容 = %s, 期望 %s", template.Content, tt.expectedContent)
}
})
}
}
func TestPromptManager_ReloadTemplates(t *testing.T) {
tempDir := t.TempDir()
// 初始文件
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("初始内容"), 0644); err != nil {
t.Fatalf("创建初始文件失败: %v", err)
}
pm := NewPromptManager()
if err := pm.LoadTemplates(tempDir); err != nil {
t.Fatalf("初始加载失败: %v", err)
}
// 验证初始内容
template, _ := pm.GetTemplate("default")
if template.Content != "初始内容" {
t.Errorf("初始内容不正确: %s", template.Content)
}
// 修改文件内容
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("更新后内容"), 0644); err != nil {
t.Fatalf("更新文件失败: %v", err)
}
// 添加新文件
if err := os.WriteFile(filepath.Join(tempDir, "new.txt"), []byte("新模板内容"), 0644); err != nil {
t.Fatalf("创建新文件失败: %v", err)
}
// 重新加载
if err := pm.ReloadTemplates(tempDir); err != nil {
t.Fatalf("重新加载失败: %v", err)
}
// 验证更新后的内容
template, err := pm.GetTemplate("default")
if err != nil {
t.Fatalf("获取 default 模板失败: %v", err)
}
if template.Content != "更新后内容" {
t.Errorf("重新加载后内容不正确: got %s, want '更新后内容'", template.Content)
}
// 验证新模板
newTemplate, err := pm.GetTemplate("new")
if err != nil {
t.Fatalf("获取 new 模板失败: %v", err)
}
if newTemplate.Content != "新模板内容" {
t.Errorf("新模板内容不正确: %s", newTemplate.Content)
}
// 验证模板数量
if len(pm.templates) != 2 {
t.Errorf("重新加载后模板数量 = %d, 期望 2", len(pm.templates))
}
}
func TestPromptManager_GetAllTemplateNames(t *testing.T) {
pm := NewPromptManager()
pm.templates = map[string]*PromptTemplate{
"default": {Name: "default", Content: "默认策略"},
"conservative": {Name: "conservative", Content: "保守策略"},
"aggressive": {Name: "aggressive", Content: "激进策略"},
}
names := pm.GetAllTemplateNames()
if len(names) != 3 {
t.Errorf("GetAllTemplateNames() 返回数量 = %d, 期望 3", len(names))
}
// 验证所有名称都存在
nameMap := make(map[string]bool)
for _, name := range names {
nameMap[name] = true
}
expectedNames := []string{"default", "conservative", "aggressive"}
for _, expectedName := range expectedNames {
if !nameMap[expectedName] {
t.Errorf("缺少预期的模板名称: %s", expectedName)
}
}
}
func TestReloadPromptTemplates_GlobalFunction(t *testing.T) {
// 保存原始的 promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
// 恢复原始模板
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
tempDir := t.TempDir()
promptsDir = tempDir
// 创建测试文件
if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("测试内容"), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
}
// 调用全局重新加载函数
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("ReloadPromptTemplates() 失败: %v", err)
}
// 验证全局管理器已更新
template, err := GetPromptTemplate("test")
if err != nil {
t.Fatalf("获取模板失败: %v", err)
}
if template.Content != "测试内容" {
t.Errorf("模板内容不正确: got %s, want '测试内容'", template.Content)
}
}
+243
View File
@@ -0,0 +1,243 @@
package decision
import (
"os"
"path/filepath"
"strings"
"testing"
)
// TestPromptReloadEndToEnd 端到端测试:验证从文件修改到决策引擎使用的完整流程
func TestPromptReloadEndToEnd(t *testing.T) {
// 保存原始的 promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
// 恢复原始模板
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录模拟 prompts/ 目录
tempDir := t.TempDir()
promptsDir = tempDir
// 步骤1: 创建初始 prompt 文件
initialContent := "# 初始交易策略\n你是一个保守的交易AI。"
if err := os.WriteFile(filepath.Join(tempDir, "test_strategy.txt"), []byte(initialContent), 0644); err != nil {
t.Fatalf("创建初始文件失败: %v", err)
}
// 步骤2: 首次加载(模拟系统启动)
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("首次加载失败: %v", err)
}
// 步骤3: 验证初始内容
template, err := GetPromptTemplate("test_strategy")
if err != nil {
t.Fatalf("获取初始模板失败: %v", err)
}
if template.Content != initialContent {
t.Errorf("初始内容不匹配\n期望: %s\n实际: %s", initialContent, template.Content)
}
// 步骤4: 使用 buildSystemPrompt 验证模板被正确使用
systemPrompt := buildSystemPrompt(10000.0, 10, 5, "test_strategy")
if !strings.Contains(systemPrompt, initialContent) {
t.Errorf("buildSystemPrompt 未包含模板内容\n生成的 prompt:\n%s", systemPrompt)
}
// 步骤5: 模拟用户修改文件(这是用户在硬盘上修改 prompt)
updatedContent := "# 更新的交易策略\n你是一个激进的交易AI,追求高风险高收益。"
if err := os.WriteFile(filepath.Join(tempDir, "test_strategy.txt"), []byte(updatedContent), 0644); err != nil {
t.Fatalf("更新文件失败: %v", err)
}
// 步骤6: 模拟交易员启动时调用 ReloadPromptTemplates()
t.Log("模拟交易员启动,调用 ReloadPromptTemplates()...")
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("重新加载失败: %v", err)
}
// 步骤7: 验证新内容已生效
reloadedTemplate, err := GetPromptTemplate("test_strategy")
if err != nil {
t.Fatalf("获取重新加载的模板失败: %v", err)
}
if reloadedTemplate.Content != updatedContent {
t.Errorf("重新加载后内容不匹配\n期望: %s\n实际: %s", updatedContent, reloadedTemplate.Content)
}
// 步骤8: 验证 buildSystemPrompt 使用了新内容
newSystemPrompt := buildSystemPrompt(10000.0, 10, 5, "test_strategy")
if !strings.Contains(newSystemPrompt, updatedContent) {
t.Errorf("buildSystemPrompt 未包含更新后的模板内容\n生成的 prompt:\n%s", newSystemPrompt)
}
// 步骤9: 验证旧内容不再存在
if strings.Contains(newSystemPrompt, "保守的交易AI") {
t.Errorf("buildSystemPrompt 仍包含旧的模板内容")
}
t.Log("✅ 端到端测试通过:文件修改 -> 重新加载 -> 决策引擎使用新内容")
}
// TestPromptReloadWithCustomPrompt 测试自定义 prompt 与模板重新加载的交互
func TestPromptReloadWithCustomPrompt(t *testing.T) {
// 保存原始的 promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
tempDir := t.TempDir()
promptsDir = tempDir
// 创建基础模板
baseContent := "基础策略:稳健交易"
if err := os.WriteFile(filepath.Join(tempDir, "base.txt"), []byte(baseContent), 0644); err != nil {
t.Fatalf("创建文件失败: %v", err)
}
// 加载模板
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("加载失败: %v", err)
}
// 测试1: 基础模板 + 自定义 prompt(不覆盖)
customPrompt := "个性化规则:只交易 BTC"
result := buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, false, "base")
if !strings.Contains(result, baseContent) {
t.Errorf("未包含基础模板内容")
}
if !strings.Contains(result, customPrompt) {
t.Errorf("未包含自定义 prompt")
}
// 测试2: 覆盖基础 prompt
result = buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, true, "base")
if strings.Contains(result, baseContent) {
t.Errorf("覆盖模式下仍包含基础模板内容")
}
if !strings.Contains(result, customPrompt) {
t.Errorf("覆盖模式下未包含自定义 prompt")
}
// 测试3: 重新加载后效果
updatedBase := "更新的基础策略:激进交易"
if err := os.WriteFile(filepath.Join(tempDir, "base.txt"), []byte(updatedBase), 0644); err != nil {
t.Fatalf("更新文件失败: %v", err)
}
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("重新加载失败: %v", err)
}
result = buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, false, "base")
if !strings.Contains(result, updatedBase) {
t.Errorf("重新加载后未包含更新的基础模板内容")
}
if strings.Contains(result, baseContent) {
t.Errorf("重新加载后仍包含旧的基础模板内容")
}
}
// TestPromptReloadFallback 测试模板不存在时的降级机制
func TestPromptReloadFallback(t *testing.T) {
// 保存原始的 promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
tempDir := t.TempDir()
promptsDir = tempDir
// 只创建 default 模板
defaultContent := "默认策略"
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte(defaultContent), 0644); err != nil {
t.Fatalf("创建文件失败: %v", err)
}
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("加载失败: %v", err)
}
// 测试1: 请求不存在的模板,应该降级到 default
result := buildSystemPrompt(10000.0, 10, 5, "nonexistent")
if !strings.Contains(result, defaultContent) {
t.Errorf("请求不存在的模板时,未降级到 default")
}
// 测试2: 空模板名,应该使用 default
result = buildSystemPrompt(10000.0, 10, 5, "")
if !strings.Contains(result, defaultContent) {
t.Errorf("空模板名时,未使用 default")
}
}
// TestConcurrentPromptReload 测试并发场景下的 prompt 重新加载
func TestConcurrentPromptReload(t *testing.T) {
// 保存原始的 promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
tempDir := t.TempDir()
promptsDir = tempDir
// 创建测试文件
if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("测试内容"), 0644); err != nil {
t.Fatalf("创建文件失败: %v", err)
}
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("初始加载失败: %v", err)
}
// 并发测试:同时读取和重新加载
done := make(chan bool)
// 启动多个读取 goroutine
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 100; j++ {
_, _ = GetPromptTemplate("test")
}
done <- true
}()
}
// 启动多个重新加载 goroutine
for i := 0; i < 3; i++ {
go func() {
for j := 0; j < 10; j++ {
_ = ReloadPromptTemplates()
}
done <- true
}()
}
// 等待所有 goroutine 完成
for i := 0; i < 13; i++ {
<-done
}
// 验证最终状态正确
template, err := GetPromptTemplate("test")
if err != nil {
t.Errorf("并发测试后获取模板失败: %v", err)
}
if template.Content != "测试内容" {
t.Errorf("并发测试后模板内容错误: %s", template.Content)
}
t.Log("✅ 并发测试通过:多个 goroutine 同时读取和重新加载模板,无数据竞争")
}