From b18151e630ab9a5853d7067dce4d68b901ac4da0 Mon Sep 17 00:00:00 2001 From: xiegegege <46314656+xiegegege@users.noreply.github.com> Date: Tue, 17 Mar 2020 16:46:56 +0800 Subject: [PATCH] add ce for PaddleDetection (#343) --- tools/train.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tools/train.py b/tools/train.py index cfca3565d..4345d0933 100644 --- a/tools/train.py +++ b/tools/train.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import time import numpy as np +import random import datetime from collections import deque @@ -60,11 +61,14 @@ def main(): FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env if FLAGS.dist: trainer_id = int(env['PADDLE_TRAINER_ID']) - import random local_seed = (99 + trainer_id) random.seed(local_seed) np.random.seed(local_seed) + if FLAGS.enable_ce: + random.seed(0) + np.random.seed(0) + cfg = load_config(FLAGS.config) if 'architecture' in cfg: main_arch = cfg.architecture @@ -101,6 +105,9 @@ def main(): # build program startup_prog = fluid.Program() train_prog = fluid.Program() + if FLAGS.enable_ce: + startup_prog.random_seed = 1000 + train_prog.random_seed = 1000 with fluid.program_guard(train_prog, startup_prog): with fluid.unique_name.guard(): model = create(main_arch) @@ -319,5 +326,11 @@ if __name__ == '__main__': type=str, default="tb_log_dir/scalar", help='Tensorboard logging directory for scalar.') + parser.add_argument( + "--enable_ce", + type=bool, + default=False, + help="If set True, enable continuous evaluation job." + "This flag is only used for internal test.") FLAGS = parser.parse_args() main() -- GitLab