未验证 提交 a494889c 编写于 作者: C ceci3 提交者: GitHub

fix pact train.py (#5472)

上级 397c3673
...@@ -302,16 +302,16 @@ def main(args): ...@@ -302,16 +302,16 @@ def main(args):
paddle.save(optimizer.state_dict(), paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir, 'best.pdopt')) os.path.join(args.output_dir, 'best.pdopt'))
if args.pact_quant: if args.pact_quant:
input_spec = [ input_spec = [
InputSpec( InputSpec(
shape=[None, 3, 224, 224], dtype='float32') shape=[None, 3, 224, 224], dtype='float32')
] ]
quanter.save_quantized_model( quanter.save_quantized_model(
model, model,
os.path.join(args.output_dir, "qat_inference"), os.path.join(args.output_dir, "qat_inference"),
input_spec=input_spec) input_spec=input_spec)
print("QAT inference model saved in {args.output_dir}") print("QAT inference model saved in {args.output_dir}")
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册