未验证 提交 e21933ff 编写于 作者: B Bai Yifan 提交者: GitHub

fix dygraph quant demo save model (#549)

上级 42eb23d3
...@@ -319,20 +319,22 @@ def compress(args): ...@@ -319,20 +319,22 @@ def compress(args):
paddle.save(net.state_dict(), model_prefix + ".pdparams") paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(opt.state_dict(), model_prefix + ".pdopt") paddle.save(opt.state_dict(), model_prefix + ".pdopt")
# load best model
load_dygraph_pretrain(net, os.path.join(args.model_save_dir, "best_model"))
############################################################################################################ ############################################################################################################
# 3. Save quant aware model # 3. Save quant aware model
############################################################################################################ ############################################################################################################
path = os.path.join(args.model_save_dir, "inference_model", 'qat_model') if paddle.distributed.get_rank() == 0:
quanter.save_quantized_model( # load best model
net, load_dygraph_pretrain(net,
path, os.path.join(args.model_save_dir, "best_model"))
input_spec=[
paddle.static.InputSpec( path = os.path.join(args.model_save_dir, "inference_model", 'qat_model')
shape=[None, 3, 224, 224], dtype='float32') quanter.save_quantized_model(
]) net,
path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
])
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册