You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
354 lines
11 KiB
354 lines
11 KiB
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Sequence, Optional
|
|
|
|
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
|
|
|
|
from ..common.datastruct import Exchange, Interval, BarData, TickData
|
|
from .database import BaseDatabaseManager, Driver
|
|
|
|
|
|
def init(_: Driver, settings: dict):
|
|
database = settings["database"]
|
|
host = settings["host"]
|
|
port = settings["port"]
|
|
username = settings["user"]
|
|
password = settings["password"]
|
|
authentication_source = settings["authentication_source"]
|
|
|
|
if not username: # if username == '' or None, skip username
|
|
username = None
|
|
password = None
|
|
authentication_source = None
|
|
|
|
connect(
|
|
db=database,
|
|
host=host,
|
|
port=port,
|
|
username=username,
|
|
password=password,
|
|
authentication_source=authentication_source,
|
|
)
|
|
|
|
return MongoManager()
|
|
|
|
|
|
class DbBarData(Document):
|
|
"""
|
|
Candlestick bar data for database storage.
|
|
|
|
Index is defined unique with datetime, interval, symbol
|
|
"""
|
|
|
|
symbol: str = StringField()
|
|
exchange: str = StringField()
|
|
datetime: datetime = DateTimeField()
|
|
interval: str = StringField()
|
|
|
|
volume: float = FloatField()
|
|
open_interest: float = FloatField()
|
|
open_price: float = FloatField()
|
|
high_price: float = FloatField()
|
|
low_price: float = FloatField()
|
|
close_price: float = FloatField()
|
|
|
|
meta = {
|
|
"indexes": [
|
|
{
|
|
"fields": ("datetime", "interval", "symbol", "exchange"),
|
|
"unique": True,
|
|
}
|
|
]
|
|
}
|
|
|
|
@staticmethod
|
|
def from_bar(bar: BarData):
|
|
"""
|
|
Generate DbBarData object from BarData.
|
|
"""
|
|
db_bar = DbBarData()
|
|
|
|
db_bar.symbol = bar.symbol
|
|
db_bar.exchange = bar.exchange.value
|
|
db_bar.datetime = bar.datetime
|
|
db_bar.interval = bar.interval.value
|
|
db_bar.volume = bar.volume
|
|
db_bar.open_interest = bar.open_interest
|
|
db_bar.open_price = bar.open_price
|
|
db_bar.high_price = bar.high_price
|
|
db_bar.low_price = bar.low_price
|
|
db_bar.close_price = bar.close_price
|
|
|
|
return db_bar
|
|
|
|
def to_bar(self):
|
|
"""
|
|
Generate BarData object from DbBarData.
|
|
"""
|
|
bar = BarData(
|
|
symbol=self.symbol,
|
|
exchange=Exchange(self.exchange),
|
|
datetime=self.datetime,
|
|
interval=Interval(self.interval),
|
|
volume=self.volume,
|
|
open_interest=self.open_interest,
|
|
open_price=self.open_price,
|
|
high_price=self.high_price,
|
|
low_price=self.low_price,
|
|
close_price=self.close_price,
|
|
gateway_name="DB",
|
|
)
|
|
return bar
|
|
|
|
|
|
class DbTickData(Document):
|
|
"""
|
|
Tick data for database storage.
|
|
|
|
Index is defined unique with (datetime, symbol)
|
|
"""
|
|
|
|
symbol: str = StringField()
|
|
exchange: str = StringField()
|
|
datetime: datetime = DateTimeField()
|
|
|
|
name: str = StringField()
|
|
volume: float = FloatField()
|
|
open_interest: float = FloatField()
|
|
last_price: float = FloatField()
|
|
last_volume: float = FloatField()
|
|
limit_up: float = FloatField()
|
|
limit_down: float = FloatField()
|
|
|
|
open_price: float = FloatField()
|
|
high_price: float = FloatField()
|
|
low_price: float = FloatField()
|
|
close_price: float = FloatField()
|
|
pre_close: float = FloatField()
|
|
|
|
bid_price_1: float = FloatField()
|
|
bid_price_2: float = FloatField()
|
|
bid_price_3: float = FloatField()
|
|
bid_price_4: float = FloatField()
|
|
bid_price_5: float = FloatField()
|
|
|
|
ask_price_1: float = FloatField()
|
|
ask_price_2: float = FloatField()
|
|
ask_price_3: float = FloatField()
|
|
ask_price_4: float = FloatField()
|
|
ask_price_5: float = FloatField()
|
|
|
|
bid_volume_1: float = FloatField()
|
|
bid_volume_2: float = FloatField()
|
|
bid_volume_3: float = FloatField()
|
|
bid_volume_4: float = FloatField()
|
|
bid_volume_5: float = FloatField()
|
|
|
|
ask_volume_1: float = FloatField()
|
|
ask_volume_2: float = FloatField()
|
|
ask_volume_3: float = FloatField()
|
|
ask_volume_4: float = FloatField()
|
|
ask_volume_5: float = FloatField()
|
|
|
|
meta = {
|
|
"indexes": [
|
|
{
|
|
"fields": ("datetime", "symbol", "exchange"),
|
|
"unique": True,
|
|
}
|
|
],
|
|
}
|
|
|
|
@staticmethod
|
|
def from_tick(tick: TickData):
|
|
"""
|
|
Generate DbTickData object from TickData.
|
|
"""
|
|
db_tick = DbTickData()
|
|
|
|
db_tick.symbol = tick.symbol
|
|
db_tick.exchange = tick.exchange.value
|
|
db_tick.datetime = tick.datetime
|
|
db_tick.name = tick.name
|
|
db_tick.volume = tick.volume
|
|
db_tick.open_interest = tick.open_interest
|
|
db_tick.last_price = tick.last_price
|
|
db_tick.last_volume = tick.last_volume
|
|
db_tick.limit_up = tick.limit_up
|
|
db_tick.limit_down = tick.limit_down
|
|
db_tick.open_price = tick.open_price
|
|
db_tick.high_price = tick.high_price
|
|
db_tick.low_price = tick.low_price
|
|
db_tick.pre_close = tick.pre_close
|
|
|
|
db_tick.bid_price_1 = tick.bid_price_1
|
|
db_tick.ask_price_1 = tick.ask_price_1
|
|
db_tick.bid_volume_1 = tick.bid_volume_1
|
|
db_tick.ask_volume_1 = tick.ask_volume_1
|
|
|
|
if tick.bid_price_2:
|
|
db_tick.bid_price_2 = tick.bid_price_2
|
|
db_tick.bid_price_3 = tick.bid_price_3
|
|
db_tick.bid_price_4 = tick.bid_price_4
|
|
db_tick.bid_price_5 = tick.bid_price_5
|
|
|
|
db_tick.ask_price_2 = tick.ask_price_2
|
|
db_tick.ask_price_3 = tick.ask_price_3
|
|
db_tick.ask_price_4 = tick.ask_price_4
|
|
db_tick.ask_price_5 = tick.ask_price_5
|
|
|
|
db_tick.bid_volume_2 = tick.bid_volume_2
|
|
db_tick.bid_volume_3 = tick.bid_volume_3
|
|
db_tick.bid_volume_4 = tick.bid_volume_4
|
|
db_tick.bid_volume_5 = tick.bid_volume_5
|
|
|
|
db_tick.ask_volume_2 = tick.ask_volume_2
|
|
db_tick.ask_volume_3 = tick.ask_volume_3
|
|
db_tick.ask_volume_4 = tick.ask_volume_4
|
|
db_tick.ask_volume_5 = tick.ask_volume_5
|
|
|
|
return db_tick
|
|
|
|
def to_tick(self):
|
|
"""
|
|
Generate TickData object from DbTickData.
|
|
"""
|
|
tick = TickData(
|
|
symbol=self.symbol,
|
|
exchange=Exchange(self.exchange),
|
|
datetime=self.datetime,
|
|
name=self.name,
|
|
volume=self.volume,
|
|
open_interest=self.open_interest,
|
|
last_price=self.last_price,
|
|
last_volume=self.last_volume,
|
|
limit_up=self.limit_up,
|
|
limit_down=self.limit_down,
|
|
open_price=self.open_price,
|
|
high_price=self.high_price,
|
|
low_price=self.low_price,
|
|
pre_close=self.pre_close,
|
|
bid_price_1=self.bid_price_1,
|
|
ask_price_1=self.ask_price_1,
|
|
bid_volume_1=self.bid_volume_1,
|
|
ask_volume_1=self.ask_volume_1,
|
|
gateway_name="DB",
|
|
)
|
|
|
|
if self.bid_price_2:
|
|
tick.bid_price_2 = self.bid_price_2
|
|
tick.bid_price_3 = self.bid_price_3
|
|
tick.bid_price_4 = self.bid_price_4
|
|
tick.bid_price_5 = self.bid_price_5
|
|
|
|
tick.ask_price_2 = self.ask_price_2
|
|
tick.ask_price_3 = self.ask_price_3
|
|
tick.ask_price_4 = self.ask_price_4
|
|
tick.ask_price_5 = self.ask_price_5
|
|
|
|
tick.bid_volume_2 = self.bid_volume_2
|
|
tick.bid_volume_3 = self.bid_volume_3
|
|
tick.bid_volume_4 = self.bid_volume_4
|
|
tick.bid_volume_5 = self.bid_volume_5
|
|
|
|
tick.ask_volume_2 = self.ask_volume_2
|
|
tick.ask_volume_3 = self.ask_volume_3
|
|
tick.ask_volume_4 = self.ask_volume_4
|
|
tick.ask_volume_5 = self.ask_volume_5
|
|
|
|
return tick
|
|
|
|
|
|
class MongoManager(BaseDatabaseManager):
|
|
def load_bar_data(
|
|
self,
|
|
symbol: str,
|
|
exchange: Exchange,
|
|
interval: Interval,
|
|
start: datetime,
|
|
end: datetime,
|
|
) -> Sequence[BarData]:
|
|
s = DbBarData.objects(
|
|
symbol=symbol,
|
|
exchange=exchange.value,
|
|
interval=interval.value,
|
|
datetime__gte=start,
|
|
datetime__lte=end,
|
|
)
|
|
data = [db_bar.to_bar() for db_bar in s]
|
|
return data
|
|
|
|
def load_tick_data(
|
|
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
|
) -> Sequence[TickData]:
|
|
s = DbTickData.objects(
|
|
symbol=symbol,
|
|
exchange=exchange.value,
|
|
datetime__gte=start,
|
|
datetime__lte=end,
|
|
)
|
|
data = [db_tick.to_tick() for db_tick in s]
|
|
return data
|
|
|
|
@staticmethod
|
|
def to_update_param(d):
|
|
return {
|
|
"set__" + k: v.value if isinstance(v, Enum) else v
|
|
for k, v in d.__dict__.items()
|
|
}
|
|
|
|
def save_bar_data(self, datas: Sequence[BarData]):
|
|
for d in datas:
|
|
updates = self.to_update_param(d)
|
|
updates.pop("set__gateway_name")
|
|
updates.pop("set__vt_symbol")
|
|
updates.pop("set__full_symbol")
|
|
updates.pop("set__adj_close_price")
|
|
updates.pop("set__bar_start_time")
|
|
(
|
|
DbBarData.objects(
|
|
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
|
|
).update_one(upsert=True, **updates)
|
|
)
|
|
|
|
def save_tick_data(self, datas: Sequence[TickData]):
|
|
for d in datas:
|
|
updates = self.to_update_param(d)
|
|
updates.pop("set__gateway_name")
|
|
updates.pop("set__vt_symbol")
|
|
updates.pop("set__depth")
|
|
updates.pop("set__full_symbol")
|
|
updates.pop("set__timestamp")
|
|
(
|
|
DbTickData.objects(
|
|
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
|
|
).update_one(upsert=True, **updates)
|
|
)
|
|
|
|
def get_newest_bar_data(
|
|
self, symbol: str, exchange: "Exchange", interval: "Interval"
|
|
) -> Optional["BarData"]:
|
|
s = (
|
|
DbBarData.objects(symbol=symbol, exchange=exchange.value)
|
|
.order_by("-datetime")
|
|
.first()
|
|
)
|
|
if s:
|
|
return s.to_bar()
|
|
return None
|
|
|
|
def get_newest_tick_data(
|
|
self, symbol: str, exchange: "Exchange"
|
|
) -> Optional["TickData"]:
|
|
s = (
|
|
DbTickData.objects(symbol=symbol, exchange=exchange.value)
|
|
.order_by("-datetime")
|
|
.first()
|
|
)
|
|
if s:
|
|
return s.to_tick()
|
|
return None
|
|
|
|
def clean(self, symbol: str):
|
|
DbTickData.objects(symbol=symbol).delete()
|
|
DbBarData.objects(symbol=symbol).delete()
|
|
|