diff --git a/PaddleRec/word2vec/train.py b/PaddleRec/word2vec/train.py index 087c9b1d9b739df91f05480b2b7a059e59ca24a1..929abf4dc4272fec154bc29415d9651c3dbd5b74 100644 --- a/PaddleRec/word2vec/train.py +++ b/PaddleRec/word2vec/train.py @@ -84,6 +84,9 @@ def parse_args(): required=False, default=False, help='print speed or not , (default: False)') + parser.add_argument( + '--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.') + return parser.parse_args() @@ -195,6 +198,11 @@ def GetFileList(data_path): def train(args): + # add ce + if args.enable_ce: + SEED = 102 + fluid.default_main_program().random_seed = SEED + fluid.default_startup_program().random_seed = SEED if not os.path.isdir(args.model_output_dir): os.mkdir(args.model_output_dir)