提交 271fbf44 编写于 作者: B baiyfbupt

replace datafeeder with dataloader

上级 fa44693a
......@@ -15,6 +15,7 @@ from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
import models
from utility import add_arguments, print_arguments
sys.path.append('./')
from pact import *
quantization_model_save_dir = './quantization_models/'
......@@ -133,8 +134,8 @@ def compress(args):
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
if args.use_pact:
image.stop_gradient = False
......@@ -196,8 +197,7 @@ def compress(args):
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
return os.path.exists(os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
......@@ -205,20 +205,29 @@ def compress(args):
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
valid_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, places)
valid_loader.set_sample_list_generator(val_reader, place)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
for data in valid_loader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
......@@ -230,20 +239,19 @@ def compress(args):
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".
format(epoch,
np.mean(np.array(acc_top1_ns)),
np.mean(np.array(acc_top5_ns))))
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, compiled_train_prog):
batch_id = 0
for data in train_reader():
for data in train_loader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
compiled_train_prog,
feed=train_feeder.feed(data),
feed=data,
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
......@@ -259,8 +267,8 @@ def compress(args):
threshold = {}
for var in val_program.list_vars():
if 'pact' in var.name:
array = np.array(fluid.global_scope().find_var(
var.name).get_tensor())
array = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
threshold[var.name] = array[0]
print(threshold)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册