diff --git a/tools/train.py b/tools/train.py index cfca3565d6560492e4a1d0fd5c415439a8f85c3c..4345d09331f7b2e3c53ea89b0f0c3eba7824a532 100644 --- a/tools/train.py +++ b/tools/train.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import time import numpy as np +import random import datetime from collections import deque @@ -60,11 +61,14 @@ def main(): FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env if FLAGS.dist: trainer_id = int(env['PADDLE_TRAINER_ID']) - import random local_seed = (99 + trainer_id) random.seed(local_seed) np.random.seed(local_seed) + if FLAGS.enable_ce: + random.seed(0) + np.random.seed(0) + cfg = load_config(FLAGS.config) if 'architecture' in cfg: main_arch = cfg.architecture @@ -101,6 +105,9 @@ def main(): # build program startup_prog = fluid.Program() train_prog = fluid.Program() + if FLAGS.enable_ce: + startup_prog.random_seed = 1000 + train_prog.random_seed = 1000 with fluid.program_guard(train_prog, startup_prog): with fluid.unique_name.guard(): model = create(main_arch) @@ -319,5 +326,11 @@ if __name__ == '__main__': type=str, default="tb_log_dir/scalar", help='Tensorboard logging directory for scalar.') + parser.add_argument( + "--enable_ce", + type=bool, + default=False, + help="If set True, enable continuous evaluation job." + "This flag is only used for internal test.") FLAGS = parser.parse_args() main()