未验证 提交 a494889c 编写于 作者: C ceci3 提交者: GitHub

fix pact train.py (#5472)

上级 397c3673
......@@ -302,16 +302,16 @@ def main(args):
paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir, 'best.pdopt'))
if args.pact_quant:
input_spec = [
InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
]
quanter.save_quantized_model(
model,
os.path.join(args.output_dir, "qat_inference"),
input_spec=input_spec)
print("QAT inference model saved in {args.output_dir}")
if args.pact_quant:
input_spec = [
InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
]
quanter.save_quantized_model(
model,
os.path.join(args.output_dir, "qat_inference"),
input_spec=input_spec)
print("QAT inference model saved in {args.output_dir}")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册