提交 8bd4a577 编写于 作者: J jerrywgz 提交者: GitHub

Fix train+eval in PaddleDetection(#2847)

上级 0b3f8b68
......@@ -9,7 +9,7 @@ log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/fpn/faster_rcnn_r50_fpn_1x/model_final
weights: output/faster_rcnn_r50_fpn_1x/model_final
num_classes: 81
FasterRCNN:
......
......@@ -149,7 +149,11 @@ class MaskRCNN(object):
cond = fluid.layers.less_than(x=bbox_size, y=size)
mask_pred = fluid.layers.create_global_var(
shape=[1], value=0.0, dtype='float32', persistable=False)
shape=[1],
value=0.0,
dtype='float32',
persistable=False,
name='mask_pred')
with fluid.layers.control_flow.Switch() as switch:
with switch.case(cond):
......
......@@ -86,7 +86,6 @@ def main():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
model = create(main_arch)
lr_builder = create('LearningRate')
optim_builder = create('OptimizerBuilder')
......@@ -95,6 +94,7 @@ def main():
train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
train_pyreader, feed_vars = create_feed(train_feed)
train_fetches = model.train(feed_vars)
loss = train_fetches['loss']
......@@ -113,6 +113,7 @@ def main():
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
eval_pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
......@@ -120,8 +121,9 @@ def main():
eval_reader = create_reader(eval_feed)
eval_pyreader.decorate_sample_list_generator(eval_reader, place)
# parse train fetches
extra_keys = ['im_info', 'im_id'] if cfg.metric == 'COCO' else []
# parse eval fetches
extra_keys = ['im_info', 'im_id',
'im_shape'] if cfg.metric == 'COCO' else []
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys)
......@@ -132,7 +134,7 @@ def main():
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
# only enable sync_bn in multi GPU devices
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
and cfg.use_gpu
and cfg.use_gpu
train_compile_program = fluid.compiler.CompiledProgram(
train_prog).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册