用于EagleEye3.0 规则集漏报和误报测试的示例项目,项目收集于github和gitee
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

412 lines
14 KiB

""""""
from datetime import datetime
from typing import List, Optional, Sequence, Type
from peewee import (
AutoField,
CharField,
Database,
DateTimeField,
FloatField,
Model,
MySQLDatabase,
PostgresqlDatabase,
SqliteDatabase,
chunked,
)
from source.common.datastruct import Exchange, Interval, BarData, TickData
from source.common.utility import get_file_path
from .database import BaseDatabaseManager, Driver
def init(driver: Driver, settings: dict):
init_funcs = {
Driver.SQLITE: init_sqlite,
Driver.MYSQL: init_mysql,
Driver.POSTGRESQL: init_postgresql,
}
assert driver in init_funcs
db = init_funcs[driver](settings)
bar, tick = init_models(db, driver)
return SqlManager(bar, tick)
def init_sqlite(settings: dict):
database = settings["database"]
path = str(get_file_path(database))
db = SqliteDatabase(path)
return db
def init_mysql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = MySQLDatabase(**settings)
return db
def init_postgresql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = PostgresqlDatabase(**settings)
return db
class ModelBase(Model):
def to_dict(self):
return self.__data__
def init_models(db: Database, driver: Driver):
class DbBarData(ModelBase):
"""
Candlestick bar data for database storage.
Index is defined unique with datetime, interval, symbol
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
interval: str = CharField()
volume: float = FloatField()
open_interest: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
class Meta:
database = db
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_interest = bar.open_interest
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name="DB",
)
return bar
@staticmethod
def save_all(objs: List["DbBarData"]):
"""
save a list of objects, update if exists.
"""
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for bar in dicts:
DbBarData.insert(bar).on_conflict(
update=bar,
conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol,
DbBarData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbBarData.insert_many(
c).on_conflict_replace().execute()
class DbTickData(ModelBase):
"""
Tick data for database storage.
Index is defined unique with (datetime, symbol)
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
name: str = CharField()
volume: float = FloatField()
open_interest: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
limit_down: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
pre_close: float = FloatField()
bid_price_1: float = FloatField()
bid_price_2: float = FloatField(null=True)
bid_price_3: float = FloatField(null=True)
bid_price_4: float = FloatField(null=True)
bid_price_5: float = FloatField(null=True)
ask_price_1: float = FloatField()
ask_price_2: float = FloatField(null=True)
ask_price_3: float = FloatField(null=True)
ask_price_4: float = FloatField(null=True)
ask_price_5: float = FloatField(null=True)
bid_volume_1: float = FloatField()
bid_volume_2: float = FloatField(null=True)
bid_volume_3: float = FloatField(null=True)
bid_volume_4: float = FloatField(null=True)
bid_volume_5: float = FloatField(null=True)
ask_volume_1: float = FloatField()
ask_volume_2: float = FloatField(null=True)
ask_volume_3: float = FloatField(null=True)
ask_volume_4: float = FloatField(null=True)
ask_volume_5: float = FloatField(null=True)
class Meta:
database = db
indexes = ((("datetime", "symbol", "exchange"), True),)
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.open_interest = tick.open_interest
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
return db_tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
open_interest=self.open_interest,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name="DB",
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
@staticmethod
def save_all(objs: List["DbTickData"]):
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for tick in dicts:
DbTickData.insert(tick).on_conflict(
update=tick,
conflict_target=(
DbTickData.datetime,
DbTickData.symbol,
DbTickData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbTickData.insert_many(
c).on_conflict_replace().execute()
db.connect()
db.create_tables([DbBarData, DbTickData])
return DbBarData, DbTickData
class SqlManager(BaseDatabaseManager):
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
self.class_bar = class_bar
self.class_tick = class_tick
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime,
) -> Sequence[BarData]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
& (self.class_bar.datetime >= start)
& (self.class_bar.datetime <= end)
)
.order_by(self.class_bar.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
) -> Sequence[TickData]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
& (self.class_tick.datetime >= start)
& (self.class_tick.datetime <= end)
)
.order_by(self.class_tick.datetime)
)
data = [db_tick.to_tick() for db_tick in s]
return data
def save_bar_data(self, datas: Sequence[BarData]):
ds = [self.class_bar.from_bar(i) for i in datas]
self.class_bar.save_all(ds)
def save_tick_data(self, datas: Sequence[TickData]):
ds = [self.class_tick.from_tick(i) for i in datas]
self.class_tick.save_all(ds)
def get_newest_bar_data(
self, symbol: str, exchange: "Exchange", interval: "Interval"
) -> Optional["BarData"]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
)
.order_by(self.class_bar.datetime.desc())
.first()
)
if s:
return s.to_bar()
return None
def get_newest_tick_data(
self, symbol: str, exchange: "Exchange"
) -> Optional["TickData"]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
)
.order_by(self.class_tick.datetime.desc())
.first()
)
if s:
return s.to_tick()
return None
def clean(self, symbol: str):
self.class_bar.delete().where(self.class_bar.symbol == symbol).execute()
self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()