回测脚本

Posted by Gregorius Blog on May 28, 2025

import baostock as bs import pandas as pd import numpy as np import time import random import logging import sys import os import argparse import requests import numpy as np import pandas as pd import traceback # 必须添加此行 from datetime import datetime, timedelta from tqdm import tqdm

配置日志

logging.basicConfig(level=logging.INFO, format=’%(asctime)s - %(levelname)s - %(message)s’)

=========================

ATR(TV一致)

=========================

def atr(df, length=14): high = df[‘high’] low = df[‘low’] close = df[‘close’]

tr = pd.concat([
    high - low,
    (high - close.shift(1)).abs(),
    (low - close.shift(1)).abs()
], axis=1).max(axis=1)

return tr.ewm(alpha=1/length, adjust=False).mean()

=========================

主策略(PINE严格对齐版)

=========================

def calculate_brooks_strategy(df, daily_ema10, target_time_marker=None, is_debug=False):

df=df.copy()

df[['open','high','low','close']] = \
    df[['open','high','low','close']].astype(float)

df['body']=abs(df['close']-df['open'])

df['ema20']=df['close'].ewm(span=20,adjust=False).mean()
df['ema50']=df['close'].ewm(span=50,adjust=False).mean()

df['ma20']=df['close'].rolling(20).mean()
df['ma20_5']=df['ma20'].shift(5)

df['atr20']=atr(df,20)

df['h20']=df['high'].rolling(20).max()
df['l20']=df['low'].rolling(20).min()

state=1
bull_count=0
range_test_count=0
confirm_count=0

dyn_res=None
dyn_sup=None

last_signal=False

for i in range(len(df)):

    if i<20:
        continue

    close=df['close'].iloc[i]
    open_=df['open'].iloc[i]
    high=df['high'].iloc[i]
    low=df['low'].iloc[i]

    is_bull=close>open_

    ma20=df['ma20'].iloc[i]
    ma20_5=df['ma20_5'].iloc[i]
    atr20=df['atr20'].iloc[i]

    if pd.isna(ma20) or pd.isna(atr20):
        continue

    range_threshold=atr20*1.5

    is_in_range=(
        high<ma20+range_threshold
        and low>ma20-range_threshold
        and abs(close-ma20)<range_threshold*0.9
    )

    is_trending_up=(
        close>ma20+range_threshold
        and ma20>ma20_5
    )

    is_trending_down=(
        close<ma20-range_threshold
        and ma20<ma20_5
    )

    close_prev=df['close'].iloc[i-1]

    daily_pct=(close-close_prev)/close_prev*100

    is_daily_runup_too_high=(daily_pct>7)

    kline_change_pct=(close-open_)/open_*100

    is_strong_bar=(
        close>open_
        and abs(close-open_)/
        max(high-low,1e-6)>0.4
    )

    is_valid_growth=(
        1.01<kline_change_pct<6.99
    )

    is_breakout=(high>df['high'].iloc[i-1])

    common_entry_filters=(

        is_strong_bar
        and is_valid_growth
        and not is_daily_runup_too_high
        and is_breakout
    )

    if state==1 and close<close_prev:
        range_test_count+=1
    elif is_bull:
        range_test_count=0

    if dyn_res is None or i%50==0:

        dyn_res=df['h20'].iloc[i]
        dyn_sup=df['l20'].iloc[i]

    if is_in_range:

        dyn_res=max(dyn_res,high)
        dyn_sup=min(dyn_sup,low)

    else:

        if close>dyn_res:
            dyn_res=close

        if close<dyn_sup:
            dyn_sup=close

    body=abs(close-open_)

    lower_shadow=min(open_,close)-low
    upper_shadow=high-max(open_,close)

    is_hammer_bar=(

        is_bull
        and not is_daily_runup_too_high
        and lower_shadow>body*2
        and lower_shadow>upper_shadow
        and low<=dyn_sup*1.015
    )

    if state==2 and close<open_:
        bull_count+=1

    elif is_bull:
        bull_count=0

    consecutive_bull=(

        i>2
        and high>df['high'].iloc[i-1]
        and df['high'].iloc[i-1]>df['high'].iloc[i-2]
    )

    momentum_signal=(
        state==2
        and consecutive_bull
    )

    if is_in_range:

        state=1
        confirm_count=0

    else:

        if is_trending_up:

            state=2

        elif is_trending_down:

            state=3

        else:

            if common_entry_filters:

                confirm_count+=1

                if confirm_count>=2:
                    state=2

            else:

                confirm_count=0
                state=1

    is_near_support=(
        low<=dyn_sup*1.015
    )

    range_h1=(

        state==1
        and range_test_count==1
        and (is_hammer_bar or common_entry_filters)
        and is_near_support
    )

    range_h2=(

        state==1
        and range_test_count>=2
        and (is_hammer_bar or common_entry_filters)
        and is_near_support
    )

    is_range_entry=(
        range_h1 or range_h2
    )

    initial_breakout_entry=(

        state==2
        and not is_daily_runup_too_high
        and kline_change_pct<9.95
        and close>=dyn_res
    )

    is_daily_trend_bull=(
        close>daily_ema10
    )

    body_mean=df['body'].rolling(5).mean().iloc[i]

    is_strong_bull=(

        is_bull
        and body>body_mean*1.5
    )

    go_long=(

        state!=3
        and is_daily_trend_bull
        and (df['ema20'].iloc[i] >
             df['ema50'].iloc[i])

        and (close >
             df['ema20'].iloc[i])

        and (

            is_hammer_bar
            or momentum_signal
            or initial_breakout_entry
            or is_range_entry
        )
    )

    last_signal=go_long

return last_signal

— 工具:消息推送 (带去重逻辑) —

def send_to_phone(title, content, code, date_str): “”” 带详细日志监控的推送函数 “”” file_path = “last_sent.txt” current_hour_id = f”{code}_{datetime.now().strftime(‘%Y%m%d%H’)}”

# 1. 日志:检查去重逻辑
if os.path.exists(file_path):
    with open(file_path, "r") as f:
        if current_hour_id in f.read():
            tqdm.write(f"⏩ [推送日志] {code} 该小时已推送过,本次跳过。")
            return

# --- 【新增】支持多人推送 ---
# 将你的 token 和朋友的 token 放进列表
tokens = [
    '7e7cf8f4208a41f88876783ef54e01a0', # 你的 Token
    '536392585aab4c7d8dfad339cda90fea'            # 朋友的 Token
]
 
base_url = "http://www.pushplus.plus/send"
full_content = f"日期: {date_str}<br>股票: {content}"

for token in tokens:
    # 将参数放入字典,requests 会自动进行 URL 编码
    payload = {
        "token": token,
        "title": title,
        "content": full_content,
        "template": "html"
    }
    
    tqdm.write(f"🚀 [推送日志] 正在请求推送接口... (目标代码: {code})")
    
    try:
        # 使用 params 传入参数,不再手动拼接 URL
        response = requests.get(base_url, params=payload, timeout=10)
        
        # 打印返回的文本,看看具体的错误信息
        if response.status_code == 200:
            tqdm.write(f"✅ [推送日志] 接口请求成功")
            # (去重记录逻辑...)
        else:
            tqdm.write(f"⚠️ [推送日志] 接口返回状态码: {response.status_code}")
            tqdm.write(f"⚠️ [推送日志] 服务器响应内容: {response.text}") # 查看服务器详细报错
            
    except Exception as e:
        tqdm.write(f"🚨 [推送日志] 推送发生未知异常: {str(e)}")

— 增加一个工具函数来提取 BaoStock 数据 —

def fetch_baostock_data(rs): data_list = [] while (rs.error_code == ‘0’) & rs.next(): data_list.append(rs.get_row_data()) return pd.DataFrame(data_list, columns=rs.fields)

— 环境检测与交互逻辑 —

def get_config(): parser = argparse.ArgumentParser() parser.add_argument(‘–manual’, action=’store_true’, help=”强制开启手动输入模式”) parser.add_argument(“–mute”, action=”store_true”, help=”关闭推送功能”) parser.add_argument(“–debug”, action=’store_true’, help=”开启详细调试日志”) # 新增开关 args, unknown = parser.parse_known_args()

target_code, target_date, target_time_marker = None, None, None

if args.manual:
    print("-" * 30)
    target_code = input("请输入股票代码 (回车全市场扫描): ").strip()
    time_input = input("请输入回测时间 (格式: 2026-05-26 10:30, 回车跳过): ").strip()
    
    if time_input:
        parts = time_input.split()
        target_date = parts[0]
        if len(parts) > 1:
            target_time_marker = parts[1].replace(":", "")
else:
    print(">>> [自动模式] 未检测到 --manual 参数,使用最新数据自动扫描。")

return target_code, target_date, target_time_marker, args.mute, args.debug

def get_stock_list_with_fallback(date_str): “”” 自动回溯获取股票列表:如果当天查不到,自动尝试前一天 “”” current_date = datetime.strptime(date_str, “%Y-%m-%d”)

# 最多尝试回溯 5 天 (防止因为长期停牌导致无限循环)
for i in range(5):
    query_date = (current_date - timedelta(days=i)).strftime("%Y-%m-%d")
    print(f">>> 正在尝试查询日期: {query_date} 的股票列表...")
    
    rs = bs.query_all_stock(day=query_date)
    stocks = fetch_baostock_data(rs)
    
    if not stocks.empty:
        print(f">>> 成功获取到 {len(stocks)} 只股票数据 (日期: {query_date})")
        return stocks
    else:
        print(f">>> {query_date} 无数据,继续回溯...")
        
return pd.DataFrame() # 如果回溯 5 天仍无数据,返回空

定义一个辅助函数,方便到处调用

def log_debug(message): if is_debug: tqdm.write(f”🔍 [DEBUG] {message}”)

— 主程序 —

def main(): if not bs.login().error_code == ‘0’: print(“登录失败,请检查网络。”) return

target_code_in, target_date, target_time_marker, is_muted, is_debug= get_config()
is_muted = True

# 2. 处理股票代码 (兼容6位,补全为9位)
target_code = None
if target_code_in:
    # 去掉可能存在的空格
    code_str = target_code_in.strip()
    
    # 判断前缀:60, 68, 58 开头加 sh.,其他加 sz.
    if code_str.startswith(('60', '68', '58')):
        target_code = f"sh.{code_str}"
    else:
        target_code = f"sz.{code_str}"
    
    # 强制校验长度
    if len(target_code) != 9:
        print(f"错误:转换后的代码 {target_code} 长度不为 9,请检查输入!")
        return
   
# 时间基准
today = datetime.now()
today_str = target_date if target_date else today.strftime("%Y-%m-%d")
today_dt = datetime.strptime(today_str, "%Y-%m-%d")

index_start = (today_dt - timedelta(days=60)).strftime("%Y-%m-%d")
stock_start = (today_dt - timedelta(days=30)).strftime("%Y-%m-%d")

# 4. 获取股票列表 if target_code: # — 动态获取股票名称 — # 这里的 code_str 就是你输入的纯数字代码,例如 ‘603823’ rs_basic = bs.query_stock_basic(code=target_code) basic_df = fetch_baostock_data(rs_basic)

    if not basic_df.empty:
        code_name = basic_df.iloc[0]['code_name']
    else:
        code_name = "未知股票"
        
    stocks = pd.DataFrame({'code': [target_code], 'code_name': [code_name]})
    print(f">>> 已定位目标: {code_name} ({target_code})")
else:
    print(">>> 正在加载全市场股票...")
    stocks = get_stock_list_with_fallback(today_str)
    
    if stocks.empty:
        print("🚨 致命错误:连续 5 天均无法获取股票列表,程序退出。")
        bs.logout()
        return
    
    # --- 增加这一行,看看原始下载到了多少只 ---
    print(f">>> 原始下载到 {len(stocks)} 只股票")
    
    # --- 过滤逻辑 ---
    # 1. 除去科创板(688开头)
    stocks = stocks[~stocks['code'].str.contains('688')]
    
    # 2. 除去 ETF (以 51 或 15 开头)
    # 注意:sh.51xxxx, sz.15xxxx 是常见的 ETF 代码
    # 更加稳健的写法:判断代码是否以特定前缀开头
    # 这样完全避开了正则表达式,不会有任何警告
    is_etf = stocks['code'].str.startswith(('sh.51', 'sh.52', 'sh.53', 'sh.55', 'sh.56', 'sh.58', 'sz.15'))
    stocks = stocks[~is_etf]
    
    # 3. 除去所有包含 'ST' 的股票名称
    stocks = stocks[~stocks['code_name'].str.contains('ST')]
    
    # 4. 仅保留主板和创业板 (sh/sz)
    stocks = stocks[stocks['code'].str.contains('sh\.|sz\.')]
    # --- 增加这一行,看看过滤后剩多少 ---
    print(f">>> 过滤后剩余 {len(stocks)} 只")

# --- 扫描前环境信息 ---
print("-" * 50)
print(f"分析日期: {today_str}")
print(f"扫描模式: {'单只' if target_code else '全市场'} ({len(stocks)} 只)")
print("-" * 50)

# 初始化容器
signal_list = []

# 5. 执行分析
for _, row in tqdm(stocks.iterrows(), total=len(stocks), desc="Brooks 扫描中"):
    try:
        try:
            rs_hourly = bs.query_history_k_data_plus(
                row['code'],
                "date,time,open,high,low,close",
                start_date=stock_start,
                end_date=today_str,
                frequency="60",
                adjustflag="3"
            )

        except Exception:

            tqdm.write("BaoStock断连,自动重连...")

            bs.logout()
            time.sleep(2)

            bs.login()

            rs_hourly = bs.query_history_k_data_plus(
                row['code'],
                "date,time,open,high,low,close",
                start_date=stock_start,
                end_date=today_str,
                frequency="60",
                adjustflag="3"
            )
        df_hourly = fetch_baostock_data(rs_hourly)
        
        if len(df_hourly) < 50: continue
        
       # --- 调整后的逻辑 ---
        if target_time_marker:
            # 过滤出符合条件的行
            target_df = df_hourly[df_hourly['time'].str.contains(target_time_marker)]
            # 取最后一个符合条件的时间点 (即该小时的 K 线)
            target_bar = target_df.iloc[-1]
        else:
            # 默认扫描最新行情
            target_bar = df_hourly.iloc[-1]
        
        # 从目标行提取信息
        last_time_str = target_bar['time']
        formatted_time = f"{last_time_str[8:10]}:{last_time_str[10:12]}"
        last_price = target_bar['close']
        
        # 打印信息
        status_info = f"[{row.get('code_name', 'N/A')} {row['code']}] 时间: {formatted_time} 价格: {last_price}"
        tqdm.write(f"正在扫描: {status_info}")
        
        # --- 1. 获取并处理日线数据以计算基准 ---
        rs_daily = bs.query_history_k_data_plus(row['code'], "date,close", 
                                               start_date=(today_dt - timedelta(days=60)).strftime("%Y-%m-%d"), 
                                               end_date=today_str, frequency="d")
        df_daily = fetch_baostock_data(rs_daily)
        df_daily['close'] = df_daily['close'].astype(float)

        # 计算 EMA10
        individual_daily_ema10 = df_daily['close'].ewm(span=10, adjust=False).mean().iloc[-1]

        # 获取“昨日收盘价” (取最后一行前的一个交易日)
        # df_daily 已经是按日期排序的,最后一行是今天,倒数第二行是昨天
        prev_day_close = df_daily['close'].iloc[-2] 

        # --- 2. 将日线基准数据带入小时线 (df_hourly) ---
        # 计算小时线数据的每日涨幅
        df_hourly['close'] = df_hourly['close'].astype(float)
        # 这里直接使用昨日收盘价计算今日所有小时线的涨幅
        df_hourly['daily_pct'] = (df_hourly['close'] - prev_day_close) / prev_day_close * 100
        # 必须先从外部读取当日的开盘前数据作为 daily_pct 的计算基础。
        # 判断策略
        if calculate_brooks_strategy(df_hourly, individual_daily_ema10, target_time_marker, is_debug):
            # 将信号存入列表
            signal_info = {
                'name': row.get('code_name', 'N/A'),
                'code': row['code'],
                'time': formatted_time,
                'price': last_price
            }
            
            signal_list.append(signal_info)
            green_start = "\033[92m"
            reset = "\033[0m"
            tqdm.write(f"{green_start}✅ 发现信号: {row.get('code_name', 'N/A')} ({row['code']}){reset}")
            
            # --- 【新增】发送到手机 ---
            msg_body = f"{signal_info['name']} ({signal_info['code']})<br>触发时间: {formatted_time}<br>当前价格: {last_price}"
            
            # 判断推送开关
            if not is_muted:
                # 调用推送,传入日期
                send_to_phone("Brooks 策略信号提醒", msg_body, row['code'], today_str)
  
            if not is_muted:
                # 【新增】强制等待,确保推送完成
                tqdm.write("⏳ 等待推送响应确认...")
            time.sleep(2)
        
        time.sleep(0.05) 
    except Exception as e:
        # 捕获详细堆栈信息
        error_msg = traceback.format_exc()
        tqdm.write(f"❌ 处理代码 {row['code']} 时出错:")
        tqdm.write(error_msg) # 打印完整的错误轨迹
        continue
        
bs.logout()
# --- 扫描结束后的汇总列表 ---
print("\n" + "="*50)
print(f"扫描完成,共发现 {len(signal_list)} 只符合条件的股票:")
print(f"{'名称':<10} | {'代码':<12} | {'最新时间':<10} | {'价格':<8}")
print("-" * 50)

for sig in signal_list:
    print(f"{sig['name']:<10} | {sig['code']:<12} | {sig['time']:<10} | {sig['price']:<8}")

print("="*50)

if name == “main”: main()