未验证 提交 8635791f 编写于 作者: W whs 提交者: GitHub

Fix pruning export (#2350)

上级 73c861ef
......@@ -63,6 +63,9 @@ def main():
test_fetches = model.test(feed_vars)
infer_prog = infer_prog.clone(True)
exe.run(startup_prog)
checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
pruned_params = FLAGS.pruned_params
assert (
FLAGS.pruned_params is not None
......@@ -90,13 +93,9 @@ def main():
logger.info("pruned FLOPS: {}".format(
float(base_flops - pruned_flops) / base_flops))
exe.run(startup_prog)
checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
dump_infer_config(FLAGS, cfg)
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
if __name__ == '__main__':
enable_static_mode()
parser = ArgsParser()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册