未验证 提交 b18151e6 编写于 作者: X xiegegege 提交者: GitHub

add ce for PaddleDetection (#343)

上级 2d385d25
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
import time import time
import numpy as np import numpy as np
import random
import datetime import datetime
from collections import deque from collections import deque
...@@ -60,11 +61,14 @@ def main(): ...@@ -60,11 +61,14 @@ def main():
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist: if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID']) trainer_id = int(env['PADDLE_TRAINER_ID'])
import random
local_seed = (99 + trainer_id) local_seed = (99 + trainer_id)
random.seed(local_seed) random.seed(local_seed)
np.random.seed(local_seed) np.random.seed(local_seed)
if FLAGS.enable_ce:
random.seed(0)
np.random.seed(0)
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
if 'architecture' in cfg: if 'architecture' in cfg:
main_arch = cfg.architecture main_arch = cfg.architecture
...@@ -101,6 +105,9 @@ def main(): ...@@ -101,6 +105,9 @@ def main():
# build program # build program
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
if FLAGS.enable_ce:
startup_prog.random_seed = 1000
train_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = create(main_arch) model = create(main_arch)
...@@ -319,5 +326,11 @@ if __name__ == '__main__': ...@@ -319,5 +326,11 @@ if __name__ == '__main__':
type=str, type=str,
default="tb_log_dir/scalar", default="tb_log_dir/scalar",
help='Tensorboard logging directory for scalar.') help='Tensorboard logging directory for scalar.')
parser.add_argument(
"--enable_ce",
type=bool,
default=False,
help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册