未验证 提交 6a161828 编写于 作者: C Chang Xu 提交者: GitHub

speedup eval in demo (#1159)

上级 a620089a
......@@ -26,14 +26,15 @@ add_arg('save_dir', str, None, "directory to save
add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.")
add_arg('data_dir', str, None, "path of dataset")
add_arg('input_name', str, "inputs", "input name of the model")
# yapf: enable
def reader_wrapper(reader):
def reader_wrapper(reader, input_name):
def gen():
for i, data in enumerate(reader()):
imgs = np.float32([item[0] for item in data])
yield {"inputs": imgs}
yield {input_name: imgs}
return gen
......@@ -45,16 +46,17 @@ def eval_reader(data_dir, batch_size):
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = eval_reader(data_dir, batch_size=1)
val_reader = eval_reader(data_dir, batch_size=args.batch_size)
image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32')
name=args.input_name, shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
results = []
for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc
if len(test_feed_names) == 1:
image = data[0][0].reshape((1, 3, 224, 224))
image = np.array([[d[0]] for d in data])
image = image.reshape((len(data), 3, 224, 224))
label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
......@@ -73,7 +75,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = data[0][0].reshape((1, 3, 224, 224))
image = np.array([[d[0]] for d in data])
image = image.reshape((len(data), 3, 224, 224))
label = [[d[1]] for d in data]
result = exe.run(
compiled_test_program,
......@@ -82,7 +85,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 5000 == 0:
if batch_id % 50 == 0:
print('Eval iter: ', batch_id)
result = np.mean(np.array(results), axis=0)
return result[0]
......@@ -97,7 +100,7 @@ if __name__ == '__main__':
train_reader = paddle.batch(
reader.train(data_dir=data_dir), batch_size=args.batch_size)
train_dataloader = reader_wrapper(train_reader)
train_dataloader = reader_wrapper(train_reader, args.input_name)
ac = AutoCompression(
model_dir=args.model_dir,
......@@ -108,6 +111,6 @@ if __name__ == '__main__':
train_config=train_config,
train_dataloader=train_dataloader,
eval_callback=eval_function,
eval_dataloader=reader_wrapper(eval_reader(data_dir, 64)))
eval_dataloader=reader_wrapper(eval_reader(data_dir, args.batch_size)), args.input_name)
ac.compress()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册