提交 4ffbe264 编写于 作者: X xiegegege 提交者: qingqing01

add ce for PaddleDetection (#4078)

* add ce for PaddleDetection
上级 95f46d33
......@@ -19,6 +19,7 @@ from __future__ import print_function
import os
import time
import numpy as np
import random
import datetime
from collections import deque
......@@ -62,10 +63,13 @@ def main():
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
import random
local_seed = (99 + trainer_id)
random.seed(local_seed)
np.random.seed(local_seed)
if FLAGS.enable_ce:
random.seed(0)
np.random.seed(0)
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
......@@ -112,6 +116,9 @@ def main():
# build program
startup_prog = fluid.Program()
train_prog = fluid.Program()
if FLAGS.enable_ce:
startup_prog.random_seed = 1000
train_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
......@@ -258,6 +265,11 @@ def main():
strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
it, np.mean(outs[-1]), logs, time_cost, eta)
logger.info(strs)
#only for continuous evaluation
if FLAGS.enable_ce and it == cfg.max_iters - 1:
print("kpis\t{}_train_loss\t{}".format(cfg.architecture, stats['loss']))
print("kpis\t{}_train_time\t{}".format(cfg.architecture, time_cost))
# profiler tools, used for benchmark
if FLAGS.is_profiler and it == 5:
......@@ -342,6 +354,12 @@ if __name__ == '__main__':
type=str,
default="tb_log_dir/scalar",
help='Tensorboard logging directory for scalar.')
parser.add_argument(
'--enable_ce',
type=bool,
default=False,
help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.")
#NOTE:args for profiler tools, used for benchmark
parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册