diff --git a/czsc/cobra/utils.py b/czsc/cobra/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0cd3ba8d5c358d6cffd67875d899a0485ceb93 --- /dev/null +++ b/czsc/cobra/utils.py @@ -0,0 +1,93 @@ +# coding: utf-8 + +import numpy as np +import pandas as pd +from czsc.ta import KDJ +import traceback +from typing import List, Union + + +def down_cross_count(x1: Union[List, np.array], x2: Union[List, np.array]): + """输入两个序列,计算 x1 下穿 x2 的次数 + + :param x1: list + :param x2: list + :return: int + + example: + ======== + >>> x1 = [1, 1, 3, 4, 5, 12, 9, 8] + >>> x2 = [2, 2, 1, 5, 8, 9, 10, 10] + >>> print("x1 下穿 x2 的次数:{}".format(down_cross_count(x1, x2))) + >>> print("x1 上穿 x2 的次数:{}".format(down_cross_count(x2, x1))) + """ + x = np.array(x1) < np.array(x2) + num = 0 + for i in range(len(x) - 1): + b1, b2 = x[i], x[i + 1] + if b2 and b1 != b2: + num += 1 + return num + + +def kdj_gold_cross(kline: Union[List[dict], pd.DataFrame], just: bool = True): + """输入K线,判断KDJ是否金叉 + + :param kline: pd.DataFrame + :param just: bool + 是否是刚刚形成 + :return: bool + """ + try: + if isinstance(kline, list): + close = [x['close'] for x in kline] + high = [x['high'] for x in kline] + low = [x['low'] for x in kline] + else: + close = kline.close.values + high = kline.high.values + low = kline.low.values + + k, d, j = KDJ(close=close, high=high, low=low) + + if not just and j[-1] > k[-1] > d[-1]: + return True + elif just and j[-1] > k[-1] > d[-1] and not (j[-2] > k[-2] > d[-2]): + return True + else: + return False + except: + print("{}: run error kdj_gold_cross".format(kline.iloc[0]['symbol'])) + traceback.print_exc() + return False + +def kdj_dead_cross(kline: Union[List[dict], pd.DataFrame], just: bool = True): + """输入K线,判断KDJ是否死叉 + + :param kline: pd.DataFrame + :param just: bool + 是否是刚刚形成 + :return: bool + """ + try: + if isinstance(kline, list): + close = [x['close'] for x in kline] + high = [x['high'] for x in kline] + low = [x['low'] for x in kline] + else: + close = kline.close.values + high = kline.high.values + low = kline.low.values + + k, d, j = KDJ(close=close, high=high, low=low) + + if not just and j[-1] < k[-1] < d[-1]: + return True + elif just and j[-1] < k[-1] < d[-1] and not (j[-2] < k[-2] < d[-2]): + return True + else: + return False + except: + print("{}: run error kdj_dead_cross".format(kline.iloc[0]['symbol'])) + traceback.print_exc() + return False diff --git a/test/test_cobra_utils.py b/test/test_cobra_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52cd266874f442e0e88cff27b0dde6aae974e966 --- /dev/null +++ b/test/test_cobra_utils.py @@ -0,0 +1,41 @@ +# coding: utf-8 +import sys +import warnings + +sys.path.insert(0, '.') +sys.path.insert(0, '..') +import os +import numpy as np +import pandas as pd +import czsc +from czsc.cobra.utils import down_cross_count, kdj_gold_cross, kdj_dead_cross + +warnings.warn("czsc version is {}".format(czsc.__version__)) + +# cur_path = os.path.split(os.path.realpath(__file__))[0] +cur_path = "./test" + +def test_kdj_cross(): + file_kline = os.path.join(cur_path, "data/000001.SH_D.csv") + kline = pd.read_csv(file_kline, encoding="utf-8") + bars = kline.to_dict("records") + + assert not kdj_gold_cross(kline, just=False) + assert not kdj_gold_cross(bars, just=False) + assert kdj_dead_cross(kline, just=False) + assert kdj_dead_cross(bars, just=False) + assert not kdj_dead_cross(kline, just=True) + + +def test_cross_count(): + x1 = [1, 1, 3, 4, 5, 12, 9, 8] + x2 = [2, 2, 1, 5, 8, 9, 10, 10] + assert down_cross_count(x1, x2) == 2 + assert down_cross_count(np.array(x1), np.array(x2)) == 2 + assert down_cross_count(x2, x1) == 2 + assert down_cross_count(np.array(x2), np.array(x1)) == 2 + + + + +