未验证 提交 4d70d009 编写于 作者: Q qingqing01 提交者: GitHub

Clean fluid.compiler.CompiledProgram (#892)

上级 bb409b0b
......@@ -305,7 +305,7 @@ def main():
build_strategy=build_strategy,
exec_strategy=exec_strategy)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
compiled_eval_prog = fluid.CompiledProgram(eval_prog)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
......
......@@ -276,7 +276,7 @@ def main():
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
compiled_eval_prog = fluid.CompiledProgram(eval_prog)
# parse eval fetches
extra_keys = []
......
......@@ -323,7 +323,7 @@ def main():
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if FLAGS.eval:
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
compiled_eval_prog = fluid.CompiledProgram(eval_prog)
train_loader.set_sample_list_generator(train_reader, place)
......
......@@ -119,8 +119,7 @@ def main():
logger.info("pruned FLOPS: {}".format(
float(base_flops - pruned_flops) / base_flops))
compile_program = fluid.compiler.CompiledProgram(
eval_prog).with_data_parallel()
compile_program = fluid.CompiledProgram(eval_prog).with_data_parallel()
assert cfg.metric != 'OID', "eval process of OID dataset \
is not supported."
......
......@@ -215,7 +215,7 @@ def main():
logger.info("FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}".format(
float(base_flops - pruned_flops) / base_flops, base_flops,
pruned_flops))
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
compiled_eval_prog = fluid.CompiledProgram(eval_prog)
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
......
......@@ -132,8 +132,7 @@ def main():
checkpoint.load_params(exe, eval_prog, cfg.weights)
eval_prog = convert(eval_prog, place, config, save_int8=False)
compile_program = fluid.compiler.CompiledProgram(
eval_prog).with_data_parallel()
compile_program = fluid.CompiledProgram(eval_prog).with_data_parallel()
results = eval_run(exe, compile_program, loader, keys, values, cls, cfg,
sub_eval_prog, sub_keys, sub_values)
......
......@@ -200,7 +200,7 @@ def main():
if FLAGS.eval:
# insert quantize op in eval_prog
eval_prog = quant_aware(eval_prog, place, config, for_test=True)
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
compiled_eval_prog = fluid.CompiledProgram(eval_prog)
start_iter = 0
if FLAGS.resume_checkpoint:
......
......@@ -122,7 +122,7 @@ def main():
def test(program):
compiled_eval_prog = fluid.compiler.CompiledProgram(program)
compiled_eval_prog = fluid.CompiledProgram(program)
results = eval_run(
exe,
......
......@@ -88,8 +88,7 @@ def main():
cfg.metric, json_directory=FLAGS.output_eval, dataset=dataset)
return
compile_program = fluid.compiler.CompiledProgram(
eval_prog).with_data_parallel()
compile_program = fluid.CompiledProgram(eval_prog).with_data_parallel()
assert cfg.metric != 'OID', "eval process of OID dataset \
is not supported."
......
......@@ -180,7 +180,7 @@ def main():
exec_strategy=exec_strategy)
if FLAGS.eval:
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
compiled_eval_prog = fluid.CompiledProgram(eval_prog)
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册