未验证 提交 b18151e6 编写于 作者: X xiegegege 提交者: GitHub

add ce for PaddleDetection (#343)

上级 2d385d25
......@@ -19,6 +19,7 @@ from __future__ import print_function
import os
import time
import numpy as np
import random
import datetime
from collections import deque
......@@ -60,11 +61,14 @@ def main():
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
import random
local_seed = (99 + trainer_id)
random.seed(local_seed)
np.random.seed(local_seed)
if FLAGS.enable_ce:
random.seed(0)
np.random.seed(0)
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
......@@ -101,6 +105,9 @@ def main():
# build program
startup_prog = fluid.Program()
train_prog = fluid.Program()
if FLAGS.enable_ce:
startup_prog.random_seed = 1000
train_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
......@@ -319,5 +326,11 @@ if __name__ == '__main__':
type=str,
default="tb_log_dir/scalar",
help='Tensorboard logging directory for scalar.')
parser.add_argument(
"--enable_ce",
type=bool,
default=False,
help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.")
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册