From 295124c1fad9c8d08b348dcce1d1e144b567f503 Mon Sep 17 00:00:00 2001 From: WquGuru Date: Sun, 9 Nov 2025 17:43:28 +0800 Subject: [PATCH] test(trader): add comprehensive unit tests and CI coverage reporting (#823) * chore(config): add Python and uv support to project - Add comprehensive Python .gitignore rules (pycache, venv, pytest, etc.) - Add uv package manager specific ignores (.uv/, uv.lock) - Initialize pyproject.toml for Python tooling Co-authored-by: tinkle-community * chore(deps): add testing dependencies - Add github.com/stretchr/testify v1.11.1 for test assertions - Add github.com/agiledragon/gomonkey/v2 v2.13.0 for mocking - Promote github.com/rs/zerolog to direct dependency Co-authored-by: tinkle-community * ci(workflow): add PR test coverage reporting Add GitHub Actions workflow to run unit tests and report coverage on PRs: - Run Go tests with race detection and coverage profiling - Calculate coverage statistics and generate detailed reports - Post coverage results as PR comments with visual indicators - Fix Go version to 1.23 (was incorrectly set to 1.25.0) Coverage guidelines: - Green (>=80%): excellent - Yellow (>=60%): good - Orange (>=40%): fair - Red (<40%): needs improvement This workflow is advisory only and does not block PR merging. Co-authored-by: tinkle-community * test(trader): add comprehensive unit tests for trader modules Add unit test suites for multiple trader implementations: - aster_trader_test.go: AsterTrader functionality tests - auto_trader_test.go: AutoTrader lifecycle and operations tests - binance_futures_test.go: Binance futures trader tests - hyperliquid_trader_test.go: Hyperliquid trader tests - trader_test_suite.go: Common test suite utilities and helpers Also fix minor formatting issue in auto_trader.go (trailing whitespace) Co-authored-by: tinkle-community * test(trader): preserve existing calculatePnLPercentage unit tests Merge existing calculatePnLPercentage tests with incoming comprehensive test suite: - Preserve TestCalculatePnLPercentage with 9 test cases covering edge cases - Preserve TestCalculatePnLPercentage_RealWorldScenarios with 3 trading scenarios - Add math package import for floating-point precision comparison - All tests validate PnL percentage calculation with different leverage scenarios Co-authored-by: tinkle-community --------- Co-authored-by: tinkle-community --- .github/workflows/pr-go-test-coverage.yml | 78 ++ .../workflows/scripts/calculate_coverage.py | 192 +++ .github/workflows/scripts/comment_pr.py | 246 ++++ .github/workflows/scripts/requirements.txt | 2 + .gitignore | 58 + go.mod | 7 +- go.sum | 13 + pyproject.toml | 7 + trader/aster_trader_test.go | 299 +++++ trader/auto_trader.go | 2 +- trader/auto_trader_test.go | 1174 ++++++++++++++++- trader/binance_futures_test.go | 420 ++++++ trader/hyperliquid_trader_test.go | 646 +++++++++ trader/trader_test_suite.go | 664 ++++++++++ 14 files changed, 3766 insertions(+), 42 deletions(-) create mode 100644 .github/workflows/pr-go-test-coverage.yml create mode 100755 .github/workflows/scripts/calculate_coverage.py create mode 100755 .github/workflows/scripts/comment_pr.py create mode 100644 .github/workflows/scripts/requirements.txt create mode 100644 pyproject.toml create mode 100644 trader/aster_trader_test.go create mode 100644 trader/binance_futures_test.go create mode 100644 trader/hyperliquid_trader_test.go create mode 100644 trader/trader_test_suite.go diff --git a/.github/workflows/pr-go-test-coverage.yml b/.github/workflows/pr-go-test-coverage.yml new file mode 100644 index 00000000..fb5134da --- /dev/null +++ b/.github/workflows/pr-go-test-coverage.yml @@ -0,0 +1,78 @@ +name: Go Test Coverage + +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - dev + - main + push: + branches: + - dev + - main + +permissions: + contents: read + pull-requests: write + +jobs: + test-coverage: + name: Go Unit Tests & Coverage + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install -r .github/workflows/scripts/requirements.txt + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Run tests with coverage + run: | + go test -v -race -coverprofile=coverage.out -covermode=atomic ./... + + - name: Calculate coverage and generate report + id: coverage + run: | + chmod +x .github/workflows/scripts/calculate_coverage.py + python .github/workflows/scripts/calculate_coverage.py coverage.out coverage_report.md + + - name: Comment PR with coverage + if: github.event_name == 'pull_request' + continue-on-error: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + chmod +x .github/workflows/scripts/comment_pr.py + python .github/workflows/scripts/comment_pr.py \ + ${{ github.event.pull_request.number }} \ + "${{ steps.coverage.outputs.coverage }}" \ + "${{ steps.coverage.outputs.emoji }}" \ + "${{ steps.coverage.outputs.status }}" \ + "${{ steps.coverage.outputs.badge_color }}" \ + coverage_report.md diff --git a/.github/workflows/scripts/calculate_coverage.py b/.github/workflows/scripts/calculate_coverage.py new file mode 100755 index 00000000..735da873 --- /dev/null +++ b/.github/workflows/scripts/calculate_coverage.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +""" +Calculate Go test coverage and generate reports. + +This script parses the coverage.out file generated by `go test -coverprofile`, +extracts coverage statistics, and generates formatted reports. +""" + +import sys +import re +import os +from typing import Dict, List, Tuple + + +def parse_coverage_file(coverage_file: str) -> Tuple[float, Dict[str, float]]: + """ + Parse coverage output file and extract coverage data. + + Args: + coverage_file: Path to coverage.out file + + Returns: + Tuple of (total_coverage, package_coverage_dict) + """ + if not os.path.exists(coverage_file): + print(f"Error: Coverage file {coverage_file} not found", file=sys.stderr) + sys.exit(1) + + # Run go tool cover to get coverage data + import subprocess + + try: + result = subprocess.run( + ['go', 'tool', 'cover', '-func', coverage_file], + capture_output=True, + text=True, + check=True + ) + except subprocess.CalledProcessError as e: + print(f"Error running go tool cover: {e}", file=sys.stderr) + sys.exit(1) + + lines = result.stdout.strip().split('\n') + package_coverage = {} + total_coverage = 0.0 + + for line in lines: + # Skip empty lines + if not line.strip(): + continue + + # Check for total coverage line + if line.startswith('total:'): + # Extract percentage from "total: (statements) XX.X%" + match = re.search(r'(\d+\.\d+)%', line) + if match: + total_coverage = float(match.group(1)) + continue + + # Parse package/file coverage + # Format: "package/file.go:function statements coverage%" + parts = line.split() + if len(parts) >= 3: + file_path = parts[0] + coverage_str = parts[-1] + + # Extract package name from file path + package = file_path.split(':')[0] + package_name = '/'.join(package.split('/')[:-1]) if '/' in package else package + + # Extract coverage percentage + match = re.search(r'(\d+\.\d+)%', coverage_str) + if match: + coverage_pct = float(match.group(1)) + + # Aggregate by package + if package_name not in package_coverage: + package_coverage[package_name] = [] + package_coverage[package_name].append(coverage_pct) + + # Calculate average coverage per package + package_avg = { + pkg: sum(coverages) / len(coverages) + for pkg, coverages in package_coverage.items() + } + + return total_coverage, package_avg + + +def get_coverage_status(coverage: float) -> Tuple[str, str, str]: + """ + Get coverage status based on percentage. + + Args: + coverage: Coverage percentage + + Returns: + Tuple of (emoji, status_text, badge_color) + """ + if coverage >= 80: + return '🟢', 'excellent', 'brightgreen' + elif coverage >= 60: + return '🟡', 'good', 'yellow' + elif coverage >= 40: + return '🟠', 'fair', 'orange' + else: + return '🔴', 'needs improvement', 'red' + + +def generate_coverage_report(coverage_file: str, output_file: str) -> None: + """ + Generate a detailed coverage report in markdown format. + + Args: + coverage_file: Path to coverage.out file + output_file: Path to output markdown file + """ + import subprocess + + try: + result = subprocess.run( + ['go', 'tool', 'cover', '-func', coverage_file], + capture_output=True, + text=True, + check=True + ) + except subprocess.CalledProcessError as e: + print(f"Error generating coverage report: {e}", file=sys.stderr) + sys.exit(1) + + with open(output_file, 'w') as f: + f.write("## Coverage by Package\n\n") + f.write("```\n") + f.write(result.stdout) + f.write("```\n") + + +def set_github_output(name: str, value: str) -> None: + """ + Set GitHub Actions output variable. + + Args: + name: Output variable name + value: Output variable value + """ + github_output = os.environ.get('GITHUB_OUTPUT') + if github_output: + with open(github_output, 'a') as f: + f.write(f"{name}={value}\n") + else: + print(f"::set-output name={name}::{value}") + + +def main(): + """Main entry point.""" + if len(sys.argv) < 2: + print("Usage: calculate_coverage.py [output_file]", file=sys.stderr) + sys.exit(1) + + coverage_file = sys.argv[1] + output_file = sys.argv[2] if len(sys.argv) > 2 else 'coverage_report.md' + + # Parse coverage data + total_coverage, package_coverage = parse_coverage_file(coverage_file) + + # Get coverage status + emoji, status, badge_color = get_coverage_status(total_coverage) + + # Generate detailed report + generate_coverage_report(coverage_file, output_file) + + # Output results + print(f"Total Coverage: {total_coverage}%") + print(f"Status: {status}") + print(f"Badge Color: {badge_color}") + + # Set GitHub Actions outputs + set_github_output('coverage', f'{total_coverage}%') + set_github_output('coverage_num', str(total_coverage)) + set_github_output('status', status) + set_github_output('emoji', emoji) + set_github_output('badge_color', badge_color) + + # Print package breakdown + if package_coverage: + print("\nCoverage by Package:") + for package, coverage in sorted(package_coverage.items()): + print(f" {package}: {coverage:.1f}%") + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/comment_pr.py b/.github/workflows/scripts/comment_pr.py new file mode 100755 index 00000000..a40d8a16 --- /dev/null +++ b/.github/workflows/scripts/comment_pr.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Post or update coverage report comment on GitHub Pull Request. + +This script generates a formatted coverage report comment and posts it to a PR, +or updates an existing coverage comment if one already exists. +""" + +import os +import sys +import json +import requests +from typing import Optional + + +def read_file(file_path: str) -> str: + """Read file content.""" + try: + with open(file_path, 'r') as f: + return f.read() + except FileNotFoundError: + print(f"Warning: File {file_path} not found", file=sys.stderr) + return "" + + +def generate_comment_body(coverage: str, emoji: str, status: str, + badge_color: str, coverage_report_path: str) -> str: + """ + Generate the PR comment body. + + Args: + coverage: Coverage percentage (e.g., "75.5%") + emoji: Status emoji + status: Status text + badge_color: Badge color + coverage_report_path: Path to detailed coverage report + + Returns: + Formatted comment body in markdown + """ + coverage_report = read_file(coverage_report_path) + + # URL encode the coverage percentage for the badge + coverage_encoded = coverage.replace('%', '%25') + + comment = f"""## {emoji} Go Test Coverage Report + +**Total Coverage:** `{coverage}` ({status}) + +![Coverage](https://img.shields.io/badge/coverage-{coverage_encoded}-{badge_color}) + +
+📊 Detailed Coverage Report (click to expand) + +{coverage_report} + +
+ +### Coverage Guidelines +- 🟢 >= 80%: Excellent +- 🟡 >= 60%: Good +- 🟠 >= 40%: Fair +- 🔴 < 40%: Needs improvement + +--- +*This is an automated coverage report. The coverage requirement is advisory and does not block PR merging.* +""" + return comment + + +def find_existing_comment(token: str, repo: str, pr_number: int) -> Optional[int]: + """ + Find existing coverage comment in the PR. + + Args: + token: GitHub token + repo: Repository in format "owner/repo" + pr_number: Pull request number + + Returns: + Comment ID if found, None otherwise + """ + url = f"https://api.github.com/repos/{repo}/issues/{pr_number}/comments" + headers = { + 'Authorization': f'token {token}', + 'Accept': 'application/vnd.github.v3+json' + } + + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + comments = response.json() + + # Look for existing coverage comment + for comment in comments: + if (comment.get('user', {}).get('type') == 'Bot' and + 'Go Test Coverage Report' in comment.get('body', '')): + return comment['id'] + + except requests.exceptions.RequestException as e: + print(f"Error fetching comments: {e}", file=sys.stderr) + + return None + + +def post_comment(token: str, repo: str, pr_number: int, body: str) -> bool: + """ + Post a new comment to the PR. + + Args: + token: GitHub token + repo: Repository in format "owner/repo" + pr_number: Pull request number + body: Comment body + + Returns: + True if successful, False otherwise + """ + url = f"https://api.github.com/repos/{repo}/issues/{pr_number}/comments" + headers = { + 'Authorization': f'token {token}', + 'Accept': 'application/vnd.github.v3+json' + } + data = {'body': body} + + try: + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + print("✅ Coverage comment posted successfully") + return True + except requests.exceptions.RequestException as e: + print(f"Error posting comment: {e}", file=sys.stderr) + if hasattr(e, 'response') and e.response is not None: + print(f"Response: {e.response.text}", file=sys.stderr) + return False + + +def update_comment(token: str, repo: str, comment_id: int, body: str) -> bool: + """ + Update an existing comment. + + Args: + token: GitHub token + repo: Repository in format "owner/repo" + comment_id: Comment ID to update + body: New comment body + + Returns: + True if successful, False otherwise + """ + url = f"https://api.github.com/repos/{repo}/issues/comments/{comment_id}" + headers = { + 'Authorization': f'token {token}', + 'Accept': 'application/vnd.github.v3+json' + } + data = {'body': body} + + try: + response = requests.patch(url, headers=headers, json=data) + response.raise_for_status() + print("✅ Coverage comment updated successfully") + return True + except requests.exceptions.RequestException as e: + print(f"Error updating comment: {e}", file=sys.stderr) + if hasattr(e, 'response') and e.response is not None: + print(f"Response: {e.response.text}", file=sys.stderr) + return False + + +def is_fork_pr(event_path: str) -> bool: + """ + Check if the PR is from a fork. + + Args: + event_path: Path to GitHub event JSON file + + Returns: + True if fork PR, False otherwise + """ + try: + with open(event_path, 'r') as f: + event = json.load(f) + + pr = event.get('pull_request', {}) + head_repo = pr.get('head', {}).get('repo', {}).get('full_name') + base_repo = pr.get('base', {}).get('repo', {}).get('full_name') + + return head_repo != base_repo + except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: + print(f"Warning: Could not determine if fork PR: {e}", file=sys.stderr) + return False + + +def main(): + """Main entry point.""" + # Get environment variables + token = os.environ.get('GITHUB_TOKEN') + repo = os.environ.get('GITHUB_REPOSITORY') + event_path = os.environ.get('GITHUB_EVENT_PATH', '') + + # Get arguments + if len(sys.argv) < 6: + print("Usage: comment_pr.py [coverage_report_path]", + file=sys.stderr) + sys.exit(1) + + pr_number = int(sys.argv[1]) + coverage = sys.argv[2] + emoji = sys.argv[3] + status = sys.argv[4] + badge_color = sys.argv[5] + coverage_report_path = sys.argv[6] if len(sys.argv) > 6 else 'coverage_report.md' + + # Validate environment + if not token: + print("Error: GITHUB_TOKEN environment variable not set", file=sys.stderr) + sys.exit(1) + + if not repo: + print("Error: GITHUB_REPOSITORY environment variable not set", file=sys.stderr) + sys.exit(1) + + # Check if fork PR + if event_path and is_fork_pr(event_path): + print("ℹ️ Fork PR detected - skipping comment (no write permissions)") + sys.exit(0) + + # Generate comment body + comment_body = generate_comment_body(coverage, emoji, status, badge_color, coverage_report_path) + + # Check for existing comment + existing_comment_id = find_existing_comment(token, repo, pr_number) + + # Post or update comment + if existing_comment_id: + print(f"Found existing comment (ID: {existing_comment_id}), updating...") + success = update_comment(token, repo, existing_comment_id, comment_body) + else: + print("No existing comment found, creating new one...") + success = post_comment(token, repo, pr_number, comment_body) + + sys.exit(0 if success else 1) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/requirements.txt b/.github/workflows/scripts/requirements.txt new file mode 100644 index 00000000..c606cb50 --- /dev/null +++ b/.github/workflows/scripts/requirements.txt @@ -0,0 +1,2 @@ +# Python dependencies for GitHub Actions scripts +requests>=2.31.0 diff --git a/.gitignore b/.gitignore index 80e2a6d7..ac5d3a0f 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,61 @@ rsa_key* # 加密相关 DATA_ENCRYPTION_KEY=* *.enc + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Python 虚拟环境 +.venv/ +venv/ +ENV/ +env/ +.env/ + +# uv +.uv/ +uv.lock + +# Pytest +.pytest_cache/ +.coverage +htmlcov/ +*.cover +.hypothesis/ + +# Jupyter Notebook +.ipynb_checkpoints +*.ipynb + +# pyenv +.python-version + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/go.mod b/go.mod index d48551d1..bee2e067 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.0 require ( github.com/adshao/go-binance/v2 v2.8.7 + github.com/agiledragon/gomonkey/v2 v2.13.0 github.com/ethereum/go-ethereum v1.16.5 github.com/gin-gonic/gin v1.11.0 github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 @@ -12,8 +13,10 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.5.1 github.com/pquerna/otp v1.4.0 + github.com/rs/zerolog v1.34.0 github.com/sirupsen/logrus v1.9.3 github.com/sonirico/go-hyperliquid v0.17.0 + github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.42.0 modernc.org/sqlite v1.40.0 ) @@ -29,6 +32,7 @@ require ( github.com/consensys/gnark-crypto v0.19.0 // indirect github.com/crate-crypto/go-eth-kzg v1.4.0 // indirect github.com/crate-crypto/go-ipa v0.0.0-20240724233137-53bbb0ceb27a // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/elastic/go-sysinfo v1.15.4 // indirect @@ -56,11 +60,11 @@ require ( github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/procfs v0.17.0 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.54.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rs/zerolog v1.34.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/sonirico/vago v0.9.0 // indirect github.com/sonirico/vago/lol v0.0.0-20250901170347-2d1d82c510bd // indirect @@ -83,6 +87,7 @@ require ( golang.org/x/text v0.29.0 // indirect golang.org/x/tools v0.36.0 // indirect google.golang.org/protobuf v1.36.9 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect howett.net/plist v1.0.1 // indirect modernc.org/libc v1.66.10 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index b6a17130..d394df3e 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDO github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/adshao/go-binance/v2 v2.8.7 h1:n7jkhwIHMdtd/9ZU2gTqFV15XVSbUCjyFlOUAtTd8uU= github.com/adshao/go-binance/v2 v2.8.7/go.mod h1:XkkuecSyJKPolaCGf/q4ovJYB3t0P+7RUYTbGr+LMGM= +github.com/agiledragon/gomonkey/v2 v2.13.0 h1:B24Jg6wBI1iB8EFR1c+/aoTg7QN/Cum7YffG8KMIyYo= +github.com/agiledragon/gomonkey/v2 v2.13.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY= github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y= @@ -88,6 +90,7 @@ github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17k github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA= @@ -101,6 +104,7 @@ github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2E github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -165,6 +169,8 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/sonirico/go-hyperliquid v0.17.0 h1:eXYACWupwu41O1VtKw17dqe9oOLQ1A2nRElGhg5Ox+4= github.com/sonirico/go-hyperliquid v0.17.0/go.mod h1:sH51Vsu+tPUwc95TL2MoQ8YXSewLWBEJirgzo7sZx6w= github.com/sonirico/vago v0.9.0 h1:DF2OWW2Aaf1xPZmnFv79kBrHmjKX3mVvMbP08vERlKo= @@ -209,29 +215,36 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/dnaeon/go-vcr.v4 v4.0.5 h1:I0hpTIvD5rII+8LgYGrHMA2d4SQPoL6u7ZvJakWKsiA= gopkg.in/dnaeon/go-vcr.v4 v4.0.5/go.mod h1:dRos81TkW9C1WJt6tTaE+uV2Lo8qJT3AG2b35+CB/nQ= gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..bbaeecdc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "nofx" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [] diff --git a/trader/aster_trader_test.go b/trader/aster_trader_test.go new file mode 100644 index 00000000..19a0b4a2 --- /dev/null +++ b/trader/aster_trader_test.go @@ -0,0 +1,299 @@ +package trader + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/assert" +) + +// ============================================================ +// 一、AsterTraderTestSuite - 继承 base test suite +// ============================================================ + +// AsterTraderTestSuite Aster交易器测试套件 +// 继承 TraderTestSuite 并添加 Aster 特定的 mock 逻辑 +type AsterTraderTestSuite struct { + *TraderTestSuite // 嵌入基础测试套件 + mockServer *httptest.Server +} + +// NewAsterTraderTestSuite 创建 Aster 测试套件 +func NewAsterTraderTestSuite(t *testing.T) *AsterTraderTestSuite { + // 创建 mock HTTP 服务器 + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 根据不同的 URL 路径返回不同的 mock 响应 + path := r.URL.Path + + var respBody interface{} + + switch { + // Mock GetBalance - /fapi/v3/balance (返回数组) + case path == "/fapi/v3/balance": + respBody = []map[string]interface{}{ + { + "asset": "USDT", + "walletBalance": "10000.00", + "unrealizedProfit": "100.50", + "marginBalance": "10100.50", + "maintMargin": "200.00", + "initialMargin": "2000.00", + "maxWithdrawAmount": "8000.00", + "crossWalletBalance": "10000.00", + "crossUnPnl": "100.50", + "availableBalance": "8000.00", + }, + } + + // Mock GetPositions - /fapi/v3/positionRisk + case path == "/fapi/v3/positionRisk": + respBody = []map[string]interface{}{ + { + "symbol": "BTCUSDT", + "positionAmt": "0.5", + "entryPrice": "50000.00", + "markPrice": "50500.00", + "unRealizedProfit": "250.00", + "liquidationPrice": "45000.00", + "leverage": "10", + "positionSide": "LONG", + }, + } + + // Mock GetMarketPrice - /fapi/v3/ticker/price (返回单个对象) + case path == "/fapi/v3/ticker/price": + // 从查询参数获取symbol + symbol := r.URL.Query().Get("symbol") + if symbol == "" { + symbol = "BTCUSDT" + } + // 根据symbol返回不同价格 + price := "50000.00" + if symbol == "ETHUSDT" { + price = "3000.00" + } else if symbol == "INVALIDUSDT" { + // 返回错误响应 + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]interface{}{ + "code": -1121, + "msg": "Invalid symbol", + }) + return + } + respBody = map[string]interface{}{ + "symbol": symbol, + "price": price, + } + + // Mock ExchangeInfo - /fapi/v3/exchangeInfo + case path == "/fapi/v3/exchangeInfo": + respBody = map[string]interface{}{ + "symbols": []map[string]interface{}{ + { + "symbol": "BTCUSDT", + "pricePrecision": 1, + "quantityPrecision": 3, + "baseAssetPrecision": 8, + "quotePrecision": 8, + "filters": []map[string]interface{}{ + { + "filterType": "PRICE_FILTER", + "tickSize": "0.1", + }, + { + "filterType": "LOT_SIZE", + "stepSize": "0.001", + }, + }, + }, + { + "symbol": "ETHUSDT", + "pricePrecision": 2, + "quantityPrecision": 3, + "baseAssetPrecision": 8, + "quotePrecision": 8, + "filters": []map[string]interface{}{ + { + "filterType": "PRICE_FILTER", + "tickSize": "0.01", + }, + { + "filterType": "LOT_SIZE", + "stepSize": "0.001", + }, + }, + }, + }, + } + + // Mock CreateOrder - /fapi/v1/order and /fapi/v3/order + case (path == "/fapi/v1/order" || path == "/fapi/v3/order") && r.Method == "POST": + // 从请求中解析参数以确定symbol + bodyBytes, _ := io.ReadAll(r.Body) + var orderParams map[string]interface{} + json.Unmarshal(bodyBytes, &orderParams) + + symbol := "BTCUSDT" + if s, ok := orderParams["symbol"].(string); ok { + symbol = s + } + + respBody = map[string]interface{}{ + "orderId": 123456, + "symbol": symbol, + "status": "FILLED", + "side": orderParams["side"], + "type": orderParams["type"], + } + + // Mock CancelOrder - /fapi/v1/order (DELETE) + case path == "/fapi/v1/order" && r.Method == "DELETE": + respBody = map[string]interface{}{ + "orderId": 123456, + "symbol": "BTCUSDT", + "status": "CANCELED", + } + + // Mock ListOpenOrders - /fapi/v1/openOrders and /fapi/v3/openOrders + case path == "/fapi/v1/openOrders" || path == "/fapi/v3/openOrders": + respBody = []map[string]interface{}{} + + // Mock SetLeverage - /fapi/v1/leverage + case path == "/fapi/v1/leverage": + respBody = map[string]interface{}{ + "leverage": 10, + "symbol": "BTCUSDT", + } + + // Mock SetMarginMode - /fapi/v1/marginType + case path == "/fapi/v1/marginType": + respBody = map[string]interface{}{ + "code": 200, + "msg": "success", + } + + // Default: empty response + default: + respBody = map[string]interface{}{} + } + + // 序列化响应 + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(respBody) + })) + + // 生成一个测试用的私钥 + privateKey, _ := crypto.GenerateKey() + + // 创建 mock trader,使用 mock server 的 URL + trader := &AsterTrader{ + ctx: context.Background(), + user: "0x1234567890123456789012345678901234567890", + signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + privateKey: privateKey, + client: mockServer.Client(), + baseURL: mockServer.URL, // 使用 mock server 的 URL + symbolPrecision: make(map[string]SymbolPrecision), + } + + // 创建基础套件 + baseSuite := NewTraderTestSuite(t, trader) + + return &AsterTraderTestSuite{ + TraderTestSuite: baseSuite, + mockServer: mockServer, + } +} + +// Cleanup 清理资源 +func (s *AsterTraderTestSuite) Cleanup() { + if s.mockServer != nil { + s.mockServer.Close() + } + s.TraderTestSuite.Cleanup() +} + +// ============================================================ +// 二、使用 AsterTraderTestSuite 运行通用测试 +// ============================================================ + +// TestAsterTrader_InterfaceCompliance 测试接口兼容性 +func TestAsterTrader_InterfaceCompliance(t *testing.T) { + var _ Trader = (*AsterTrader)(nil) +} + +// TestAsterTrader_CommonInterface 使用测试套件运行所有通用接口测试 +func TestAsterTrader_CommonInterface(t *testing.T) { + // 创建测试套件 + suite := NewAsterTraderTestSuite(t) + defer suite.Cleanup() + + // 运行所有通用接口测试 + suite.RunAllTests() +} + +// ============================================================ +// 三、Aster 特定功能的单元测试 +// ============================================================ + +// TestNewAsterTrader 测试创建 Aster 交易器 +func TestNewAsterTrader(t *testing.T) { + tests := []struct { + name string + user string + signer string + privateKeyHex string + wantError bool + errorContains string + }{ + { + name: "成功创建", + user: "0x1234567890123456789012345678901234567890", + signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + wantError: false, + }, + { + name: "无效私钥格式", + user: "0x1234567890123456789012345678901234567890", + signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + privateKeyHex: "invalid_key", + wantError: true, + errorContains: "解析私钥失败", + }, + { + name: "带0x前缀的私钥", + user: "0x1234567890123456789012345678901234567890", + signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + trader, err := NewAsterTrader(tt.user, tt.signer, tt.privateKeyHex) + + if tt.wantError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, trader) + } else { + assert.NoError(t, err) + assert.NotNil(t, trader) + if trader != nil { + assert.Equal(t, tt.user, trader.user) + assert.Equal(t, tt.signer, trader.signer) + assert.NotNil(t, trader.privateKey) + } + } + }) + } +} diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 0df685b1..e118fd7e 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -241,7 +241,7 @@ func (at *AutoTrader) Run() error { at.isRunning = true at.stopMonitorCh = make(chan struct{}) at.startTime = time.Now() - + log.Println("🚀 AI驱动自动交易系统启动") log.Printf("💰 初始余额: %.2f USDT", at.initialBalance) log.Printf("⚙️ 扫描间隔: %v", at.config.ScanInterval) diff --git a/trader/auto_trader_test.go b/trader/auto_trader_test.go index 40d2e562..09a2c428 100644 --- a/trader/auto_trader_test.go +++ b/trader/auto_trader_test.go @@ -1,70 +1,1164 @@ package trader import ( + "errors" + "fmt" "math" "testing" + "time" + + "nofx/decision" + "nofx/logger" + "nofx/market" + "nofx/pool" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/suite" ) +// ============================================================ +// AutoTraderTestSuite - 使用 testify/suite 进行结构化测试 +// ============================================================ + +// AutoTraderTestSuite 是 AutoTrader 的测试套件 +// 使用 testify/suite 来组织测试,提供统一的 setup/teardown 和 mock 管理 +type AutoTraderTestSuite struct { + suite.Suite + + // 测试对象 + autoTrader *AutoTrader + + // Mock 依赖 + mockTrader *MockTrader + mockDB *MockDatabase + mockLogger *logger.DecisionLogger + + // gomonkey patches + patches *gomonkey.Patches + + // 测试配置 + config AutoTraderConfig +} + +// SetupSuite 在整个测试套件开始前执行一次 +func (s *AutoTraderTestSuite) SetupSuite() { + // 可以在这里初始化一些全局资源 +} + +// TearDownSuite 在整个测试套件结束后执行一次 +func (s *AutoTraderTestSuite) TearDownSuite() { + // 清理全局资源 +} + +// SetupTest 在每个测试用例开始前执行 +func (s *AutoTraderTestSuite) SetupTest() { + // 初始化 patches + s.patches = gomonkey.NewPatches() + + // 创建 mock 对象 + s.mockTrader = &MockTrader{ + balance: map[string]interface{}{ + "totalWalletBalance": 10000.0, + "availableBalance": 8000.0, + "totalUnrealizedProfit": 100.0, + }, + positions: []map[string]interface{}{}, + } + + s.mockDB = &MockDatabase{} + + // 创建临时决策日志记录器 + s.mockLogger = logger.NewDecisionLogger("/tmp/test_decision_logs") + + // 设置默认配置 + s.config = AutoTraderConfig{ + ID: "test_trader", + Name: "Test Trader", + AIModel: "deepseek", + Exchange: "binance", + InitialBalance: 10000.0, + ScanInterval: 3 * time.Minute, + SystemPromptTemplate: "adaptive", + BTCETHLeverage: 10, + AltcoinLeverage: 5, + IsCrossMargin: true, + } + + // 创建 AutoTrader 实例(直接构造,不调用 NewAutoTrader 以避免外部依赖) + s.autoTrader = &AutoTrader{ + id: s.config.ID, + name: s.config.Name, + aiModel: s.config.AIModel, + exchange: s.config.Exchange, + config: s.config, + trader: s.mockTrader, + mcpClient: nil, // 测试中不需要实际的 MCP Client + decisionLogger: s.mockLogger, + initialBalance: s.config.InitialBalance, + systemPromptTemplate: s.config.SystemPromptTemplate, + defaultCoins: []string{"BTC", "ETH"}, + tradingCoins: []string{}, + lastResetTime: time.Now(), + startTime: time.Now(), + callCount: 0, + isRunning: false, + positionFirstSeenTime: make(map[string]int64), + stopMonitorCh: make(chan struct{}), + peakPnLCache: make(map[string]float64), + lastBalanceSyncTime: time.Now(), + database: s.mockDB, + userID: "test_user", + } +} + +// TearDownTest 在每个测试用例结束后执行 +func (s *AutoTraderTestSuite) TearDownTest() { + // 重置 gomonkey patches + if s.patches != nil { + s.patches.Reset() + } +} + +// ============================================================ +// 层次 1: 工具函数测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() { + tests := []struct { + name string + input []decision.Decision + }{ + { + name: "混合决策_验证优先级排序", + input: []decision.Decision{ + {Action: "open_long", Symbol: "BTCUSDT"}, + {Action: "close_short", Symbol: "ETHUSDT"}, + {Action: "hold", Symbol: "BNBUSDT"}, + {Action: "update_stop_loss", Symbol: "SOLUSDT"}, + {Action: "open_short", Symbol: "ADAUSDT"}, + {Action: "partial_close", Symbol: "DOGEUSDT"}, + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := sortDecisionsByPriority(tt.input) + + s.Equal(len(tt.input), len(result), "结果长度应该相同") + + // 验证优先级是否递增 + getActionPriority := func(action string) int { + switch action { + case "close_long", "close_short", "partial_close": + return 1 + case "update_stop_loss", "update_take_profit": + return 2 + case "open_long", "open_short": + return 3 + case "hold", "wait": + return 4 + default: + return 999 + } + } + + for i := 0; i < len(result)-1; i++ { + currentPriority := getActionPriority(result[i].Action) + nextPriority := getActionPriority(result[i+1].Action) + s.LessOrEqual(currentPriority, nextPriority, "优先级应该递增") + } + }) + } +} + +func (s *AutoTraderTestSuite) TestNormalizeSymbol() { + tests := []struct { + name string + input string + expected string + }{ + {"已经是标准格式", "BTCUSDT", "BTCUSDT"}, + {"小写转大写", "btcusdt", "BTCUSDT"}, + {"只有币种名称_添加USDT", "BTC", "BTCUSDT"}, + {"带空格_去除空格", " BTC ", "BTCUSDT"}, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := normalizeSymbol(tt.input) + s.Equal(tt.expected, result) + }) + } +} + +// ============================================================ +// 层次 2: Getter/Setter 测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestGettersAndSetters() { + s.Run("GetID", func() { + s.Equal("test_trader", s.autoTrader.GetID()) + }) + + s.Run("GetName", func() { + s.Equal("Test Trader", s.autoTrader.GetName()) + }) + + s.Run("SetSystemPromptTemplate", func() { + s.autoTrader.SetSystemPromptTemplate("aggressive") + s.Equal("aggressive", s.autoTrader.GetSystemPromptTemplate()) + }) + + s.Run("SetCustomPrompt", func() { + s.autoTrader.SetCustomPrompt("custom prompt") + s.Equal("custom prompt", s.autoTrader.customPrompt) + }) +} + +// ============================================================ +// 层次 3: PeakPnL 缓存测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestPeakPnLCache() { + s.Run("UpdatePeakPnL_首次记录", func() { + s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 10.5) + cache := s.autoTrader.GetPeakPnLCache() + s.Equal(10.5, cache["BTCUSDT_long"]) + }) + + s.Run("UpdatePeakPnL_更新为更高值", func() { + s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 15.0) + cache := s.autoTrader.GetPeakPnLCache() + s.Equal(15.0, cache["BTCUSDT_long"]) + }) + + s.Run("UpdatePeakPnL_不更新为更低值", func() { + s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 12.0) + cache := s.autoTrader.GetPeakPnLCache() + s.Equal(15.0, cache["BTCUSDT_long"], "峰值应保持不变") + }) + + s.Run("ClearPeakPnLCache", func() { + s.autoTrader.ClearPeakPnLCache("BTCUSDT", "long") + cache := s.autoTrader.GetPeakPnLCache() + _, exists := cache["BTCUSDT_long"] + s.False(exists, "应该被清除") + }) +} + +// ============================================================ +// 层次 4: GetStatus 测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestGetStatus() { + s.autoTrader.isRunning = true + s.autoTrader.callCount = 15 + + status := s.autoTrader.GetStatus() + + s.Equal("test_trader", status["trader_id"]) + s.Equal("Test Trader", status["trader_name"]) + s.Equal("deepseek", status["ai_model"]) + s.Equal("binance", status["exchange"]) + s.True(status["is_running"].(bool)) + s.Equal(15, status["call_count"]) + s.Equal(10000.0, status["initial_balance"]) +} + +// ============================================================ +// 层次 5: GetAccountInfo 测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestGetAccountInfo() { + accountInfo, err := s.autoTrader.GetAccountInfo() + + s.NoError(err) + s.NotNil(accountInfo) + + // 验证核心字段和数值 + s.Equal(10100.0, accountInfo["total_equity"]) // 10000 + 100 + s.Equal(8000.0, accountInfo["available_balance"]) + s.Equal(100.0, accountInfo["total_pnl"]) // 10100 - 10000 +} + +// ============================================================ +// 层次 6: GetPositions 测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestGetPositions() { + s.Run("空持仓", func() { + positions, err := s.autoTrader.GetPositions() + + s.NoError(err) + // positions 可能是 nil 或空数组,两者都是有效的 + if positions != nil { + s.Equal(0, len(positions)) + } + }) + + s.Run("有持仓", func() { + // 设置 mock 持仓 + s.mockTrader.positions = []map[string]interface{}{ + { + "symbol": "BTCUSDT", + "side": "long", + "entryPrice": 50000.0, + "markPrice": 51000.0, + "positionAmt": 0.1, + "unRealizedProfit": 100.0, + "liquidationPrice": 45000.0, + "leverage": 10.0, + }, + } + + positions, err := s.autoTrader.GetPositions() + + s.NoError(err) + s.Equal(1, len(positions)) + + pos := positions[0] + s.Equal("BTCUSDT", pos["symbol"]) + s.Equal("long", pos["side"]) + s.Equal(0.1, pos["quantity"]) + s.Equal(50000.0, pos["entry_price"]) + }) +} + +// ============================================================ +// 层次 7: getCandidateCoins 测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestGetCandidateCoins() { + s.Run("使用数据库默认币种", func() { + s.autoTrader.defaultCoins = []string{"BTC", "ETH", "BNB"} + s.autoTrader.tradingCoins = []string{} // 空的自定义币种 + + coins, err := s.autoTrader.getCandidateCoins() + + s.NoError(err) + s.Equal(3, len(coins)) + s.Equal("BTCUSDT", coins[0].Symbol) + s.Equal("ETHUSDT", coins[1].Symbol) + s.Equal("BNBUSDT", coins[2].Symbol) + s.Contains(coins[0].Sources, "default") + }) + + s.Run("使用自定义币种", func() { + s.autoTrader.tradingCoins = []string{"SOL", "AVAX"} + + coins, err := s.autoTrader.getCandidateCoins() + + s.NoError(err) + s.Equal(2, len(coins)) + s.Equal("SOLUSDT", coins[0].Symbol) + s.Equal("AVAXUSDT", coins[1].Symbol) + s.Contains(coins[0].Sources, "custom") + }) + + s.Run("使用AI500+OI作为fallback", func() { + s.autoTrader.defaultCoins = []string{} // 空的默认币种 + s.autoTrader.tradingCoins = []string{} // 空的自定义币种 + + // Mock pool.GetMergedCoinPool + s.patches.ApplyFunc(pool.GetMergedCoinPool, func(ai500Limit int) (*pool.MergedCoinPool, error) { + return &pool.MergedCoinPool{ + AllSymbols: []string{"BTCUSDT", "ETHUSDT"}, + SymbolSources: map[string][]string{ + "BTCUSDT": {"ai500", "oi_top"}, + "ETHUSDT": {"ai500"}, + }, + }, nil + }) + + coins, err := s.autoTrader.getCandidateCoins() + + s.NoError(err) + s.Equal(2, len(coins)) + }) +} + +// ============================================================ +// 层次 8: buildTradingContext 测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestBuildTradingContext() { + // Mock market.Get + s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { + return &market.Data{Symbol: symbol, CurrentPrice: 50000.0}, nil + }) + + ctx, err := s.autoTrader.buildTradingContext() + + s.NoError(err) + s.NotNil(ctx) + + // 验证核心字段 + s.Equal(10100.0, ctx.Account.TotalEquity) // 10000 + 100 + s.Equal(8000.0, ctx.Account.AvailableBalance) + s.Equal(10, ctx.BTCETHLeverage) + s.Equal(5, ctx.AltcoinLeverage) +} + +// ============================================================ +// 层次 9: 交易执行测试 +// ============================================================ + +// TestExecuteOpenPosition 测试开仓操作(多空通用) +func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { + tests := []struct { + name string + action string + expectedOrder int64 + existingSide string + availBalance float64 + expectedErr string + executeFn func(*decision.Decision, *logger.DecisionAction) error + }{ + { + name: "成功开多仓", + action: "open_long", + expectedOrder: 123456, + availBalance: 8000.0, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeOpenLongWithRecord(d, a) + }, + }, + { + name: "成功开空仓", + action: "open_short", + expectedOrder: 123457, + availBalance: 8000.0, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeOpenShortWithRecord(d, a) + }, + }, + { + name: "多仓_保证金不足", + action: "open_long", + availBalance: 0.0, + expectedErr: "保证金不足", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeOpenLongWithRecord(d, a) + }, + }, + { + name: "空仓_保证金不足", + action: "open_short", + availBalance: 0.0, + expectedErr: "保证金不足", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeOpenShortWithRecord(d, a) + }, + }, + { + name: "多仓_已有同方向持仓", + action: "open_long", + existingSide: "long", + availBalance: 8000.0, + expectedErr: "已有多仓", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeOpenLongWithRecord(d, a) + }, + }, + { + name: "空仓_已有同方向持仓", + action: "open_short", + existingSide: "short", + availBalance: 8000.0, + expectedErr: "已有空仓", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeOpenShortWithRecord(d, a) + }, + }, + } + + for _, tt := range tests { + time.Sleep(time.Millisecond) + s.Run(tt.name, func() { + s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { + return &market.Data{Symbol: symbol, CurrentPrice: 50000.0}, nil + }) + + s.mockTrader.balance["availableBalance"] = tt.availBalance + if tt.existingSide != "" { + s.mockTrader.positions = []map[string]interface{}{{"symbol": "BTCUSDT", "side": tt.existingSide}} + } else { + s.mockTrader.positions = []map[string]interface{}{} + } + + decision := &decision.Decision{Action: tt.action, Symbol: "BTCUSDT", PositionSizeUSD: 1000.0, Leverage: 10} + actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"} + + err := tt.executeFn(decision, actionRecord) + + if tt.expectedErr != "" { + s.Error(err) + s.Contains(err.Error(), tt.expectedErr) + } else { + s.NoError(err) + s.Equal(tt.expectedOrder, actionRecord.OrderID) + s.Greater(actionRecord.Quantity, 0.0) + s.Equal(50000.0, actionRecord.Price) + } + + // 恢复默认状态 + s.mockTrader.balance["availableBalance"] = 8000.0 + s.mockTrader.positions = []map[string]interface{}{} + }) + } +} + +// TestExecuteClosePosition 测试平仓操作(多空通用) +func (s *AutoTraderTestSuite) TestExecuteClosePosition() { + tests := []struct { + name string + action string + currentPrice float64 + expectedOrder int64 + executeFn func(*decision.Decision, *logger.DecisionAction) error + }{ + { + name: "成功平多仓", + action: "close_long", + currentPrice: 51000.0, + expectedOrder: 123458, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeCloseLongWithRecord(d, a) + }, + }, + { + name: "成功平空仓", + action: "close_short", + currentPrice: 49000.0, + expectedOrder: 123459, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeCloseShortWithRecord(d, a) + }, + }, + } + + for _, tt := range tests { + time.Sleep(time.Millisecond) + s.Run(tt.name, func() { + s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { + return &market.Data{Symbol: symbol, CurrentPrice: tt.currentPrice}, nil + }) + + decision := &decision.Decision{Action: tt.action, Symbol: "BTCUSDT"} + actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"} + + err := tt.executeFn(decision, actionRecord) + + s.NoError(err) + s.Equal(tt.expectedOrder, actionRecord.OrderID) + s.Equal(tt.currentPrice, actionRecord.Price) + }) + } +} + +// TestExecuteUpdateStopOrTakeProfit 测试更新止损/止盈(多空通用) +func (s *AutoTraderTestSuite) TestExecuteUpdateStopOrTakeProfit() { + // 使用指针变量来控制 market.Get 的返回值 + var testPrice *float64 + s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { + price := 50000.0 + if testPrice != nil { + price = *testPrice + } + return &market.Data{Symbol: symbol, CurrentPrice: price}, nil + }) + + tests := []struct { + name string + action string + symbol string + side string + currentPrice float64 + newPrice float64 + hasPosition bool + expectedErr string + executeFn func(*decision.Decision, *logger.DecisionAction) error + }{ + { + name: "成功更新多头止损", + action: "update_stop_loss", + symbol: "BTCUSDT", + side: "long", + currentPrice: 52000.0, + newPrice: 51000.0, + hasPosition: true, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateStopLossWithRecord(d, a) + }, + }, + { + name: "成功更新空头止损", + action: "update_stop_loss", + symbol: "ETHUSDT", + side: "short", + currentPrice: 2900.0, + newPrice: 2950.0, + hasPosition: true, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateStopLossWithRecord(d, a) + }, + }, + { + name: "成功更新多头止盈", + action: "update_take_profit", + symbol: "BTCUSDT", + side: "long", + currentPrice: 52000.0, + newPrice: 55000.0, + hasPosition: true, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) + }, + }, + { + name: "成功更新空头止盈", + action: "update_take_profit", + symbol: "ETHUSDT", + side: "short", + currentPrice: 2900.0, + newPrice: 2800.0, + hasPosition: true, + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) + }, + }, + { + name: "多头止损价格不合理", + action: "update_stop_loss", + symbol: "BTCUSDT", + side: "long", + currentPrice: 50000.0, + newPrice: 51000.0, + hasPosition: true, + expectedErr: "多单止损必须低于当前价格", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateStopLossWithRecord(d, a) + }, + }, + { + name: "多头止盈价格不合理", + action: "update_take_profit", + symbol: "BTCUSDT", + side: "long", + currentPrice: 50000.0, + newPrice: 49000.0, + hasPosition: true, + expectedErr: "多单止盈必须高于当前价格", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) + }, + }, + { + name: "止损_持仓不存在", + action: "update_stop_loss", + symbol: "BTCUSDT", + currentPrice: 50000.0, + newPrice: 49000.0, + hasPosition: false, + expectedErr: "持仓不存在", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateStopLossWithRecord(d, a) + }, + }, + { + name: "止盈_持仓不存在", + action: "update_take_profit", + symbol: "BTCUSDT", + currentPrice: 50000.0, + newPrice: 55000.0, + hasPosition: false, + expectedErr: "持仓不存在", + executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) + }, + }, + } + + for _, tt := range tests { + time.Sleep(time.Millisecond) + s.Run(tt.name, func() { + // 设置当前测试用例的价格 + testPrice = &tt.currentPrice + + if tt.hasPosition { + s.mockTrader.positions = []map[string]interface{}{ + {"symbol": tt.symbol, "side": tt.side, "positionAmt": 0.1}, + } + } else { + s.mockTrader.positions = []map[string]interface{}{} + } + + decision := &decision.Decision{Action: tt.action, Symbol: tt.symbol} + if tt.action == "update_stop_loss" { + decision.NewStopLoss = tt.newPrice + } else { + decision.NewTakeProfit = tt.newPrice + } + actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: tt.symbol} + + err := tt.executeFn(decision, actionRecord) + + if tt.expectedErr != "" { + s.Error(err) + s.Contains(err.Error(), tt.expectedErr) + } else { + s.NoError(err) + s.Equal(tt.currentPrice, actionRecord.Price) + } + + // 恢复默认状态 + s.mockTrader.positions = []map[string]interface{}{} + }) + } +} + +func (s *AutoTraderTestSuite) TestExecutePartialCloseWithRecord() { + s.Run("成功部分平仓", func() { + // 设置持仓 + s.mockTrader.positions = []map[string]interface{}{ + { + "symbol": "BTCUSDT", + "side": "long", + "positionAmt": 0.1, + "entryPrice": 50000.0, + "markPrice": 52000.0, + }, + } + + // Mock market.Get + s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { + return &market.Data{ + Symbol: symbol, + CurrentPrice: 52000.0, + }, nil + }) + + decision := &decision.Decision{ + Action: "partial_close", + Symbol: "BTCUSDT", + ClosePercentage: 50.0, + } + + actionRecord := &logger.DecisionAction{ + Action: "partial_close", + Symbol: "BTCUSDT", + } + + err := s.autoTrader.executePartialCloseWithRecord(decision, actionRecord) + + s.NoError(err) + s.Equal(0.05, actionRecord.Quantity) // 50% of 0.1 + }) + + s.Run("无效的平仓百分比", func() { + decision := &decision.Decision{ + Action: "partial_close", + Symbol: "BTCUSDT", + ClosePercentage: 150.0, // 无效 + } + + actionRecord := &logger.DecisionAction{} + + err := s.autoTrader.executePartialCloseWithRecord(decision, actionRecord) + + s.Error(err) + s.Contains(err.Error(), "平仓百分比必须在 0-100 之间") + }) +} + +// ============================================================ +// 层次 10: executeDecisionWithRecord 路由测试 +// ============================================================ + +func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() { + // Mock market.Get + s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { + return &market.Data{ + Symbol: symbol, + CurrentPrice: 50000.0, + }, nil + }) + + s.Run("路由到open_long", func() { + decision := &decision.Decision{ + Action: "open_long", + Symbol: "BTCUSDT", + PositionSizeUSD: 1000.0, + Leverage: 10, + } + actionRecord := &logger.DecisionAction{} + + err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) + s.NoError(err) + }) + + s.Run("路由到close_long", func() { + decision := &decision.Decision{ + Action: "close_long", + Symbol: "BTCUSDT", + } + actionRecord := &logger.DecisionAction{} + + err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) + s.NoError(err) + }) + + s.Run("路由到hold_不执行", func() { + decision := &decision.Decision{ + Action: "hold", + Symbol: "BTCUSDT", + } + actionRecord := &logger.DecisionAction{} + + err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) + s.NoError(err) + }) + + s.Run("未知action返回错误", func() { + decision := &decision.Decision{ + Action: "unknown_action", + Symbol: "BTCUSDT", + } + actionRecord := &logger.DecisionAction{} + + err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) + s.Error(err) + s.Contains(err.Error(), "未知的action") + }) +} + +func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() { + tests := []struct { + name string + setupPositions func() + setupPeakPnL func() + setupFailures func() + cleanupFailures func() + expectedCacheKey string + shouldClearCache bool + skipCacheCheck bool + }{ + { + name: "获取持仓失败_不panic", + setupFailures: func() { s.mockTrader.shouldFailPositions = true }, + cleanupFailures: func() { s.mockTrader.shouldFailPositions = false }, + skipCacheCheck: true, + }, + { + name: "无持仓_不panic", + setupPositions: func() { s.mockTrader.positions = []map[string]interface{}{} }, + skipCacheCheck: true, + }, + { + name: "收益不足5%_不触发平仓", + setupPositions: func() { + s.mockTrader.positions = []map[string]interface{}{ + {"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50150.0, "leverage": 10.0}, + } + }, + setupPeakPnL: func() { s.autoTrader.ClearPeakPnLCache("BTCUSDT", "long") }, + skipCacheCheck: true, + }, + { + name: "回撤不足40%_不触发平仓", + setupPositions: func() { + s.mockTrader.positions = []map[string]interface{}{ + {"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50400.0, "leverage": 10.0}, + } + }, + setupPeakPnL: func() { s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 10.0) }, + skipCacheCheck: true, + }, + { + name: "多头_触发回撤平仓", + setupPositions: func() { + s.mockTrader.positions = []map[string]interface{}{ + {"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50300.0, "leverage": 10.0}, + } + }, + setupPeakPnL: func() { s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 10.0) }, + expectedCacheKey: "BTCUSDT_long", + shouldClearCache: true, + }, + { + name: "空头_触发回撤平仓", + setupPositions: func() { + s.mockTrader.positions = []map[string]interface{}{ + {"symbol": "ETHUSDT", "side": "short", "positionAmt": -0.5, "entryPrice": 3000.0, "markPrice": 2982.0, "leverage": 10.0}, + } + }, + setupPeakPnL: func() { s.autoTrader.UpdatePeakPnL("ETHUSDT", "short", 10.0) }, + expectedCacheKey: "ETHUSDT_short", + shouldClearCache: true, + }, + { + name: "多头_平仓失败_保留缓存", + setupPositions: func() { + s.mockTrader.positions = []map[string]interface{}{ + {"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50300.0, "leverage": 10.0}, + } + }, + setupPeakPnL: func() { s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 10.0) }, + setupFailures: func() { s.mockTrader.shouldFailCloseLong = true }, + cleanupFailures: func() { s.mockTrader.shouldFailCloseLong = false }, + expectedCacheKey: "BTCUSDT_long", + shouldClearCache: false, + }, + { + name: "空头_平仓失败_保留缓存", + setupPositions: func() { + s.mockTrader.positions = []map[string]interface{}{ + {"symbol": "ETHUSDT", "side": "short", "positionAmt": -0.5, "entryPrice": 3000.0, "markPrice": 2982.0, "leverage": 10.0}, + } + }, + setupPeakPnL: func() { s.autoTrader.UpdatePeakPnL("ETHUSDT", "short", 10.0) }, + setupFailures: func() { s.mockTrader.shouldFailCloseShort = true }, + cleanupFailures: func() { s.mockTrader.shouldFailCloseShort = false }, + expectedCacheKey: "ETHUSDT_short", + shouldClearCache: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + if tt.setupPositions != nil { + tt.setupPositions() + } + if tt.setupPeakPnL != nil { + tt.setupPeakPnL() + } + if tt.setupFailures != nil { + tt.setupFailures() + } + if tt.cleanupFailures != nil { + defer tt.cleanupFailures() + } + + s.autoTrader.checkPositionDrawdown() + + if !tt.skipCacheCheck { + cache := s.autoTrader.GetPeakPnLCache() + _, exists := cache[tt.expectedCacheKey] + if tt.shouldClearCache { + s.False(exists, "峰值缓存应该被清理") + } else { + s.True(exists, "峰值缓存不应该被清理") + } + } + + // 清理状态 + s.mockTrader.positions = []map[string]interface{}{} + }) + } +} + +// ============================================================ +// Mock 实现 +// ============================================================ + +// MockDatabase 模拟数据库 +type MockDatabase struct { + shouldFail bool +} + +func (m *MockDatabase) UpdateTraderInitialBalance(userID, traderID string, newBalance float64) error { + if m.shouldFail { + return errors.New("database error") + } + return nil +} + +// MockTrader 增强版(添加错误控制) +type MockTrader struct { + balance map[string]interface{} + positions []map[string]interface{} + shouldFailBalance bool + shouldFailPositions bool + shouldFailOpenLong bool + shouldFailCloseLong bool + shouldFailCloseShort bool +} + +func (m *MockTrader) GetBalance() (map[string]interface{}, error) { + if m.shouldFailBalance { + return nil, errors.New("failed to get balance") + } + if m.balance == nil { + return map[string]interface{}{ + "totalWalletBalance": 10000.0, + "availableBalance": 8000.0, + "totalUnrealizedProfit": 100.0, + }, nil + } + return m.balance, nil +} + +func (m *MockTrader) GetPositions() ([]map[string]interface{}, error) { + if m.shouldFailPositions { + return nil, errors.New("failed to get positions") + } + if m.positions == nil { + return []map[string]interface{}{}, nil + } + return m.positions, nil +} + +func (m *MockTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + if m.shouldFailOpenLong { + return nil, errors.New("failed to open long") + } + return map[string]interface{}{ + "orderId": int64(123456), + "symbol": symbol, + }, nil +} + +func (m *MockTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { + return map[string]interface{}{ + "orderId": int64(123457), + "symbol": symbol, + }, nil +} + +func (m *MockTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { + if m.shouldFailCloseLong { + return nil, errors.New("failed to close long") + } + return map[string]interface{}{ + "orderId": int64(123458), + "symbol": symbol, + }, nil +} + +func (m *MockTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { + if m.shouldFailCloseShort { + return nil, errors.New("failed to close short") + } + return map[string]interface{}{ + "orderId": int64(123459), + "symbol": symbol, + }, nil +} + +func (m *MockTrader) SetLeverage(symbol string, leverage int) error { + return nil +} + +func (m *MockTrader) SetMarginMode(symbol string, isCrossMargin bool) error { + return nil +} + +func (m *MockTrader) GetMarketPrice(symbol string) (float64, error) { + return 50000.0, nil +} + +func (m *MockTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { + return nil +} + +func (m *MockTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { + return nil +} + +func (m *MockTrader) CancelStopLossOrders(symbol string) error { + return nil +} + +func (m *MockTrader) CancelTakeProfitOrders(symbol string) error { + return nil +} + +func (m *MockTrader) CancelAllOrders(symbol string) error { + return nil +} + +func (m *MockTrader) CancelStopOrders(symbol string) error { + return nil +} + +func (m *MockTrader) FormatQuantity(symbol string, quantity float64) (string, error) { + return fmt.Sprintf("%.4f", quantity), nil +} + +// ============================================================ +// 测试套件入口 +// ============================================================ + +// TestAutoTraderTestSuite 运行 AutoTrader 测试套件 +func TestAutoTraderTestSuite(t *testing.T) { + suite.Run(t, new(AutoTraderTestSuite)) +} + +// ============================================================ +// 独立的单元测试 - calculatePnLPercentage 函数测试 +// ============================================================ + func TestCalculatePnLPercentage(t *testing.T) { tests := []struct { - name string - unrealizedPnl float64 - marginUsed float64 - expected float64 + name string + unrealizedPnl float64 + marginUsed float64 + expected float64 }{ { - name: "正常盈利 - 10倍杠杆", - unrealizedPnl: 100.0, // 盈利 100 USDT - marginUsed: 1000.0, // 保证金 1000 USDT - expected: 10.0, // 10% 收益率 + name: "正常盈利 - 10倍杠杆", + unrealizedPnl: 100.0, // 盈利 100 USDT + marginUsed: 1000.0, // 保证金 1000 USDT + expected: 10.0, // 10% 收益率 }, { - name: "正常亏损 - 10倍杠杆", - unrealizedPnl: -50.0, // 亏损 50 USDT - marginUsed: 1000.0, // 保证金 1000 USDT - expected: -5.0, // -5% 收益率 + name: "正常亏损 - 10倍杠杆", + unrealizedPnl: -50.0, // 亏损 50 USDT + marginUsed: 1000.0, // 保证金 1000 USDT + expected: -5.0, // -5% 收益率 }, { - name: "高杠杆盈利 - 价格上涨1%,20倍杠杆", - unrealizedPnl: 200.0, // 盈利 200 USDT - marginUsed: 1000.0, // 保证金 1000 USDT - expected: 20.0, // 20% 收益率 + name: "高杠杆盈利 - 价格上涨1%,20倍杠杆", + unrealizedPnl: 200.0, // 盈利 200 USDT + marginUsed: 1000.0, // 保证金 1000 USDT + expected: 20.0, // 20% 收益率 }, { - name: "保证金为0 - 边界情况", - unrealizedPnl: 100.0, - marginUsed: 0.0, - expected: 0.0, // 应该返回 0 而不是除以零错误 + name: "保证金为0 - 边界情况", + unrealizedPnl: 100.0, + marginUsed: 0.0, + expected: 0.0, // 应该返回 0 而不是除以零错误 }, { - name: "负保证金 - 边界情况", - unrealizedPnl: 100.0, - marginUsed: -1000.0, - expected: 0.0, // 应该返回 0(异常情况) + name: "负保证金 - 边界情况", + unrealizedPnl: 100.0, + marginUsed: -1000.0, + expected: 0.0, // 应该返回 0(异常情况) }, { - name: "盈亏为0", - unrealizedPnl: 0.0, - marginUsed: 1000.0, - expected: 0.0, + name: "盈亏为0", + unrealizedPnl: 0.0, + marginUsed: 1000.0, + expected: 0.0, }, { - name: "小额交易", - unrealizedPnl: 0.5, - marginUsed: 10.0, - expected: 5.0, + name: "小额交易", + unrealizedPnl: 0.5, + marginUsed: 10.0, + expected: 5.0, }, { - name: "大额盈利", - unrealizedPnl: 5000.0, - marginUsed: 10000.0, - expected: 50.0, + name: "大额盈利", + unrealizedPnl: 5000.0, + marginUsed: 10000.0, + expected: 50.0, }, { - name: "极小保证金", - unrealizedPnl: 1.0, - marginUsed: 0.01, - expected: 10000.0, // 100倍收益率 + name: "极小保证金", + unrealizedPnl: 1.0, + marginUsed: 0.01, + expected: 10000.0, // 100倍收益率 }, } diff --git a/trader/binance_futures_test.go b/trader/binance_futures_test.go new file mode 100644 index 00000000..6f9e2987 --- /dev/null +++ b/trader/binance_futures_test.go @@ -0,0 +1,420 @@ +package trader + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/adshao/go-binance/v2/futures" + "github.com/stretchr/testify/assert" +) + +// ============================================================ +// 一、BinanceFuturesTestSuite - 继承 base test suite +// ============================================================ + +// BinanceFuturesTestSuite 币安合约交易器测试套件 +// 继承 TraderTestSuite 并添加 Binance Futures 特定的 mock 逻辑 +type BinanceFuturesTestSuite struct { + *TraderTestSuite // 嵌入基础测试套件 + mockServer *httptest.Server +} + +// NewBinanceFuturesTestSuite 创建币安合约测试套件 +func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite { + // 创建 mock HTTP 服务器 + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 根据不同的 URL 路径返回不同的 mock 响应 + path := r.URL.Path + + var respBody interface{} + + switch { + // Mock GetBalance - /fapi/v2/balance + case path == "/fapi/v2/balance": + respBody = []map[string]interface{}{ + { + "accountAlias": "test", + "asset": "USDT", + "balance": "10000.00", + "crossWalletBalance": "10000.00", + "crossUnPnl": "100.50", + "availableBalance": "8000.00", + "maxWithdrawAmount": "8000.00", + }, + } + + // Mock GetAccount - /fapi/v2/account + case path == "/fapi/v2/account": + respBody = map[string]interface{}{ + "totalWalletBalance": "10000.00", + "availableBalance": "8000.00", + "totalUnrealizedProfit": "100.50", + "assets": []map[string]interface{}{ + { + "asset": "USDT", + "walletBalance": "10000.00", + "unrealizedProfit": "100.50", + "marginBalance": "10100.50", + "maintMargin": "200.00", + "initialMargin": "2000.00", + "positionInitialMargin": "2000.00", + "openOrderInitialMargin": "0.00", + "crossWalletBalance": "10000.00", + "crossUnPnl": "100.50", + "availableBalance": "8000.00", + "maxWithdrawAmount": "8000.00", + }, + }, + } + + // Mock GetPositions - /fapi/v2/positionRisk + case path == "/fapi/v2/positionRisk": + respBody = []map[string]interface{}{ + { + "symbol": "BTCUSDT", + "positionAmt": "0.5", + "entryPrice": "50000.00", + "markPrice": "50500.00", + "unRealizedProfit": "250.00", + "liquidationPrice": "45000.00", + "leverage": "10", + "positionSide": "LONG", + }, + } + + // Mock GetMarketPrice - /fapi/v1/ticker/price and /fapi/v2/ticker/price + case path == "/fapi/v1/ticker/price" || path == "/fapi/v2/ticker/price": + symbol := r.URL.Query().Get("symbol") + if symbol == "" { + // 返回所有价格 + respBody = []map[string]interface{}{ + {"Symbol": "BTCUSDT", "Price": "50000.00", "Time": 1234567890}, + {"Symbol": "ETHUSDT", "Price": "3000.00", "Time": 1234567890}, + } + } else if symbol == "INVALIDUSDT" { + // 返回错误 + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]interface{}{ + "code": -1121, + "msg": "Invalid symbol.", + }) + return + } else { + // 返回单个价格(注意:即使有 symbol 参数,也要返回数组) + price := "50000.00" + if symbol == "ETHUSDT" { + price = "3000.00" + } + respBody = []map[string]interface{}{ + { + "Symbol": symbol, + "Price": price, + "Time": 1234567890, + }, + } + } + + // Mock ExchangeInfo - /fapi/v1/exchangeInfo + case path == "/fapi/v1/exchangeInfo": + respBody = map[string]interface{}{ + "symbols": []map[string]interface{}{ + { + "symbol": "BTCUSDT", + "status": "TRADING", + "baseAsset": "BTC", + "quoteAsset": "USDT", + "pricePrecision": 2, + "quantityPrecision": 3, + "baseAssetPrecision": 8, + "quotePrecision": 8, + "filters": []map[string]interface{}{ + { + "filterType": "PRICE_FILTER", + "minPrice": "0.01", + "maxPrice": "1000000", + "tickSize": "0.01", + }, + { + "filterType": "LOT_SIZE", + "minQty": "0.001", + "maxQty": "10000", + "stepSize": "0.001", + }, + }, + }, + { + "symbol": "ETHUSDT", + "status": "TRADING", + "baseAsset": "ETH", + "quoteAsset": "USDT", + "pricePrecision": 2, + "quantityPrecision": 3, + "baseAssetPrecision": 8, + "quotePrecision": 8, + "filters": []map[string]interface{}{ + { + "filterType": "PRICE_FILTER", + "minPrice": "0.01", + "maxPrice": "100000", + "tickSize": "0.01", + }, + { + "filterType": "LOT_SIZE", + "minQty": "0.001", + "maxQty": "10000", + "stepSize": "0.001", + }, + }, + }, + }, + } + + // Mock CreateOrder - /fapi/v1/order (POST) + case path == "/fapi/v1/order" && r.Method == "POST": + symbol := r.FormValue("symbol") + if symbol == "" { + symbol = "BTCUSDT" + } + respBody = map[string]interface{}{ + "orderId": 123456, + "symbol": symbol, + "status": "FILLED", + "clientOrderId": r.FormValue("newClientOrderId"), + "price": r.FormValue("price"), + "avgPrice": r.FormValue("price"), + "origQty": r.FormValue("quantity"), + "executedQty": r.FormValue("quantity"), + "cumQty": r.FormValue("quantity"), + "cumQuote": "1000.00", + "timeInForce": r.FormValue("timeInForce"), + "type": r.FormValue("type"), + "reduceOnly": r.FormValue("reduceOnly") == "true", + "side": r.FormValue("side"), + "positionSide": r.FormValue("positionSide"), + "stopPrice": r.FormValue("stopPrice"), + "workingType": r.FormValue("workingType"), + } + + // Mock CancelOrder - /fapi/v1/order (DELETE) + case path == "/fapi/v1/order" && r.Method == "DELETE": + respBody = map[string]interface{}{ + "orderId": 123456, + "symbol": r.URL.Query().Get("symbol"), + "status": "CANCELED", + } + + // Mock ListOpenOrders - /fapi/v1/openOrders + case path == "/fapi/v1/openOrders": + respBody = []map[string]interface{}{} + + // Mock CancelAllOrders - /fapi/v1/allOpenOrders (DELETE) + case path == "/fapi/v1/allOpenOrders" && r.Method == "DELETE": + respBody = map[string]interface{}{ + "code": 200, + "msg": "The operation of cancel all open order is done.", + } + + // Mock SetLeverage - /fapi/v1/leverage + case path == "/fapi/v1/leverage": + // 将字符串转换为整数 + leverageStr := r.FormValue("leverage") + leverage := 10 // 默认值 + if leverageStr != "" { + // 注意:这里我们直接返回整数,而不是字符串 + fmt.Sscanf(leverageStr, "%d", &leverage) + } + respBody = map[string]interface{}{ + "leverage": leverage, + "maxNotionalValue": "1000000", + "symbol": r.FormValue("symbol"), + } + + // Mock SetMarginType - /fapi/v1/marginType + case path == "/fapi/v1/marginType": + respBody = map[string]interface{}{ + "code": 200, + "msg": "success", + } + + // Mock ChangePositionMode - /fapi/v1/positionSide/dual + case path == "/fapi/v1/positionSide/dual": + respBody = map[string]interface{}{ + "code": 200, + "msg": "success", + } + + // Mock ServerTime - /fapi/v1/time + case path == "/fapi/v1/time": + respBody = map[string]interface{}{ + "serverTime": 1234567890000, + } + + // Default: empty response + default: + respBody = map[string]interface{}{} + } + + // 序列化响应 + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(respBody) + })) + + // 创建 futures.Client 并设置为使用 mock 服务器 + client := futures.NewClient("test_api_key", "test_secret_key") + client.BaseURL = mockServer.URL + client.HTTPClient = mockServer.Client() + + // 创建 FuturesTrader + trader := &FuturesTrader{ + client: client, + cacheDuration: 0, // 禁用缓存以便测试 + } + + // 创建基础套件 + baseSuite := NewTraderTestSuite(t, trader) + + return &BinanceFuturesTestSuite{ + TraderTestSuite: baseSuite, + mockServer: mockServer, + } +} + +// Cleanup 清理资源 +func (s *BinanceFuturesTestSuite) Cleanup() { + if s.mockServer != nil { + s.mockServer.Close() + } + s.TraderTestSuite.Cleanup() +} + +// ============================================================ +// 二、使用 BinanceFuturesTestSuite 运行通用测试 +// ============================================================ + +// TestFuturesTrader_InterfaceCompliance 测试接口兼容性 +func TestFuturesTrader_InterfaceCompliance(t *testing.T) { + var _ Trader = (*FuturesTrader)(nil) +} + +// TestFuturesTrader_CommonInterface 使用测试套件运行所有通用接口测试 +func TestFuturesTrader_CommonInterface(t *testing.T) { + // 创建测试套件 + suite := NewBinanceFuturesTestSuite(t) + defer suite.Cleanup() + + // 运行所有通用接口测试 + suite.RunAllTests() +} + +// ============================================================ +// 三、币安合约特定功能的单元测试 +// ============================================================ + +// TestNewFuturesTrader 测试创建币安合约交易器 +func TestNewFuturesTrader(t *testing.T) { + // 创建 mock HTTP 服务器 + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + + var respBody interface{} + + switch path { + case "/fapi/v1/time": + respBody = map[string]interface{}{ + "serverTime": 1234567890000, + } + case "/fapi/v1/positionSide/dual": + respBody = map[string]interface{}{ + "code": 200, + "msg": "success", + } + default: + respBody = map[string]interface{}{} + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(respBody) + })) + defer mockServer.Close() + + // 测试成功创建 + trader := NewFuturesTrader("test_api_key", "test_secret_key", "test_user") + + // 修改 client 使用 mock server + trader.client.BaseURL = mockServer.URL + trader.client.HTTPClient = mockServer.Client() + + assert.NotNil(t, trader) + assert.NotNil(t, trader.client) + assert.Equal(t, 15*time.Second, trader.cacheDuration) +} + +// TestCalculatePositionSize 测试仓位计算 +func TestCalculatePositionSize(t *testing.T) { + trader := &FuturesTrader{} + + tests := []struct { + name string + balance float64 + riskPercent float64 + price float64 + leverage int + wantQuantity float64 + }{ + { + name: "正常计算", + balance: 10000, + riskPercent: 2, + price: 50000, + leverage: 10, + wantQuantity: 0.04, // (10000 * 0.02 * 10) / 50000 = 0.04 + }, + { + name: "高杠杆", + balance: 10000, + riskPercent: 1, + price: 3000, + leverage: 20, + wantQuantity: 0.6667, // (10000 * 0.01 * 20) / 3000 = 0.6667 + }, + { + name: "低风险", + balance: 5000, + riskPercent: 0.5, + price: 50000, + leverage: 5, + wantQuantity: 0.0025, // (5000 * 0.005 * 5) / 50000 = 0.0025 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + quantity := trader.CalculatePositionSize(tt.balance, tt.riskPercent, tt.price, tt.leverage) + assert.InDelta(t, tt.wantQuantity, quantity, 0.0001, "计算的仓位数量不正确") + }) + } +} + +// TestGetBrOrderID 测试订单ID生成 +func TestGetBrOrderID(t *testing.T) { + // 测试3次,确保每次生成的ID都不同 + ids := make(map[string]bool) + for i := 0; i < 3; i++ { + id := getBrOrderID() + + // 检查格式 + assert.True(t, strings.HasPrefix(id, "x-KzrpZaP9"), "订单ID应以x-KzrpZaP9开头") + + // 检查长度(应该 <= 32) + assert.LessOrEqual(t, len(id), 32, "订单ID长度不应超过32字符") + + // 检查唯一性 + assert.False(t, ids[id], "订单ID应该唯一") + ids[id] = true + } +} diff --git a/trader/hyperliquid_trader_test.go b/trader/hyperliquid_trader_test.go new file mode 100644 index 00000000..b50f842a --- /dev/null +++ b/trader/hyperliquid_trader_test.go @@ -0,0 +1,646 @@ +package trader + +import ( + "context" + "crypto/ecdsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/sonirico/go-hyperliquid" + "github.com/stretchr/testify/assert" +) + +// ============================================================ +// 一、HyperliquidTestSuite - 继承 base test suite +// ============================================================ + +// HyperliquidTestSuite Hyperliquid 交易器测试套件 +// 继承 TraderTestSuite 并添加 Hyperliquid 特定的 mock 逻辑 +type HyperliquidTestSuite struct { + *TraderTestSuite // 嵌入基础测试套件 + mockServer *httptest.Server + privateKey *ecdsa.PrivateKey +} + +// NewHyperliquidTestSuite 创建 Hyperliquid 测试套件 +func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite { + // 创建测试用私钥 + privateKey, err := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") + if err != nil { + t.Fatalf("创建测试私钥失败: %v", err) + } + + // 创建 mock HTTP 服务器 + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 根据不同的请求路径返回不同的 mock 响应 + var respBody interface{} + + // Hyperliquid API 使用 POST 请求,请求体是 JSON + // 我们需要根据请求体中的 "type" 字段来区分不同的请求 + var reqBody map[string]interface{} + if r.Method == "POST" { + json.NewDecoder(r.Body).Decode(&reqBody) + } + + // Try to get type from top level first, then from action object + reqType, _ := reqBody["type"].(string) + if reqType == "" && reqBody["action"] != nil { + if action, ok := reqBody["action"].(map[string]interface{}); ok { + reqType, _ = action["type"].(string) + } + } + + switch reqType { + // Mock Meta - 获取市场元数据 + case "meta": + respBody = map[string]interface{}{ + "universe": []map[string]interface{}{ + { + "name": "BTC", + "szDecimals": 4, + "maxLeverage": 50, + "onlyIsolated": false, + "isDelisted": false, + "marginTableId": 0, + }, + { + "name": "ETH", + "szDecimals": 3, + "maxLeverage": 50, + "onlyIsolated": false, + "isDelisted": false, + "marginTableId": 0, + }, + }, + "marginTables": []interface{}{}, + } + + // Mock UserState - 获取用户账户状态(用于 GetBalance 和 GetPositions) + case "clearinghouseState": + user, _ := reqBody["user"].(string) + + // 检查是否是查询 Agent 钱包余额(用于安全检查) + agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex() + if user == agentAddr { + // Agent 钱包余额应该很低 + respBody = map[string]interface{}{ + "crossMarginSummary": map[string]interface{}{ + "accountValue": "5.00", + "totalMarginUsed": "0.00", + }, + "withdrawable": "5.00", + "assetPositions": []interface{}{}, + } + } else { + // 主钱包账户状态 + respBody = map[string]interface{}{ + "crossMarginSummary": map[string]interface{}{ + "accountValue": "10000.00", + "totalMarginUsed": "2000.00", + }, + "withdrawable": "8000.00", + "assetPositions": []map[string]interface{}{ + { + "position": map[string]interface{}{ + "coin": "BTC", + "szi": "0.5", + "entryPx": "50000.00", + "liquidationPx": "45000.00", + "positionValue": "25000.00", + "unrealizedPnl": "100.50", + "leverage": map[string]interface{}{ + "type": "cross", + "value": 10, + }, + }, + }, + }, + } + } + + // Mock SpotUserState - 获取现货账户状态 + case "spotClearinghouseState": + respBody = map[string]interface{}{ + "balances": []map[string]interface{}{ + { + "coin": "USDC", + "total": "500.00", + }, + }, + } + + // Mock SpotMeta - 获取现货市场元数据 + case "spotMeta": + respBody = map[string]interface{}{ + "universe": []map[string]interface{}{}, + "tokens": []map[string]interface{}{}, + } + + // Mock AllMids - 获取所有市场价格 + case "allMids": + respBody = map[string]string{ + "BTC": "50000.00", + "ETH": "3000.00", + } + + // Mock OpenOrders - 获取挂单列表 + case "openOrders": + respBody = []interface{}{} + + // Mock Order - 创建订单(开仓、平仓、止损、止盈) + case "order": + respBody = map[string]interface{}{ + "status": "ok", + "response": map[string]interface{}{ + "type": "order", + "data": map[string]interface{}{ + "statuses": []map[string]interface{}{ + { + "filled": map[string]interface{}{ + "totalSz": "0.01", + "avgPx": "50000.00", + }, + }, + }, + }, + }, + } + + // Mock UpdateLeverage - 设置杠杆 + case "updateLeverage": + respBody = map[string]interface{}{ + "status": "ok", + } + + // Mock Cancel - 取消订单 + case "cancel": + respBody = map[string]interface{}{ + "status": "ok", + } + + default: + // 默认返回成功响应 + respBody = map[string]interface{}{ + "status": "ok", + } + } + + // 序列化响应 + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(respBody) + })) + + // 创建 HyperliquidTrader,使用 mock 服务器 URL + walletAddr := "0x9999999999999999999999999999999999999999" + ctx := context.Background() + + // 创建 Exchange 客户端,指向 mock 服务器 + exchange := hyperliquid.NewExchange( + ctx, + privateKey, + mockServer.URL, // 使用 mock 服务器 URL + nil, + "", + walletAddr, + nil, + ) + + // 创建 meta(模拟获取成功) + meta := &hyperliquid.Meta{ + Universe: []hyperliquid.AssetInfo{ + {Name: "BTC", SzDecimals: 4}, + {Name: "ETH", SzDecimals: 3}, + }, + } + + trader := &HyperliquidTrader{ + exchange: exchange, + ctx: ctx, + walletAddr: walletAddr, + meta: meta, + isCrossMargin: true, + } + + // 创建基础套件 + baseSuite := NewTraderTestSuite(t, trader) + + return &HyperliquidTestSuite{ + TraderTestSuite: baseSuite, + mockServer: mockServer, + privateKey: privateKey, + } +} + +// Cleanup 清理资源 +func (s *HyperliquidTestSuite) Cleanup() { + if s.mockServer != nil { + s.mockServer.Close() + } + s.TraderTestSuite.Cleanup() +} + +// ============================================================ +// 二、使用 HyperliquidTestSuite 运行通用测试 +// ============================================================ + +// TestHyperliquidTrader_InterfaceCompliance 测试接口兼容性 +func TestHyperliquidTrader_InterfaceCompliance(t *testing.T) { + var _ Trader = (*HyperliquidTrader)(nil) +} + +// TestHyperliquidTrader_CommonInterface 使用测试套件运行所有通用接口测试 +func TestHyperliquidTrader_CommonInterface(t *testing.T) { + // 创建测试套件 + suite := NewHyperliquidTestSuite(t) + defer suite.Cleanup() + + // 运行所有通用接口测试 + suite.RunAllTests() +} + +// ============================================================ +// 三、Hyperliquid 特定功能的单元测试 +// ============================================================ + +// TestNewHyperliquidTrader 测试创建 Hyperliquid 交易器 +func TestNewHyperliquidTrader(t *testing.T) { + tests := []struct { + name string + privateKeyHex string + walletAddr string + testnet bool + wantError bool + errorContains string + }{ + { + name: "无效私钥格式", + privateKeyHex: "invalid_key", + walletAddr: "0x1234567890123456789012345678901234567890", + testnet: true, + wantError: true, + errorContains: "解析私钥失败", + }, + { + name: "钱包地址为空", + privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + walletAddr: "", + testnet: true, + wantError: true, + errorContains: "Configuration error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + trader, err := NewHyperliquidTrader(tt.privateKeyHex, tt.walletAddr, tt.testnet) + + if tt.wantError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, trader) + } else { + assert.NoError(t, err) + assert.NotNil(t, trader) + if trader != nil { + assert.Equal(t, tt.walletAddr, trader.walletAddr) + assert.NotNil(t, trader.exchange) + } + } + }) + } +} + +// TestNewHyperliquidTrader_Success 测试成功创建交易器(需要 mock HTTP) +func TestNewHyperliquidTrader_Success(t *testing.T) { + // 创建测试用私钥 + privateKey, _ := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") + agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex() + + // 创建 mock HTTP 服务器 + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + reqType, _ := reqBody["type"].(string) + + var respBody interface{} + switch reqType { + case "meta": + respBody = map[string]interface{}{ + "universe": []map[string]interface{}{ + { + "name": "BTC", + "szDecimals": 4, + "maxLeverage": 50, + "onlyIsolated": false, + "isDelisted": false, + "marginTableId": 0, + }, + }, + "marginTables": []interface{}{}, + } + case "clearinghouseState": + user, _ := reqBody["user"].(string) + if user == agentAddr { + // Agent 钱包余额低 + respBody = map[string]interface{}{ + "crossMarginSummary": map[string]interface{}{ + "accountValue": "5.00", + }, + "assetPositions": []interface{}{}, + } + } else { + // 主钱包 + respBody = map[string]interface{}{ + "crossMarginSummary": map[string]interface{}{ + "accountValue": "10000.00", + }, + "assetPositions": []interface{}{}, + } + } + default: + respBody = map[string]interface{}{"status": "ok"} + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(respBody) + })) + defer mockServer.Close() + + // 注意:这个测试会真正调用 NewHyperliquidTrader,但会失败 + // 因为 hyperliquid SDK 不允许我们在构造函数中注入自定义 URL + // 所以这个测试仅用于验证参数处理逻辑 + t.Skip("跳过此测试:hyperliquid SDK 在构造时会调用真实 API,无法注入 mock URL") +} + +// ============================================================ +// 四、工具函数单元测试(Hyperliquid 特有) +// ============================================================ + +// TestConvertSymbolToHyperliquid 测试 symbol 转换函数 +func TestConvertSymbolToHyperliquid(t *testing.T) { + tests := []struct { + name string + symbol string + expected string + }{ + { + name: "BTCUSDT转换", + symbol: "BTCUSDT", + expected: "BTC", + }, + { + name: "ETHUSDT转换", + symbol: "ETHUSDT", + expected: "ETH", + }, + { + name: "无USDT后缀", + symbol: "BTC", + expected: "BTC", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertSymbolToHyperliquid(tt.symbol) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestAbsFloat 测试绝对值函数 +func TestAbsFloat(t *testing.T) { + tests := []struct { + name string + input float64 + expected float64 + }{ + { + name: "正数", + input: 10.5, + expected: 10.5, + }, + { + name: "负数", + input: -10.5, + expected: 10.5, + }, + { + name: "零", + input: 0, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := absFloat(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestHyperliquidTrader_RoundToSzDecimals 测试数量精度处理 +func TestHyperliquidTrader_RoundToSzDecimals(t *testing.T) { + trader := &HyperliquidTrader{ + meta: &hyperliquid.Meta{ + Universe: []hyperliquid.AssetInfo{ + {Name: "BTC", SzDecimals: 4}, + {Name: "ETH", SzDecimals: 3}, + }, + }, + } + + tests := []struct { + name string + coin string + quantity float64 + expected float64 + }{ + { + name: "BTC_四舍五入到4位", + coin: "BTC", + quantity: 1.23456789, + expected: 1.2346, + }, + { + name: "ETH_四舍五入到3位", + coin: "ETH", + quantity: 10.12345, + expected: 10.123, + }, + { + name: "未知币种_使用默认精度4位", + coin: "UNKNOWN", + quantity: 1.23456789, + expected: 1.2346, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := trader.roundToSzDecimals(tt.coin, tt.quantity) + assert.InDelta(t, tt.expected, result, 0.0001) + }) + } +} + +// TestHyperliquidTrader_RoundPriceToSigfigs 测试价格有效数字处理 +func TestHyperliquidTrader_RoundPriceToSigfigs(t *testing.T) { + trader := &HyperliquidTrader{} + + tests := []struct { + name string + price float64 + expected float64 + }{ + { + name: "BTC价格_5位有效数字", + price: 50123.456789, + expected: 50123.0, + }, + { + name: "小数价格_5位有效数字", + price: 0.0012345678, + expected: 0.0012346, + }, + { + name: "零价格", + price: 0, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := trader.roundPriceToSigfigs(tt.price) + assert.InDelta(t, tt.expected, result, tt.expected*0.001) + }) + } +} + +// TestHyperliquidTrader_GetSzDecimals 测试获取精度 +func TestHyperliquidTrader_GetSzDecimals(t *testing.T) { + tests := []struct { + name string + meta *hyperliquid.Meta + coin string + expected int + }{ + { + name: "meta为nil_返回默认精度", + meta: nil, + coin: "BTC", + expected: 4, + }, + { + name: "找到BTC_返回正确精度", + meta: &hyperliquid.Meta{ + Universe: []hyperliquid.AssetInfo{ + {Name: "BTC", SzDecimals: 5}, + }, + }, + coin: "BTC", + expected: 5, + }, + { + name: "未找到币种_返回默认精度", + meta: &hyperliquid.Meta{ + Universe: []hyperliquid.AssetInfo{ + {Name: "ETH", SzDecimals: 3}, + }, + }, + coin: "BTC", + expected: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + trader := &HyperliquidTrader{meta: tt.meta} + result := trader.getSzDecimals(tt.coin) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestHyperliquidTrader_SetMarginMode 测试设置保证金模式 +func TestHyperliquidTrader_SetMarginMode(t *testing.T) { + trader := &HyperliquidTrader{ + ctx: context.Background(), + isCrossMargin: true, + } + + tests := []struct { + name string + symbol string + isCrossMargin bool + wantError bool + }{ + { + name: "设置为全仓模式", + symbol: "BTCUSDT", + isCrossMargin: true, + wantError: false, + }, + { + name: "设置为逐仓模式", + symbol: "ETHUSDT", + isCrossMargin: false, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := trader.SetMarginMode(tt.symbol, tt.isCrossMargin) + + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.isCrossMargin, trader.isCrossMargin) + } + }) + } +} + +// TestNewHyperliquidTrader_PrivateKeyProcessing 测试私钥处理 +func TestNewHyperliquidTrader_PrivateKeyProcessing(t *testing.T) { + tests := []struct { + name string + privateKeyHex string + shouldStripOx bool + expectedLength int + }{ + { + name: "带0x前缀的私钥", + privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + shouldStripOx: true, + expectedLength: 64, + }, + { + name: "无前缀的私钥", + privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + shouldStripOx: false, + expectedLength: 64, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 测试私钥前缀处理逻辑(不实际创建 trader) + processed := tt.privateKeyHex + if len(processed) > 2 && (processed[:2] == "0x" || processed[:2] == "0X") { + processed = processed[2:] + } + + assert.Equal(t, tt.expectedLength, len(processed)) + }) + } +} diff --git a/trader/trader_test_suite.go b/trader/trader_test_suite.go new file mode 100644 index 00000000..67f2db8b --- /dev/null +++ b/trader/trader_test_suite.go @@ -0,0 +1,664 @@ +package trader + +import ( + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" +) + +// TraderTestSuite 通用的 Trader 接口测试套件(基础套件) +// 用于黑盒测试任何实现了 Trader 接口的交易器 +// +// 使用方式: +// 1. 创建具体的测试套件结构体,嵌入 TraderTestSuite +// 2. 实现 SetupMocks() 方法来配置 gomonkey mock +// 3. 调用 RunAllTests() 运行所有通用测试 +type TraderTestSuite struct { + T *testing.T + Trader Trader + Patches *gomonkey.Patches +} + +// NewTraderTestSuite 创建新的基础测试套件 +func NewTraderTestSuite(t *testing.T, trader Trader) *TraderTestSuite { + return &TraderTestSuite{ + T: t, + Trader: trader, + Patches: gomonkey.NewPatches(), + } +} + +// Cleanup 清理 mock patches +func (s *TraderTestSuite) Cleanup() { + if s.Patches != nil { + s.Patches.Reset() + } +} + +// RunAllTests 运行所有通用接口测试 +// 注意:调用此方法前,请先通过 SetupMocks 设置好所需的 mock +func (s *TraderTestSuite) RunAllTests() { + // 基础查询方法 + s.T.Run("GetBalance", func(t *testing.T) { s.TestGetBalance() }) + s.T.Run("GetPositions", func(t *testing.T) { s.TestGetPositions() }) + s.T.Run("GetMarketPrice", func(t *testing.T) { s.TestGetMarketPrice() }) + + // 配置方法 + s.T.Run("SetLeverage", func(t *testing.T) { s.TestSetLeverage() }) + s.T.Run("SetMarginMode", func(t *testing.T) { s.TestSetMarginMode() }) + s.T.Run("FormatQuantity", func(t *testing.T) { s.TestFormatQuantity() }) + + // 核心交易方法 + s.T.Run("OpenLong", func(t *testing.T) { s.TestOpenLong() }) + s.T.Run("OpenShort", func(t *testing.T) { s.TestOpenShort() }) + s.T.Run("CloseLong", func(t *testing.T) { s.TestCloseLong() }) + s.T.Run("CloseShort", func(t *testing.T) { s.TestCloseShort() }) + + // 止损止盈 + s.T.Run("SetStopLoss", func(t *testing.T) { s.TestSetStopLoss() }) + s.T.Run("SetTakeProfit", func(t *testing.T) { s.TestSetTakeProfit() }) + + // 订单管理 + s.T.Run("CancelAllOrders", func(t *testing.T) { s.TestCancelAllOrders() }) + s.T.Run("CancelStopOrders", func(t *testing.T) { s.TestCancelStopOrders() }) + s.T.Run("CancelStopLossOrders", func(t *testing.T) { s.TestCancelStopLossOrders() }) + s.T.Run("CancelTakeProfitOrders", func(t *testing.T) { s.TestCancelTakeProfitOrders() }) +} + +// TestGetBalance 测试获取账户余额 +func (s *TraderTestSuite) TestGetBalance() { + tests := []struct { + name string + wantError bool + validate func(*testing.T, map[string]interface{}) + }{ + { + name: "成功获取余额", + wantError: false, + validate: func(t *testing.T, result map[string]interface{}) { + assert.NotNil(t, result) + assert.Contains(t, result, "totalWalletBalance") + assert.Contains(t, result, "availableBalance") + }, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + result, err := s.Trader.GetBalance() + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +} + +// TestGetPositions 测试获取持仓 +func (s *TraderTestSuite) TestGetPositions() { + tests := []struct { + name string + wantError bool + validate func(*testing.T, []map[string]interface{}) + }{ + { + name: "成功获取持仓列表", + wantError: false, + validate: func(t *testing.T, positions []map[string]interface{}) { + assert.NotNil(t, positions) + // 持仓可以为空数组 + for _, pos := range positions { + assert.Contains(t, pos, "symbol") + assert.Contains(t, pos, "side") + assert.Contains(t, pos, "positionAmt") + } + }, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + result, err := s.Trader.GetPositions() + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +} + +// TestGetMarketPrice 测试获取市场价格 +func (s *TraderTestSuite) TestGetMarketPrice() { + tests := []struct { + name string + symbol string + wantError bool + validate func(*testing.T, float64) + }{ + { + name: "成功获取BTC价格", + symbol: "BTCUSDT", + wantError: false, + validate: func(t *testing.T, price float64) { + assert.Greater(t, price, 0.0) + }, + }, + { + name: "无效交易对返回错误", + symbol: "INVALIDUSDT", + wantError: true, + validate: nil, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + price, err := s.Trader.GetMarketPrice(tt.symbol) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, price) + } + } + }) + } +} + +// TestSetLeverage 测试设置杠杆 +func (s *TraderTestSuite) TestSetLeverage() { + tests := []struct { + name string + symbol string + leverage int + wantError bool + }{ + { + name: "设置10倍杠杆", + symbol: "BTCUSDT", + leverage: 10, + wantError: false, + }, + { + name: "设置1倍杠杆", + symbol: "ETHUSDT", + leverage: 1, + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.SetLeverage(tt.symbol, tt.leverage) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestSetMarginMode 测试设置仓位模式 +func (s *TraderTestSuite) TestSetMarginMode() { + tests := []struct { + name string + symbol string + isCrossMargin bool + wantError bool + }{ + { + name: "设置全仓模式", + symbol: "BTCUSDT", + isCrossMargin: true, + wantError: false, + }, + { + name: "设置逐仓模式", + symbol: "ETHUSDT", + isCrossMargin: false, + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.SetMarginMode(tt.symbol, tt.isCrossMargin) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestFormatQuantity 测试数量格式化 +func (s *TraderTestSuite) TestFormatQuantity() { + tests := []struct { + name string + symbol string + quantity float64 + wantError bool + validate func(*testing.T, string) + }{ + { + name: "格式化BTC数量", + symbol: "BTCUSDT", + quantity: 1.23456789, + wantError: false, + validate: func(t *testing.T, result string) { + assert.NotEmpty(t, result) + }, + }, + { + name: "格式化小数量", + symbol: "ETHUSDT", + quantity: 0.001, + wantError: false, + validate: func(t *testing.T, result string) { + assert.NotEmpty(t, result) + }, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + result, err := s.Trader.FormatQuantity(tt.symbol, tt.quantity) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +} + +// TestCancelAllOrders 测试取消所有订单 +func (s *TraderTestSuite) TestCancelAllOrders() { + tests := []struct { + name string + symbol string + wantError bool + }{ + { + name: "取消BTC所有订单", + symbol: "BTCUSDT", + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.CancelAllOrders(tt.symbol) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// ============================================================ +// 核心交易方法测试 +// ============================================================ + +// TestOpenLong 测试开多仓 +func (s *TraderTestSuite) TestOpenLong() { + tests := []struct { + name string + symbol string + quantity float64 + leverage int + wantError bool + validate func(*testing.T, map[string]interface{}) + }{ + { + name: "成功开多仓", + symbol: "BTCUSDT", + quantity: 0.01, + leverage: 10, + wantError: false, + validate: func(t *testing.T, result map[string]interface{}) { + assert.NotNil(t, result) + assert.Contains(t, result, "symbol") + assert.Equal(t, "BTCUSDT", result["symbol"]) + }, + }, + { + name: "小数量开仓", + symbol: "ETHUSDT", + quantity: 0.004, // 增加到 0.004 以满足 Binance Futures 的 10 USDT 最小订单金额要求 (0.004 * 3000 = 12 USDT) + leverage: 5, + wantError: false, + validate: func(t *testing.T, result map[string]interface{}) { + assert.NotNil(t, result) + }, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + result, err := s.Trader.OpenLong(tt.symbol, tt.quantity, tt.leverage) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +} + +// TestOpenShort 测试开空仓 +func (s *TraderTestSuite) TestOpenShort() { + tests := []struct { + name string + symbol string + quantity float64 + leverage int + wantError bool + validate func(*testing.T, map[string]interface{}) + }{ + { + name: "成功开空仓", + symbol: "BTCUSDT", + quantity: 0.01, + leverage: 10, + wantError: false, + validate: func(t *testing.T, result map[string]interface{}) { + assert.NotNil(t, result) + assert.Contains(t, result, "symbol") + assert.Equal(t, "BTCUSDT", result["symbol"]) + }, + }, + { + name: "小数量开空仓", + symbol: "ETHUSDT", + quantity: 0.004, // 增加到 0.004 以满足 Binance Futures 的 10 USDT 最小订单金额要求 (0.004 * 3000 = 12 USDT) + leverage: 5, + wantError: false, + validate: func(t *testing.T, result map[string]interface{}) { + assert.NotNil(t, result) + }, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + result, err := s.Trader.OpenShort(tt.symbol, tt.quantity, tt.leverage) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +} + +// TestCloseLong 测试平多仓 +func (s *TraderTestSuite) TestCloseLong() { + tests := []struct { + name string + symbol string + quantity float64 + wantError bool + validate func(*testing.T, map[string]interface{}) + }{ + { + name: "平指定数量", + symbol: "BTCUSDT", + quantity: 0.01, + wantError: false, + validate: func(t *testing.T, result map[string]interface{}) { + assert.NotNil(t, result) + assert.Contains(t, result, "symbol") + }, + }, + { + name: "全部平仓_quantity为0_无持仓返回错误", + symbol: "ETHUSDT", + quantity: 0, + wantError: true, // 当没有持仓时,quantity=0 应该返回错误 + validate: nil, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + result, err := s.Trader.CloseLong(tt.symbol, tt.quantity) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +} + +// TestCloseShort 测试平空仓 +func (s *TraderTestSuite) TestCloseShort() { + tests := []struct { + name string + symbol string + quantity float64 + wantError bool + validate func(*testing.T, map[string]interface{}) + }{ + { + name: "平指定数量", + symbol: "BTCUSDT", + quantity: 0.01, + wantError: false, + validate: func(t *testing.T, result map[string]interface{}) { + assert.NotNil(t, result) + assert.Contains(t, result, "symbol") + }, + }, + { + name: "全部平仓_quantity为0_无持仓返回错误", + symbol: "ETHUSDT", + quantity: 0, + wantError: true, // 当没有持仓时,quantity=0 应该返回错误 + validate: nil, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + result, err := s.Trader.CloseShort(tt.symbol, tt.quantity) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, result) + } + } + }) + } +} + +// ============================================================ +// 止损止盈测试 +// ============================================================ + +// TestSetStopLoss 测试设置止损 +func (s *TraderTestSuite) TestSetStopLoss() { + tests := []struct { + name string + symbol string + positionSide string + quantity float64 + stopPrice float64 + wantError bool + }{ + { + name: "多头止损", + symbol: "BTCUSDT", + positionSide: "LONG", + quantity: 0.01, + stopPrice: 45000.0, + wantError: false, + }, + { + name: "空头止损", + symbol: "ETHUSDT", + positionSide: "SHORT", + quantity: 0.1, + stopPrice: 3200.0, + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.SetStopLoss(tt.symbol, tt.positionSide, tt.quantity, tt.stopPrice) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestSetTakeProfit 测试设置止盈 +func (s *TraderTestSuite) TestSetTakeProfit() { + tests := []struct { + name string + symbol string + positionSide string + quantity float64 + takeProfitPrice float64 + wantError bool + }{ + { + name: "多头止盈", + symbol: "BTCUSDT", + positionSide: "LONG", + quantity: 0.01, + takeProfitPrice: 55000.0, + wantError: false, + }, + { + name: "空头止盈", + symbol: "ETHUSDT", + positionSide: "SHORT", + quantity: 0.1, + takeProfitPrice: 2800.0, + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.SetTakeProfit(tt.symbol, tt.positionSide, tt.quantity, tt.takeProfitPrice) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestCancelStopOrders 测试取消止盈止损单 +func (s *TraderTestSuite) TestCancelStopOrders() { + tests := []struct { + name string + symbol string + wantError bool + }{ + { + name: "取消BTC止盈止损单", + symbol: "BTCUSDT", + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.CancelStopOrders(tt.symbol) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestCancelStopLossOrders 测试取消止损单 +func (s *TraderTestSuite) TestCancelStopLossOrders() { + tests := []struct { + name string + symbol string + wantError bool + }{ + { + name: "取消BTC止损单", + symbol: "BTCUSDT", + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.CancelStopLossOrders(tt.symbol) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestCancelTakeProfitOrders 测试取消止盈单 +func (s *TraderTestSuite) TestCancelTakeProfitOrders() { + tests := []struct { + name string + symbol string + wantError bool + }{ + { + name: "取消BTC止盈单", + symbol: "BTCUSDT", + wantError: false, + }, + } + + for _, tt := range tests { + s.T.Run(tt.name, func(t *testing.T) { + err := s.Trader.CancelTakeProfitOrders(tt.symbol) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +}