未验证 提交 da4c6757 编写于 作者: B Bai Yifan 提交者: GitHub

Fix pact demo dataloader (#526)

* fix bug

* code fix
上级 3d2a5924
...@@ -157,7 +157,9 @@ def compress(args): ...@@ -157,7 +157,9 @@ def compress(args):
places=places, places=places,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
return_list=False,
batch_size=args.batch_size, batch_size=args.batch_size,
use_shared_memory=False,
shuffle=True, shuffle=True,
num_workers=1) num_workers=1)
...@@ -166,7 +168,9 @@ def compress(args): ...@@ -166,7 +168,9 @@ def compress(args):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=False, drop_last=False,
return_list=False,
batch_size=args.batch_size, batch_size=args.batch_size,
use_shared_memory=False,
shuffle=False) shuffle=False)
if args.analysis: if args.analysis:
...@@ -372,13 +376,8 @@ def compress(args): ...@@ -372,13 +376,8 @@ def compress(args):
ckpt_path = args.checkpoint_dir ckpt_path = args.checkpoint_dir
assert args.checkpoint_epoch is not None, "checkpoint_epoch must be set" assert args.checkpoint_epoch is not None, "checkpoint_epoch must be set"
start_epoch = args.checkpoint_epoch start_epoch = args.checkpoint_epoch
paddle.static.load_vars( paddle.static.load(
exe, dirname=args.checkpoint_dir, main_program=val_program) executor=exe, model_path=args.checkpoint_dir, program=val_program)
start_step = start_epoch * int(
math.ceil(float(args.total_images) / args.batch_size))
v = paddle.static.global_scope().find_var(
'@LR_DECAY_COUNTER@').get_tensor()
v.set(np.array([start_step]).astype(np.float32), place)
best_eval_acc1 = 0 best_eval_acc1 = 0
best_acc1_epoch = 0 best_acc1_epoch = 0
...@@ -391,22 +390,20 @@ def compress(args): ...@@ -391,22 +390,20 @@ def compress(args):
_logger.info("Best Validation Acc1: {:.6f}, at epoch {}".format( _logger.info("Best Validation Acc1: {:.6f}, at epoch {}".format(
best_eval_acc1, best_acc1_epoch)) best_eval_acc1, best_acc1_epoch))
paddle.static.save( paddle.static.save(
exe, model_path=os.path.join(args.output_dir, str(i)),
dirname=os.path.join(args.output_dir, str(i)), program=val_program)
main_program=val_program)
if acc1 > best_acc1: if acc1 > best_acc1:
best_acc1 = acc1 best_acc1 = acc1
best_epoch = i best_epoch = i
paddle.static.save( paddle.static.save(
exe, model_path=os.path.join(args.output_dir, 'best_model'),
dirname=os.path.join(args.output_dir, 'best_model'), program=val_program)
main_program=val_program)
if os.path.exists(os.path.join(args.output_dir, 'best_model')): if os.path.exists(os.path.join(args.output_dir, 'best_model.pdparams')):
paddle.static.load( paddle.static.load(
exe, executor=exe,
dirname=os.path.join(args.output_dir, 'best_model'), model_path=os.path.join(args.output_dir, 'best_model'),
main_program=val_program) program=val_program)
# 3. Freeze the graph after training by adjusting the quantize # 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
......
...@@ -256,9 +256,9 @@ def compress(args): ...@@ -256,9 +256,9 @@ def compress(args):
model_path=os.path.join(args.checkpoint_dir, 'best_model')) model_path=os.path.join(args.checkpoint_dir, 'best_model'))
if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')): if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')):
paddle.static.load( paddle.static.load(
exe, executor=exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'), model_path=os.path.join(args.checkpoint_dir, 'best_model'),
main_program=val_program) program=val_program)
############################################################################################################ ############################################################################################################
# 3. Freeze the graph after training by adjusting the quantize # 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册