未验证 提交 377a5ce1 编写于 作者: Q qingqing01 提交者: GitHub

Clean fluid.compiler.CompiledProgram (#891)

上级 b5fa78bd
...@@ -305,7 +305,7 @@ def main(): ...@@ -305,7 +305,7 @@ def main():
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
# whether output bbox is normalized in model output layer # whether output bbox is normalized in model output layer
is_bbox_normalized = False is_bbox_normalized = False
......
...@@ -276,7 +276,7 @@ def main(): ...@@ -276,7 +276,7 @@ def main():
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
# parse eval fetches # parse eval fetches
extra_keys = [] extra_keys = []
......
...@@ -323,7 +323,7 @@ def main(): ...@@ -323,7 +323,7 @@ def main():
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
if FLAGS.eval: if FLAGS.eval:
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
train_loader.set_sample_list_generator(train_reader, place) train_loader.set_sample_list_generator(train_reader, place)
......
...@@ -119,8 +119,7 @@ def main(): ...@@ -119,8 +119,7 @@ def main():
logger.info("pruned FLOPS: {}".format( logger.info("pruned FLOPS: {}".format(
float(base_flops - pruned_flops) / base_flops)) float(base_flops - pruned_flops) / base_flops))
compile_program = fluid.compiler.CompiledProgram( compile_program = fluid.CompiledProgram(eval_prog).with_data_parallel()
eval_prog).with_data_parallel()
assert cfg.metric != 'OID', "eval process of OID dataset \ assert cfg.metric != 'OID', "eval process of OID dataset \
is not supported." is not supported."
......
...@@ -215,7 +215,7 @@ def main(): ...@@ -215,7 +215,7 @@ def main():
logger.info("FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}".format( logger.info("FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}".format(
float(base_flops - pruned_flops) / base_flops, base_flops, float(base_flops - pruned_flops) / base_flops, base_flops,
pruned_flops)) pruned_flops))
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
if FLAGS.resume_checkpoint: if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
......
...@@ -132,8 +132,7 @@ def main(): ...@@ -132,8 +132,7 @@ def main():
checkpoint.load_params(exe, eval_prog, cfg.weights) checkpoint.load_params(exe, eval_prog, cfg.weights)
eval_prog = convert(eval_prog, place, config, save_int8=False) eval_prog = convert(eval_prog, place, config, save_int8=False)
compile_program = fluid.compiler.CompiledProgram( compile_program = fluid.CompiledProgram(eval_prog).with_data_parallel()
eval_prog).with_data_parallel()
results = eval_run(exe, compile_program, loader, keys, values, cls, cfg, results = eval_run(exe, compile_program, loader, keys, values, cls, cfg,
sub_eval_prog, sub_keys, sub_values) sub_eval_prog, sub_keys, sub_values)
......
...@@ -200,7 +200,7 @@ def main(): ...@@ -200,7 +200,7 @@ def main():
if FLAGS.eval: if FLAGS.eval:
# insert quantize op in eval_prog # insert quantize op in eval_prog
eval_prog = quant_aware(eval_prog, place, config, for_test=True) eval_prog = quant_aware(eval_prog, place, config, for_test=True)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
start_iter = 0 start_iter = 0
if FLAGS.resume_checkpoint: if FLAGS.resume_checkpoint:
......
...@@ -122,7 +122,7 @@ def main(): ...@@ -122,7 +122,7 @@ def main():
def test(program): def test(program):
compiled_eval_prog = fluid.compiler.CompiledProgram(program) compiled_eval_prog = fluid.CompiledProgram(program)
results = eval_run( results = eval_run(
exe, exe,
......
...@@ -88,8 +88,7 @@ def main(): ...@@ -88,8 +88,7 @@ def main():
cfg.metric, json_directory=FLAGS.output_eval, dataset=dataset) cfg.metric, json_directory=FLAGS.output_eval, dataset=dataset)
return return
compile_program = fluid.compiler.CompiledProgram( compile_program = fluid.CompiledProgram(eval_prog).with_data_parallel()
eval_prog).with_data_parallel()
assert cfg.metric != 'OID', "eval process of OID dataset \ assert cfg.metric != 'OID', "eval process of OID dataset \
is not supported." is not supported."
......
...@@ -180,7 +180,7 @@ def main(): ...@@ -180,7 +180,7 @@ def main():
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
if FLAGS.eval: if FLAGS.eval:
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册