提交 c889c814 编写于 作者: M mmglove 提交者: zhoushiyu

add ce for dcn deepfm xdeepfm models (#4102)

上级 c008e600
...@@ -79,5 +79,7 @@ def parse_args(): ...@@ -79,5 +79,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--clip_by_norm', type=float, default=100.0, help="gradient clip norm") '--clip_by_norm', type=float, default=100.0, help="gradient clip norm")
parser.add_argument('--print_steps', type=int, default=100) parser.add_argument('--print_steps', type=int, default=100)
parser.add_argument(
'--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
return parser.parse_args() return parser.parse_args()
...@@ -21,6 +21,12 @@ def train(args): ...@@ -21,6 +21,12 @@ def train(args):
:param args: hyperparams of model :param args: hyperparams of model
:return: :return:
""" """
# ce
if args.enable_ce:
SEED = 102
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
cat_feat_dims_dict = OrderedDict() cat_feat_dims_dict = OrderedDict()
for line in open(args.cat_feat_num): for line in open(args.cat_feat_num):
spls = line.strip().split() spls = line.strip().split()
......
...@@ -67,5 +67,7 @@ def parse_args(): ...@@ -67,5 +67,7 @@ def parse_args():
'--reg', type=float, default=1e-4, help=' (default: 1e-4)') '--reg', type=float, default=1e-4, help=' (default: 1e-4)')
parser.add_argument('--num_field', type=int, default=39) parser.add_argument('--num_field', type=int, default=39)
parser.add_argument('--num_feat', type=int, default=1086460) # 2090493 parser.add_argument('--num_feat', type=int, default=1086460) # 2090493
parser.add_argument(
'--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
return parser.parse_args() return parser.parse_args()
...@@ -11,6 +11,12 @@ import utils ...@@ -11,6 +11,12 @@ import utils
def train(): def train():
args = parse_args() args = parse_args()
# add ce
if args.enable_ce:
SEED = 102
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
print('---------- Configuration Arguments ----------') print('---------- Configuration Arguments ----------')
for key, value in args.__dict__.items(): for key, value in args.__dict__.items():
print(key + ':' + str(value)) print(key + ':' + str(value))
......
...@@ -75,5 +75,7 @@ def parse_args(): ...@@ -75,5 +75,7 @@ def parse_args():
required=False, required=False,
default=False, default=False,
help='embedding will use sparse or not, (default: False)') help='embedding will use sparse or not, (default: False)')
parser.add_argument(
'--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
return parser.parse_args() return parser.parse_args()
...@@ -9,6 +9,12 @@ import utils ...@@ -9,6 +9,12 @@ import utils
def train(): def train():
args = parse_args() args = parse_args()
# add ce
if args.enable_ce:
SEED = 102
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
print(args) print(args)
if not os.path.isdir(args.model_output_dir): if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir) os.mkdir(args.model_output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册