Backtrader量化&回测3——读取SQL数据(以SQL lite数据库为例)
  TEZNKK3IfmPf 2023年11月14日 27 0

SQL读取类

官方暂时没有一键读取的功能类,因此需要自己写,一个简单的例子如下,运用请参考最下面的示例代码:

class SQLiteData(DataBase):
    """自定义的SQL lite数据格式"""
    params = (
        ('dataname', None),  # 策略中读取数据库是用到的名称
        ('name', ''),  # 绘图时用到的名称
        ('timeframe', TimeFrame.Days),  # 每条K线代表的时间长短
        ('fromdate', None),  # 从什么时候开始
        ('todate', None),  # 到什么时候截止
    )

    def __init__(self):
        self.engine = create_engine('sqlite:///local_sql_lite.db')  # 初始化数据库连接
        self.tabel_name = "my_stock_code"  # 数据表名称

    def start(self):  # 只会在加载数据前执行一次,常用于初始化参数
        self.conn = self.engine.connect()
        sql_query = "SELECT `date`,`open`,`high`,`low`,`close`,`volume`,`turnover` FROM `{}` ORDER BY `date` ASC" \
            .format(self.tabel_name)
        self.result = self.conn.execute(sql_query)

    def stop(self):  # 结束数据加载程序之后执行一次,常用于关闭数据库链接
        self.engine.dispose()

    def _load(self):  # 类似于策略的 next(),预期执行几次 next(),就会执行几次 _load()
        one_row = self.result.fetchone()
        if one_row is None:
            return False
        self.lines.datetime[0] = date2num(dt.datetime.strptime(str(one_row[0]), '%Y-%m-%d %H:%M:%S'))  # date parsing
        self.lines.open[0] = float(one_row[1])
        self.lines.high[0] = float(one_row[2])
        self.lines.low[0] = float(one_row[3])
        self.lines.close[0] = float(one_row[4])
        self.lines.volume[0] = int(one_row[5])
        self.lines.turnover[0] = float(one_row[6])
        self.lines.openinterest[0] = -1
        return True

其中有几个比较重要的函数:

  • def start():在加载数据前执行一次,常用于初始化参数
  • def stop():结束数据加载程序之后执行一次,常用于关闭数据库链接
  • def _load():在策略中的next()拿到的数据其实就是这里传过去的数据,会循环执行多次这个函数,直到取得全部数据或接收到FalseNone的返回值
  • params:这是定义数据集本身的一些参数,重要参数已在示例代码中解释,更多参数请参考官网

注意:def _load() 函数中:

  • self.lines:表示第一个数据集的列,等同于 self.datas[0].lines,是一种简写形式
  • self.lines.open:指代数据中的开盘价那一列
  • self.lines.open[0]:特指当天的开盘价,如果是self.lines.open[-1],就是昨天的,如果是self.lines.open[1] 就是明天的

示例代码

import backtrader
import efinance
import pandas as pd
from datetime import datetime
import sqlite3
import datetime as dt
from backtrader import TimeFrame
from backtrader.feed import DataBase
from backtrader import date2num
from sqlalchemy import create_engine


def get_k_data(stock_code, begin: datetime, end: datetime) -> pd.DataFrame:
    """根据efinance工具包获取股票数据 :param stock_code:股票代码 :param begin: 开始日期 :param end: 结束日期 """
    # stock_code = '600519' # 股票代码,茅台
    k_dataframe: pd.DataFrame = efinance.stock.get_quote_history(
        stock_code, beg=begin.strftime("%Y%m%d"), end=end.strftime("%Y%m%d"))
    k_dataframe = k_dataframe.iloc[:, :9]
    k_dataframe.columns = ['name', 'code', 'date', 'open', 'close', 'high', 'low', 'volume', 'turnover']
    k_dataframe.index = pd.to_datetime(k_dataframe.date)
    k_dataframe.drop(['name', 'code', "date"], axis=1, inplace=True)
    return k_dataframe


def write_sql_lite_from_pandas(stock_code, begin: datetime, end: datetime):
    """获取K线数据,并保存到SQL lite数据库"""
    conn = sqlite3.connect('local_sql_lite.db')
    dataframe = get_k_data(stock_code, begin=begin, end=end)
    dataframe.to_sql("my_stock_code", conn, if_exists="replace")  # 保存数据到数据库中,表名称:my_stock_code


class SQLiteData(DataBase):
    """自定义的SQL lite数据格式"""
    params = (
        ('dataname', None),  # 策略中读取数据库是用到的名称
        ('name', ''),  # 绘图时用到的名称
        ('timeframe', TimeFrame.Days),  # 每条K线代表的时间长短
        ('fromdate', None),  # 从什么时候开始
        ('todate', None),  # 到什么时候截止
        # 下面是除了默认的open,close,high,low,volume外,新添加的维度
        ('turnover', -1),
    )

    # 新添加数据列用法相同
    lines = ('turnover',)

    def __init__(self):
        self.engine = create_engine('sqlite:///local_sql_lite.db')
        self._timeframe = self.p.timeframe
        self._compression = self.p.compression
        self._dataname = "my_stock_code"

    def start(self):
        self.conn = self.engine.connect()
        sql_query = "SELECT `date`,`open`,`high`,`low`,`close`,`volume`,`turnover` FROM `{}` ORDER BY `date` ASC" \
            .format(self._dataname)
        self.result = self.conn.execute(sql_query)

    def stop(self):
        self.engine.dispose()

    def _load(self):
        # 会全部循环完毕,然后再读取
        one_row = self.result.fetchone()
        if one_row is None:
            return False
        self.lines.datetime[0] = date2num(dt.datetime.strptime(str(one_row[0]), '%Y-%m-%d %H:%M:%S'))  # date parsing
        self.lines.open[0] = float(one_row[1])
        self.lines.high[0] = float(one_row[2])
        self.lines.low[0] = float(one_row[3])
        self.lines.close[0] = float(one_row[4])
        self.lines.volume[0] = int(one_row[5])
        self.lines.turnover[0] = float(one_row[6])
        self.lines.openinterest[0] = -1
        return True


class MyStrategy1(backtrader.Strategy):  # 策略
    def __init__(self):
        # 初始化交易指令、买卖价格和手续费
        self.close_price = self.datas[0].close  # 这里加一个数据引用,方便后续操作
        this_data = self.getdatabyname("stock_600519")  # 获取传入的 name = stock_600519 的数据
        print("全部列名:", this_data.getlinealiases())  # 全部的列名称

    def next(self):  # 框架执行过程中会不断循环next(),过一个K线,执行一次next()
        print('=======================')
        print("今天是:", self.datetime.date())
        print("当前的值:", dict(zip(self.datas[0].getlinealiases(), [i[0] for i in list(self.datas[0].lines)])))


def main():
    # 获取数据
    start_time = datetime(2015, 1, 1)
    end_time = datetime(2015, 1, 10)
    # 先保存数据到 SQL lite 数据库,用于后续的读取
    write_sql_lite_from_pandas("600519", start_time, end_time)
    # 从SQL lite数据库读取数据
    data = SQLiteData()
    # =============== 为系统注入数据 =================
    # 加载数据
    # data = PandasDataPlus(dataname=dataframe, fromdate=start_time, todate=end_time)
    # 初始化cerebro回测系统
    cerebral_system = backtrader.Cerebro()  # Cerebro引擎在后台创建了broker(经纪人)实例,系统默认每个broker的初始资金量为10000
    # 将数据传入回测系统
    cerebral_system.adddata(data, name="stock_600519")  # 导入数据,在策略中使用 self.datas 来获取数据源
    # 将交易策略加载到回测系统中
    cerebral_system.addstrategy(MyStrategy1)
    # =============== 系统设置 ==================
    # 运行回测系统
    cerebral_system.run()


if __name__ == '__main__':
    main()
【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月14日 0

暂无评论

推荐阅读
  TEZNKK3IfmPf   2024年05月31日   26   0   0 sqlite数据库
  TEZNKK3IfmPf   2024年05月31日   31   0   0 数据库mysql
  TEZNKK3IfmPf   2024年05月17日   38   0   0 sqlcube
  TEZNKK3IfmPf   2024年05月31日   27   0   0 数据库mysql
TEZNKK3IfmPf