未验证 提交 da4c6757 编写于 作者: B Bai Yifan 提交者: GitHub

Fix pact demo dataloader (#526)

* fix bug

* code fix
上级 3d2a5924
......@@ -157,7 +157,9 @@ def compress(args):
places=places,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=args.batch_size,
use_shared_memory=False,
shuffle=True,
num_workers=1)
......@@ -166,7 +168,9 @@ def compress(args):
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
batch_size=args.batch_size,
use_shared_memory=False,
shuffle=False)
if args.analysis:
......@@ -372,13 +376,8 @@ def compress(args):
ckpt_path = args.checkpoint_dir
assert args.checkpoint_epoch is not None, "checkpoint_epoch must be set"
start_epoch = args.checkpoint_epoch
paddle.static.load_vars(
exe, dirname=args.checkpoint_dir, main_program=val_program)
start_step = start_epoch * int(
math.ceil(float(args.total_images) / args.batch_size))
v = paddle.static.global_scope().find_var(
'@LR_DECAY_COUNTER@').get_tensor()
v.set(np.array([start_step]).astype(np.float32), place)
paddle.static.load(
executor=exe, model_path=args.checkpoint_dir, program=val_program)
best_eval_acc1 = 0
best_acc1_epoch = 0
......@@ -391,22 +390,20 @@ def compress(args):
_logger.info("Best Validation Acc1: {:.6f}, at epoch {}".format(
best_eval_acc1, best_acc1_epoch))
paddle.static.save(
exe,
dirname=os.path.join(args.output_dir, str(i)),
main_program=val_program)
model_path=os.path.join(args.output_dir, str(i)),
program=val_program)
if acc1 > best_acc1:
best_acc1 = acc1
best_epoch = i
paddle.static.save(
exe,
dirname=os.path.join(args.output_dir, 'best_model'),
main_program=val_program)
model_path=os.path.join(args.output_dir, 'best_model'),
program=val_program)
if os.path.exists(os.path.join(args.output_dir, 'best_model')):
if os.path.exists(os.path.join(args.output_dir, 'best_model.pdparams')):
paddle.static.load(
exe,
dirname=os.path.join(args.output_dir, 'best_model'),
main_program=val_program)
executor=exe,
model_path=os.path.join(args.output_dir, 'best_model'),
program=val_program)
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
......
......@@ -256,9 +256,9 @@ def compress(args):
model_path=os.path.join(args.checkpoint_dir, 'best_model'))
if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')):
paddle.static.load(
exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'),
main_program=val_program)
executor=exe,
model_path=os.path.join(args.checkpoint_dir, 'best_model'),
program=val_program)
############################################################################################################
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册