提交 4ffbe264 编写于 作者: X xiegegege 提交者: qingqing01

add ce for PaddleDetection (#4078)

* add ce for PaddleDetection
上级 95f46d33
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
import time import time
import numpy as np import numpy as np
import random
import datetime import datetime
from collections import deque from collections import deque
...@@ -62,10 +63,13 @@ def main(): ...@@ -62,10 +63,13 @@ def main():
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist: if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID']) trainer_id = int(env['PADDLE_TRAINER_ID'])
import random
local_seed = (99 + trainer_id) local_seed = (99 + trainer_id)
random.seed(local_seed) random.seed(local_seed)
np.random.seed(local_seed) np.random.seed(local_seed)
if FLAGS.enable_ce:
random.seed(0)
np.random.seed(0)
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
if 'architecture' in cfg: if 'architecture' in cfg:
...@@ -112,6 +116,9 @@ def main(): ...@@ -112,6 +116,9 @@ def main():
# build program # build program
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
if FLAGS.enable_ce:
startup_prog.random_seed = 1000
train_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = create(main_arch) model = create(main_arch)
...@@ -258,6 +265,11 @@ def main(): ...@@ -258,6 +265,11 @@ def main():
strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format( strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
it, np.mean(outs[-1]), logs, time_cost, eta) it, np.mean(outs[-1]), logs, time_cost, eta)
logger.info(strs) logger.info(strs)
#only for continuous evaluation
if FLAGS.enable_ce and it == cfg.max_iters - 1:
print("kpis\t{}_train_loss\t{}".format(cfg.architecture, stats['loss']))
print("kpis\t{}_train_time\t{}".format(cfg.architecture, time_cost))
# profiler tools, used for benchmark # profiler tools, used for benchmark
if FLAGS.is_profiler and it == 5: if FLAGS.is_profiler and it == 5:
...@@ -342,6 +354,12 @@ if __name__ == '__main__': ...@@ -342,6 +354,12 @@ if __name__ == '__main__':
type=str, type=str,
default="tb_log_dir/scalar", default="tb_log_dir/scalar",
help='Tensorboard logging directory for scalar.') help='Tensorboard logging directory for scalar.')
parser.add_argument(
'--enable_ce',
type=bool,
default=False,
help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.")
#NOTE:args for profiler tools, used for benchmark #NOTE:args for profiler tools, used for benchmark
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册