import numpy as np
import pandas as pd
import akshare as ak
import matplotlib.pyplot as plt
from talib import abstract
from datetime import datetime, timedelta
import matplotlib.dates as mdates
from collections import defaultdict
# 設置matplotlib支持中文顯示
plt.rcParams['font.sans-serif'] = ['SimHei'] # 指定默認字體
plt.rcParams['axes.unicode_minus'] = False # 解決保存圖像是負號'-'顯示為方塊的問題
class TripleBottomStrategy:
"""
三重W底量化選股策略框架
功能:
1. 動態波動率調整突破閾值
2. MACD/RSI二次驗證
3. 均線系統多因子過濾
4. 形態完成后的漲幅統計
5. 模式失敗概率分析
6. 可視化形態識別結果
"""
def __init__(self, stock_code, start_date, end_date):
"""初始化策略實例"""
self.stock_code = stock_code
self.data = self._load_data(start_date, end_date)
self._preprocess_data()
self.patterns = None # 存儲檢測到的形態
self.performance_results = None # 存儲形態表現結果
def _load_data(self, start_date, end_date):
"""從akshare加載股票數據"""
df = ak.stock_zh_a_daily(symbol=self.stock_code, start_date=start_date, end_date=end_date)
df.index = pd.to_datetime(df['date'])
return df[['open', 'high', 'low', 'close', 'volume']]
def _preprocess_data(self):
"""數據預處理和技術指標計算"""
# 計算均線
self.data['ma20'] = self.data['close'].rolling(20).mean()
self.data['ma60'] = self.data['close'].rolling(60).mean()
# 計算ATR波動率
self.data['atr'] = abstract.ATR(self.data['high'], self.data['low'],
self.data['close'], timeperiod=14)
# 計算MACD指標
self.data['macd'], self.data['macd_signal'], self.data['macd_hist'] = abstract.MACD(
self.data['close'], fastperiod=12, slowperiod=26, signalperiod=9)
# 計算RSI指標
self.data['rsi'] = abstract.RSI(self.data['close'], timeperiod=14)
# 計算未來N天的漲幅,用于回測和分析
for n in [5, 10, 20, 30]:
self.data[f'future_return_{n}d'] = self.data['close'].pct_change(n).shift(-n)
def _find_extrema(self, window=5):
"""使用滑動窗口尋找局部極值點"""
# 局部最大值
peaks = (self.data['low'] == self.data['low'].rolling(window, center=True).min())
# 局部最小值
troughs = (self.data['low'] == self.data['low'].rolling(window, center=True).min())
# 合并并排序所有極值點
extrema = pd.Series(index=self.data.index)
extrema[peaks] = 1 # 1表示波峰
extrema[troughs] = -1 # -1表示波谷
extrema = extrema.dropna()
return extrema
def _dynamic_breakout_ratio(self, date_index):
"""基于波動率動態調整突破閾值"""
# 確保有足夠的數據計算ATR
if date_index < 30:
return 0.03
recent_atr = self.data['atr'].iloc[date_index-20:date_index].mean()
overall_atr = self.data['atr'].iloc[:date_index].mean()
if overall_atr == 0:
return 0.03
base_ratio = 0.03 # 基礎突破比例
# 根據近期波動率與歷史波動率的比率調整突破閾值
dynamic_ratio = base_ratio * (recent_atr / overall_atr)
# 限制在2%-5%之間,防止極端情況
return np.clip(dynamic_ratio, 0.02, 0.05)
def _valIDAte_with_indicators(self, date_index):
"""使用MACD和RSI指標進行二次驗證"""
if date_index < 60:
return False
current_data = self.data.iloc[date_index]
# MACD金叉或在零軸附近且動能增強
macd_ok = (current_data['macd'] > current_data['macd_signal']) and \
(current_data['macd_hist'] > self.data['macd_hist'].iloc[date_index-1])
# RSI不在超買區,最好在50以下
rsi_ok = current_data['rsi'] < 60
# 均線多頭排列或即將形成
ma_ok = current_data['ma20'] > current_data['ma60']
return macd_ok and rsi_ok and ma_ok
def detect_pattern(self, min_distance=21, price_tolerance=0.05):
"""
檢測三重底形態
參數:
min_distance: 波谷之間的最小交易日距離
price_tolerance: 底部價格的容忍度,允許的最大價格差異比例
"""
extrema = self._find_extrema()
troughs = extrema[extrema == -1]
patterns = []
# 遍歷所有可能的三重底組合
for i in range(len(troughs) - 2):
t1_date, t2_date, t3_date = troughs.index[i], troughs.index[i+1], troughs.index[i+2]
# 檢查波谷之間的時間距離
days_between_t1_t2 = (t2_date - t1_date).days
days_between_t2_t3 = (t3_date - t2_date).days
if days_between_t1_t2 < min_distance or days_between_t2_t3 < min_distance:
continue
# 獲取三個底部的價格
t1_price = self.data.loc[t1_date, 'low']
t2_price = self.data.loc[t2_date, 'low']
t3_price = self.data.loc[t3_date, 'low']
# 檢查三個底部價格是否相近
if not (abs(t1_price - t2_price) / t1_price < price_tolerance and
abs(t2_price - t3_price) / t2_price < price_tolerance and
abs(t1_price - t3_price) / t1_price < price_tolerance):
continue
# 找到兩個底部之間的波峰
peak1_idx = self.data.index.get_loc(t1_date)
peak2_idx = self.data.index.get_loc(t2_date)
peak_between = self.data['high'].iloc[peak1_idx:peak2_idx].idxmax()
# 確認W底形態(中間高,兩邊低)
if not (self.data.loc[peak_between, 'high'] > max(t1_price, t2_price)):
continue
# 找到頸線位(兩個波峰的連線)
neckline_price = max(self.data.loc[peak_between, 'high'],
self.data['high'].iloc[peak2_idx:self.data.index.get_loc(t3_date)].idxmax())
# 獲取形態完成日期的索引
completion_idx = self.data.index.get_loc(t3_date)
# 計算動態突破閾值
breakout_ratio = self._dynamic_breakout_ratio(completion_idx)
breakout_price = neckline_price * (1 + breakout_ratio)
# 檢查是否突破
# 我們檢查突破后的5個交易日內是否收盤價高于突破價
future_data = self.data.iloc[completion_idx:completion_idx+6]
breakout_date = None
for date, row in future_data.iterrows():
if row['close'] > breakout_price:
breakout_date = date
break
if breakout_date is None:
continue
# 指標二次驗證
if not self._validate_with_indicators(self.data.index.get_loc(breakout_date)):
continue
# 記錄形態信息
patterns.append({
't1_date': t1_date,
't2_date': t2_date,
't3_date': t3_date,
'neckline_price': neckline_price,
'breakout_price': breakout_price,
'breakout_date': breakout_date,
'breakout_ratio_used': breakout_ratio,
'success': None, # 成功與否將在回測中確定
'future_returns': {}
})
self.patterns = patterns
return patterns
def analyze_pattern_performance(self):
"""分析形態完成后的漲幅和失敗概率"""
if not self.patterns:
print("未檢測到任何三重底形態,無法進行性能分析。")
return
results = []
for pattern in self.patterns:
breakout_idx = self.data.index.get_loc(pattern['breakout_date'])
# 計算不同周期的未來收益
for n in [5, 10, 20, 30]:
if breakout_idx + n < len(self.data):
pattern['future_returns'][f'{n}d'] = self.data.iloc[breakout_idx + n]['close'] / \
self.data.loc[pattern['breakout_date'], 'close'] - 1
else:
pattern['future_returns'][f'{n}d'] = np.nan
results.append({
'breakout_date': pattern['breakout_date'],
'return_5d': pattern['future_returns'].get('5d'),
'return_10d': pattern['future_returns'].get('10d'),
'return_20d': pattern['future_returns'].get('20d'),
'return_30d': pattern['future_returns'].get('30d')
})
self.performance_results = pd.DataFrame(results).set_index('breakout_date')
# 計算失敗概率
success_threshold = 0.02 # 漲幅超過2%視為成功
success_counts = defaultdict(int)
total_counts = defaultdict(int)
for pattern in self.patterns:
for period, ret in pattern['future_returns'].items():
if pd.notna(ret):
total_counts[period] += 1
if ret >= success_threshold:
success_counts[period] += 1
print("\n--- 形態表現分析 ---")
print(f"共檢測到 {len(self.patterns)} 個有效三重底形態")
print("\n未來漲幅統計:")
print(self.performance_results.describe())
print("\n模式失敗概率分析:")
for period in ['5d', '10d', '20d', '30d']:
if total_counts[period] > 0:
success_rate = success_counts[period] / total_counts[period]
failure_rate = 1 - success_rate
print(f"{period} 失敗概率: {failure_rate:.2%} (成功: {success_counts[period]}, 總樣本: {total_counts[period]})")
else:
print(f"{period} 沒有足夠的數據進行分析")
def visualize(self):
"""可視化形態識別結果"""
if not self.patterns:
print("未檢測到任何三重底形態,無法進行可視化。")
return
fig, ax = plt.subplots(figsize=(16, 10))
# 繪制價格和均線
ax.plot(self.data.index, self.data['close'], label='收盤價', linewidth=2)
ax.plot(self.data.index, self.data['ma20'], label='20日均線', color='orange', alpha=0.7)
ax.plot(self.data.index, self.data['ma60'], label='60日均線', color='purple', alpha=0.7)
# 標記形態關鍵點
for pattern in self.patterns:
# 標記三個底部
ax.scatter([pattern['t1_date'], pattern['t2_date'], pattern['t3_date']],
[self.data.loc[pattern['t1_date'], 'low'],
self.data.loc[pattern['t2_date'], 'low'],
self.data.loc[pattern['t3_date'], 'low']],
color='green', s=100, marker='^', label='三重底底部')
# 標記突破點
ax.scatter(pattern['breakout_date'], self.data.loc[pattern['breakout_date'], 'close'],
color='red', s=120, marker='*', label='突破點')
# 繪制頸線
# 找到第一個波峰
peak1_idx = self.data.index.get_loc(pattern['t1_date'])
peak2_idx = self.data.index.get_loc(pattern['t2_date'])
peak_between_date = self.data['high'].iloc[peak1_idx:peak2_idx].idxmax()
# 找到第二個波峰
peak3_idx = self.data.index.get_loc(pattern['t3_date'])
peak_after_idx = self.data.index.get_loc(pattern['breakout_date'])
peak_after_date = self.data['high'].iloc[peak2_idx:peak3_idx].idxmax()
# 繪制頸線
neckline_points = [peak_between_date, peak_after_date, pattern['breakout_date']]
ax.plot(neckline_points, [pattern['neckline_price']]*3, 'r--', label='頸線')
# 添加標注
ax.annotate(f"突破: {pattern['breakout_price']:.2f}",
xy=(pattern['breakout_date'], pattern['breakout_price']),
xytext=(pattern['breakout_date'], pattern['breakout_price'] * 1.05),
arrowprops=dict(facecolor='black', shrink=0.05))
ax.set_title(f"{self.stock_code} 三重底形態識別結果")
ax.set_xlabel("日期")
ax.set_ylabel("價格")
ax.grid(True)
ax.legend()
# 優化x軸日期顯示
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
plt.gcf().autofmt_xdate()
plt.show()