未验证 提交 3d2a5924 编写于 作者: B Bai Yifan 提交者: GitHub

Fix save/load_inference_model, dataloader (#523)

* fix dataloader
上级 1a54970a
......@@ -121,6 +121,7 @@ def compress(args):
feed_list=[image, label],
drop_last=True,
batch_size=args.batch_size,
return_list=False,
shuffle=True,
use_shared_memory=False,
num_workers=1)
......@@ -129,6 +130,7 @@ def compress(args):
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=False,
batch_size=args.batch_size,
shuffle=False)
......@@ -217,7 +219,7 @@ def compress(args):
format(epoch_id, step_id, val_loss[0], val_acc1[0],
val_acc5[0]))
if args.save_inference:
paddle.static.save_inference_model(
paddle.fluid.io.save_inference_model(
os.path.join("./saved_models", str(epoch_id)), ["image"],
[out], exe, student_program)
_logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
......
......@@ -65,6 +65,7 @@ def eval(args):
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
batch_size=args.batch_size,
shuffle=False)
......
......@@ -150,6 +150,7 @@ def compress(args):
drop_last=True,
batch_size=args.batch_size,
shuffle=True,
return_list=False,
use_shared_memory=False,
num_workers=16)
valid_loader = paddle.io.DataLoader(
......@@ -157,6 +158,7 @@ def compress(args):
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
use_shared_memory=False,
batch_size=args.batch_size,
shuffle=False)
......
......@@ -426,7 +426,7 @@ def compress(args):
if not os.path.isdir(model_path):
os.makedirs(model_path)
paddle.static.save_inference_model(
paddle.fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out],
......
......@@ -168,6 +168,7 @@ def compress(args):
feed_list=[image, label],
drop_last=True,
batch_size=args.batch_size,
return_list=False,
use_shared_memory=False,
shuffle=True,
num_workers=1)
......@@ -176,6 +177,7 @@ def compress(args):
places=place,
feed_list=[image, label],
drop_last=False,
return_list=False,
batch_size=args.batch_size,
use_shared_memory=False,
shuffle=False)
......@@ -277,7 +279,7 @@ def compress(args):
if not os.path.isdir(model_path):
os.makedirs(model_path)
paddle.static.save_inference_model(
paddle.fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out],
......
......@@ -40,7 +40,7 @@ def eval(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
args.model_path,
exe,
model_filename=args.model_name,
......
......@@ -62,7 +62,7 @@ def export_model(args):
else:
assert False, "args.pretrained_model must set"
paddle.static.save_inference_model(
paddle.fluid.io.save_inference_model(
'./inference_model/' + args.model,
feeded_var_names=[image.name],
target_vars=[out],
......
......@@ -56,6 +56,7 @@ class TestAnalysisHelper(StaticCase):
places=places,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=64)
exe.run(paddle.static.default_startup_program())
......
......@@ -124,9 +124,14 @@ class TestQuantAwareCase2(StaticCase):
places=place,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=64)
valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64)
test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program):
iter = 0
......
......@@ -96,9 +96,14 @@ class TestQuantAwareCase1(StaticCase):
places=place,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=64)
valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64)
test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program):
iter = 0
......
......@@ -60,9 +60,14 @@ class TestQuantAwareCase1(StaticCase):
places=place,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=64)
valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64)
test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program):
iter = 0
......@@ -97,7 +102,7 @@ class TestQuantAwareCase1(StaticCase):
train(main_prog)
top1_1, top5_1 = test(val_prog)
paddle.static.save_inference_model(
paddle.fluid.io.save_inference_model(
dirname='./test_quant_post',
feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5],
......@@ -114,7 +119,7 @@ class TestQuantAwareCase1(StaticCase):
model_filename='model',
params_filename='params',
batch_nums=10)
quant_post_prog, feed_target_names, fetch_targets = paddle.static.load_inference_model(
quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_inference',
executor=exe,
model_filename='__model__',
......
......@@ -60,9 +60,14 @@ class TestQuantPostOnlyWeightCase1(StaticCase):
places=place,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=64)
valid_loader = paddle.io.DataLoader(
test_dataset, places=place, feed_list=[image, label], batch_size=64)
test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def train(program):
iter = 0
......@@ -97,7 +102,7 @@ class TestQuantPostOnlyWeightCase1(StaticCase):
train(main_prog)
top1_1, top5_1 = test(val_prog)
paddle.static.save_inference_model(
paddle.fluid.io.save_inference_model(
dirname='./test_quant_post_dynamic',
feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5],
......@@ -112,7 +117,7 @@ class TestQuantPostOnlyWeightCase1(StaticCase):
model_filename='model',
params_filename='params',
generate_test_model=True)
quant_post_prog, feed_target_names, fetch_targets = paddle.static.load_inference_model(
quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_inference/test_model', executor=exe)
top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册