提交 8bd4a577 编写于 作者: J jerrywgz 提交者: GitHub

Fix train+eval in PaddleDetection(#2847)

上级 0b3f8b68
...@@ -9,7 +9,7 @@ log_smooth_window: 20 ...@@ -9,7 +9,7 @@ log_smooth_window: 20
save_dir: output save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO metric: COCO
weights: output/fpn/faster_rcnn_r50_fpn_1x/model_final weights: output/faster_rcnn_r50_fpn_1x/model_final
num_classes: 81 num_classes: 81
FasterRCNN: FasterRCNN:
......
...@@ -149,7 +149,11 @@ class MaskRCNN(object): ...@@ -149,7 +149,11 @@ class MaskRCNN(object):
cond = fluid.layers.less_than(x=bbox_size, y=size) cond = fluid.layers.less_than(x=bbox_size, y=size)
mask_pred = fluid.layers.create_global_var( mask_pred = fluid.layers.create_global_var(
shape=[1], value=0.0, dtype='float32', persistable=False) shape=[1],
value=0.0,
dtype='float32',
persistable=False,
name='mask_pred')
with fluid.layers.control_flow.Switch() as switch: with fluid.layers.control_flow.Switch() as switch:
with switch.case(cond): with switch.case(cond):
......
...@@ -86,7 +86,6 @@ def main(): ...@@ -86,7 +86,6 @@ def main():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
model = create(main_arch)
lr_builder = create('LearningRate') lr_builder = create('LearningRate')
optim_builder = create('OptimizerBuilder') optim_builder = create('OptimizerBuilder')
...@@ -95,6 +94,7 @@ def main(): ...@@ -95,6 +94,7 @@ def main():
train_prog = fluid.Program() train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = create(main_arch)
train_pyreader, feed_vars = create_feed(train_feed) train_pyreader, feed_vars = create_feed(train_feed)
train_fetches = model.train(feed_vars) train_fetches = model.train(feed_vars)
loss = train_fetches['loss'] loss = train_fetches['loss']
...@@ -113,6 +113,7 @@ def main(): ...@@ -113,6 +113,7 @@ def main():
eval_prog = fluid.Program() eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog): with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
model = create(main_arch)
eval_pyreader, feed_vars = create_feed(eval_feed) eval_pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars) fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
...@@ -120,8 +121,9 @@ def main(): ...@@ -120,8 +121,9 @@ def main():
eval_reader = create_reader(eval_feed) eval_reader = create_reader(eval_feed)
eval_pyreader.decorate_sample_list_generator(eval_reader, place) eval_pyreader.decorate_sample_list_generator(eval_reader, place)
# parse train fetches # parse eval fetches
extra_keys = ['im_info', 'im_id'] if cfg.metric == 'COCO' else [] extra_keys = ['im_info', 'im_id',
'im_shape'] if cfg.metric == 'COCO' else []
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys) extra_keys)
...@@ -132,7 +134,7 @@ def main(): ...@@ -132,7 +134,7 @@ def main():
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
# only enable sync_bn in multi GPU devices # only enable sync_bn in multi GPU devices
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \ build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
and cfg.use_gpu and cfg.use_gpu
train_compile_program = fluid.compiler.CompiledProgram( train_compile_program = fluid.compiler.CompiledProgram(
train_prog).with_data_parallel( train_prog).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册