未验证 提交 6a161828 编写于 作者: C Chang Xu 提交者: GitHub

speedup eval in demo (#1159)

上级 a620089a
...@@ -26,14 +26,15 @@ add_arg('save_dir', str, None, "directory to save ...@@ -26,14 +26,15 @@ add_arg('save_dir', str, None, "directory to save
add_arg('batch_size', int, 1, "train batch size.") add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.") add_arg('config_path', str, None, "path of compression strategy config.")
add_arg('data_dir', str, None, "path of dataset") add_arg('data_dir', str, None, "path of dataset")
add_arg('input_name', str, "inputs", "input name of the model")
# yapf: enable # yapf: enable
def reader_wrapper(reader): def reader_wrapper(reader, input_name):
def gen(): def gen():
for i, data in enumerate(reader()): for i, data in enumerate(reader()):
imgs = np.float32([item[0] for item in data]) imgs = np.float32([item[0] for item in data])
yield {"inputs": imgs} yield {input_name: imgs}
return gen return gen
...@@ -45,16 +46,17 @@ def eval_reader(data_dir, batch_size): ...@@ -45,16 +46,17 @@ def eval_reader(data_dir, batch_size):
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = eval_reader(data_dir, batch_size=1) val_reader = eval_reader(data_dir, batch_size=args.batch_size)
image = paddle.static.data( image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32') name=args.input_name, shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
results = [] results = []
for batch_id, data in enumerate(val_reader()): for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc # top1_acc, top5_acc
if len(test_feed_names) == 1: if len(test_feed_names) == 1:
image = data[0][0].reshape((1, 3, 224, 224)) image = np.array([[d[0]] for d in data])
image = image.reshape((len(data), 3, 224, 224))
label = [[d[1]] for d in data] label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program, pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image}, feed={test_feed_names[0]: image},
...@@ -73,7 +75,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -73,7 +75,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
results.append([top_1, top_5]) results.append([top_1, top_5])
else: else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = data[0][0].reshape((1, 3, 224, 224)) image = np.array([[d[0]] for d in data])
image = image.reshape((len(data), 3, 224, 224))
label = [[d[1]] for d in data] label = [[d[1]] for d in data]
result = exe.run( result = exe.run(
compiled_test_program, compiled_test_program,
...@@ -82,7 +85,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -82,7 +85,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list=test_fetch_list) fetch_list=test_fetch_list)
result = [np.mean(r) for r in result] result = [np.mean(r) for r in result]
results.append(result) results.append(result)
if batch_id % 5000 == 0: if batch_id % 50 == 0:
print('Eval iter: ', batch_id) print('Eval iter: ', batch_id)
result = np.mean(np.array(results), axis=0) result = np.mean(np.array(results), axis=0)
return result[0] return result[0]
...@@ -97,7 +100,7 @@ if __name__ == '__main__': ...@@ -97,7 +100,7 @@ if __name__ == '__main__':
train_reader = paddle.batch( train_reader = paddle.batch(
reader.train(data_dir=data_dir), batch_size=args.batch_size) reader.train(data_dir=data_dir), batch_size=args.batch_size)
train_dataloader = reader_wrapper(train_reader) train_dataloader = reader_wrapper(train_reader, args.input_name)
ac = AutoCompression( ac = AutoCompression(
model_dir=args.model_dir, model_dir=args.model_dir,
...@@ -108,6 +111,6 @@ if __name__ == '__main__': ...@@ -108,6 +111,6 @@ if __name__ == '__main__':
train_config=train_config, train_config=train_config,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_callback=eval_function, eval_callback=eval_function,
eval_dataloader=reader_wrapper(eval_reader(data_dir, 64))) eval_dataloader=reader_wrapper(eval_reader(data_dir, args.batch_size)), args.input_name)
ac.compress() ac.compress()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册