未验证 提交 3d2a5924 编写于 作者: B Bai Yifan 提交者: GitHub

Fix save/load_inference_model, dataloader (#523)

* fix dataloader
上级 1a54970a
...@@ -121,6 +121,7 @@ def compress(args): ...@@ -121,6 +121,7 @@ def compress(args):
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
batch_size=args.batch_size, batch_size=args.batch_size,
return_list=False,
shuffle=True, shuffle=True,
use_shared_memory=False, use_shared_memory=False,
num_workers=1) num_workers=1)
...@@ -129,6 +130,7 @@ def compress(args): ...@@ -129,6 +130,7 @@ def compress(args):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=False, drop_last=False,
return_list=False,
use_shared_memory=False, use_shared_memory=False,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False) shuffle=False)
...@@ -217,7 +219,7 @@ def compress(args): ...@@ -217,7 +219,7 @@ def compress(args):
format(epoch_id, step_id, val_loss[0], val_acc1[0], format(epoch_id, step_id, val_loss[0], val_acc1[0],
val_acc5[0])) val_acc5[0]))
if args.save_inference: if args.save_inference:
paddle.static.save_inference_model( paddle.fluid.io.save_inference_model(
os.path.join("./saved_models", str(epoch_id)), ["image"], os.path.join("./saved_models", str(epoch_id)), ["image"],
[out], exe, student_program) [out], exe, student_program)
_logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format( _logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
......
...@@ -65,6 +65,7 @@ def eval(args): ...@@ -65,6 +65,7 @@ def eval(args):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=False, drop_last=False,
return_list=False,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False) shuffle=False)
......
...@@ -150,6 +150,7 @@ def compress(args): ...@@ -150,6 +150,7 @@ def compress(args):
drop_last=True, drop_last=True,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
return_list=False,
use_shared_memory=False, use_shared_memory=False,
num_workers=16) num_workers=16)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
...@@ -157,6 +158,7 @@ def compress(args): ...@@ -157,6 +158,7 @@ def compress(args):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=False, drop_last=False,
return_list=False,
use_shared_memory=False, use_shared_memory=False,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False) shuffle=False)
......
...@@ -426,7 +426,7 @@ def compress(args): ...@@ -426,7 +426,7 @@ def compress(args):
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
paddle.static.save_inference_model( paddle.fluid.io.save_inference_model(
dirname=float_path, dirname=float_path,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], target_vars=[out],
......
...@@ -168,6 +168,7 @@ def compress(args): ...@@ -168,6 +168,7 @@ def compress(args):
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
batch_size=args.batch_size, batch_size=args.batch_size,
return_list=False,
use_shared_memory=False, use_shared_memory=False,
shuffle=True, shuffle=True,
num_workers=1) num_workers=1)
...@@ -176,6 +177,7 @@ def compress(args): ...@@ -176,6 +177,7 @@ def compress(args):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=False, drop_last=False,
return_list=False,
batch_size=args.batch_size, batch_size=args.batch_size,
use_shared_memory=False, use_shared_memory=False,
shuffle=False) shuffle=False)
...@@ -277,7 +279,7 @@ def compress(args): ...@@ -277,7 +279,7 @@ def compress(args):
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
paddle.static.save_inference_model( paddle.fluid.io.save_inference_model(
dirname=float_path, dirname=float_path,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], target_vars=[out],
......
...@@ -40,7 +40,7 @@ def eval(args): ...@@ -40,7 +40,7 @@ def eval(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace() place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
args.model_path, args.model_path,
exe, exe,
model_filename=args.model_name, model_filename=args.model_name,
......
...@@ -62,7 +62,7 @@ def export_model(args): ...@@ -62,7 +62,7 @@ def export_model(args):
else: else:
assert False, "args.pretrained_model must set" assert False, "args.pretrained_model must set"
paddle.static.save_inference_model( paddle.fluid.io.save_inference_model(
'./inference_model/' + args.model, './inference_model/' + args.model,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], target_vars=[out],
......
...@@ -56,6 +56,7 @@ class TestAnalysisHelper(StaticCase): ...@@ -56,6 +56,7 @@ class TestAnalysisHelper(StaticCase):
places=places, places=places,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
return_list=False,
batch_size=64) batch_size=64)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
......
...@@ -124,9 +124,14 @@ class TestQuantAwareCase2(StaticCase): ...@@ -124,9 +124,14 @@ class TestQuantAwareCase2(StaticCase):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
return_list=False,
batch_size=64) batch_size=64)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64) test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program): def train(program):
iter = 0 iter = 0
......
...@@ -96,9 +96,14 @@ class TestQuantAwareCase1(StaticCase): ...@@ -96,9 +96,14 @@ class TestQuantAwareCase1(StaticCase):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
return_list=False,
batch_size=64) batch_size=64)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64) test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program): def train(program):
iter = 0 iter = 0
......
...@@ -60,9 +60,14 @@ class TestQuantAwareCase1(StaticCase): ...@@ -60,9 +60,14 @@ class TestQuantAwareCase1(StaticCase):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
return_list=False,
batch_size=64) batch_size=64)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64) test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program): def train(program):
iter = 0 iter = 0
...@@ -97,7 +102,7 @@ class TestQuantAwareCase1(StaticCase): ...@@ -97,7 +102,7 @@ class TestQuantAwareCase1(StaticCase):
train(main_prog) train(main_prog)
top1_1, top5_1 = test(val_prog) top1_1, top5_1 = test(val_prog)
paddle.static.save_inference_model( paddle.fluid.io.save_inference_model(
dirname='./test_quant_post', dirname='./test_quant_post',
feeded_var_names=[image.name, label.name], feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5], target_vars=[avg_cost, acc_top1, acc_top5],
...@@ -114,7 +119,7 @@ class TestQuantAwareCase1(StaticCase): ...@@ -114,7 +119,7 @@ class TestQuantAwareCase1(StaticCase):
model_filename='model', model_filename='model',
params_filename='params', params_filename='params',
batch_nums=10) batch_nums=10)
quant_post_prog, feed_target_names, fetch_targets = paddle.static.load_inference_model( quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_inference', dirname='./test_quant_post_inference',
executor=exe, executor=exe,
model_filename='__model__', model_filename='__model__',
......
...@@ -60,9 +60,14 @@ class TestQuantPostOnlyWeightCase1(StaticCase): ...@@ -60,9 +60,14 @@ class TestQuantPostOnlyWeightCase1(StaticCase):
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
return_list=False,
batch_size=64) batch_size=64)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64) test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program): def train(program):
iter = 0 iter = 0
...@@ -97,7 +102,7 @@ class TestQuantPostOnlyWeightCase1(StaticCase): ...@@ -97,7 +102,7 @@ class TestQuantPostOnlyWeightCase1(StaticCase):
train(main_prog) train(main_prog)
top1_1, top5_1 = test(val_prog) top1_1, top5_1 = test(val_prog)
paddle.static.save_inference_model( paddle.fluid.io.save_inference_model(
dirname='./test_quant_post_dynamic', dirname='./test_quant_post_dynamic',
feeded_var_names=[image.name, label.name], feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5], target_vars=[avg_cost, acc_top1, acc_top5],
...@@ -112,7 +117,7 @@ class TestQuantPostOnlyWeightCase1(StaticCase): ...@@ -112,7 +117,7 @@ class TestQuantPostOnlyWeightCase1(StaticCase):
model_filename='model', model_filename='model',
params_filename='params', params_filename='params',
generate_test_model=True) generate_test_model=True)
quant_post_prog, feed_target_names, fetch_targets = paddle.static.load_inference_model( quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_inference/test_model', executor=exe) dirname='./test_quant_post_inference/test_model', executor=exe)
top1_2, top5_2 = test(quant_post_prog, fetch_targets) top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册