未验证 提交 12109399 编写于 作者: C Chang Xu 提交者: GitHub

Update ImageNet Infer in ACT (#1408)

上级 50ec7275
......@@ -128,7 +128,7 @@ python eval.py --config_path='./configs/MobileNetV1/qat_dis.yaml'
环境配置:若使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python)
配置文件:```configs/infer.yaml```中有以下字段用于配置预测参数:
以下字段用于配置预测参数:
- ```inference_model_dir```:inference 模型文件所在目录,该目录下需要有文件 .pdmodel 和 .pdiparams 两个文件
- ```model_filename```:inference_model_dir文件夹下的模型文件名称
- ```params_filename```:inference_model_dir文件夹下的参数文件名称
......@@ -148,7 +148,13 @@ python eval.py --config_path='./configs/MobileNetV1/qat_dis.yaml'
准备好inference模型后,使用以下命令进行预测:
```shell
python infer.py --config_path="configs/infer.yaml"
python infer.py --model_dir='MobileNetV1_infer' \
--model_filename='inference.pdmodel' \
--model_filename='inference.pdiparams' \
--eval=True \
--use_gpu=True \
--enable_mkldnn=True \
--use_int8=True
```
#### 4.2 PaddleLite端侧部署
......
model_dir: "./MobileNetV1_infer"
model_filename: "inference.pdmodel"
params_filename: "inference.pdiparams"
batch_size: 1
image_size: 224
use_gpu: True
enable_mkldnn: True
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
use_int8: False
ir_optim: True
use_tensorrt: True
gpu_mem: 8000
enable_profile: False
......@@ -19,38 +19,89 @@ import time
import sys
import argparse
import yaml
from tqdm import tqdm
from utils import preprocess, postprocess
import paddle
from paddle.inference import create_predictor
from paddleslim.common import load_config
from paddle.io import DataLoader
from imagenet_reader import ImageNetDataset, process_image
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default='./image_classification/configs/infer.yaml',
help='config file path')
parser.add_argument(
'--model_dir',
type=str,
default='./MobileNetV1_infer',
help='model directory')
parser.add_argument(
'--model_filename',
type=str,
default='inference.pdmodel',
help='model file name')
parser.add_argument(
'--params_filename',
type=str,
default='inference.pdiparams',
help='params file name')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument('--resize_size', type=int, default=256)
parser.add_argument(
'--eval', type=bool, default=False, help='Whether to evaluate')
parser.add_argument('--data_path', type=str, default='./ILSVRC2012/')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='Whether to use gpu')
parser.add_argument(
'--enable_mkldnn',
type=bool,
default=False,
help='Whether to use mkldnn')
parser.add_argument(
'--cpu_num_threads', type=int, default=10, help='Number of cpu threads')
parser.add_argument(
'--use_fp16', type=bool, default=False, help='Whether to use fp16')
parser.add_argument(
'--use_int8', type=bool, default=False, help='Whether to use int8')
parser.add_argument(
'--use_tensorrt',
type=bool,
default=True,
help='Whether to use tensorrt')
parser.add_argument(
'--enable_profile',
type=bool,
default=False,
help='Whether to enable profile')
parser.add_argument('--gpu_mem', type=int, default=8000, help='GPU memory')
parser.add_argument('--ir_optim', type=bool, default=True)
return parser
def eval_reader(data_dir, batch_size, crop_size, resize_size):
val_reader = ImageNetDataset(
mode='val',
data_dir=data_dir,
crop_size=crop_size,
resize_size=resize_size)
val_loader = DataLoader(
val_reader,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=0)
return val_loader
class Predictor(object):
def __init__(self, config):
def __init__(self, args):
# HALF precission predict only work when using tensorrt
if config['use_fp16'] is True:
assert config['use_tensorrt'] is True
self.config = config
if args.use_fp16 is True:
assert args.use_tensorrt is True
self.args = args
self.paddle_predictor = self.create_paddle_predictor()
input_names = self.paddle_predictor.get_input_names()
......@@ -62,36 +113,34 @@ class Predictor(object):
output_names[0])
def create_paddle_predictor(self):
inference_model_dir = self.config['model_dir']
model_file = os.path.join(inference_model_dir,
self.config['model_filename'])
inference_model_dir = self.args.model_dir
model_file = os.path.join(inference_model_dir, self.args.model_filename)
params_file = os.path.join(inference_model_dir,
self.config['params_filename'])
self.args.params_filename)
config = paddle.inference.Config(model_file, params_file)
precision = paddle.inference.Config.Precision.Float32
if self.config['use_int8']:
if self.args.use_int8:
precision = paddle.inference.Config.Precision.Int8
elif self.config['use_fp16']:
elif self.args.use_fp16:
precision = paddle.inference.Config.Precision.Half
if self.config['use_gpu']:
config.enable_use_gpu(self.config['gpu_mem'], 0)
if self.args.use_gpu:
config.enable_use_gpu(self.args.gpu_mem, 0)
else:
config.disable_gpu()
if self.config['enable_mkldnn']:
if self.args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(self.config['cpu_num_threads'])
config.set_cpu_math_library_num_threads(self.args.cpu_num_threads)
if self.config['enable_profile']:
if self.args.enable_profile:
config.enable_profile()
config.disable_glog_info()
config.switch_ir_optim(self.config['ir_optim']) # default true
if self.config['use_tensorrt']:
config.switch_ir_optim(self.args.ir_optim) # default true
if self.args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=precision,
max_batch_size=self.config['batch_size'],
max_batch_size=self.args.batch_size,
workspace_size=1 << 30,
min_subgraph_size=30,
use_calib_mode=False)
......@@ -107,9 +156,8 @@ class Predictor(object):
test_num = 1000
test_time = 0.0
for i in range(0, test_num + 10):
inputs = np.random.rand(config['batch_size'], 3,
config['image_size'],
config['image_size']).astype(np.float32)
inputs = np.random.rand(self.args.batch_size, 3, self.args.img_size,
self.args.img_size).astype(np.float32)
start_time = time.time()
self.input_tensor.copy_from_cpu(inputs)
self.paddle_predictor.run()
......@@ -118,24 +166,66 @@ class Predictor(object):
test_time += time.time() - start_time
time.sleep(0.01) # sleep for T4 GPU
fp_message = "FP16" if config['use_fp16'] else "FP32"
fp_message = "INT8" if config['use_int8'] else fp_message
trt_msg = "using tensorrt" if config[
'use_tensorrt'] else "not using tensorrt"
fp_message = "FP16" if self.args.use_fp16 else "FP32"
fp_message = "INT8" if self.args.use_int8 else fp_message
trt_msg = "using tensorrt" if self.args.use_tensorrt else "not using tensorrt"
print("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
trt_msg, fp_message, config[
'batch_size'], 1000 * test_time / test_num))
trt_msg, fp_message, args.batch_size, 1000 * test_time / test_num))
def eval(self):
if os.path.exists(self.args.data_path):
val_loader = eval_reader(
self.args.data_path,
batch_size=self.args.batch_size,
crop_size=self.args.img_size,
resize_size=self.args.resize_size)
else:
image = np.ones((1, 3, self.args.img_size,
self.args.img_size)).astype(np.float32)
label = None
val_loader = [[image, label]]
results = []
with tqdm(
total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for batch_id, (image, label) in enumerate(val_loader):
input_names = self.paddle_predictor.get_input_names()
input_tensor = self.paddle_predictor.get_input_handle(
input_names[0])
output_names = self.paddle_predictor.get_output_names()
output_tensor = self.paddle_predictor.get_output_handle(
output_names[0])
image = np.array(image)
input_tensor.copy_from_cpu(image)
self.paddle_predictor.run()
batch_output = output_tensor.copy_to_cpu()
sort_array = batch_output.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
if label is None:
results.append(top_1_pred)
break
label = np.array(label)
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
result = np.mean(np.array(results), axis=0)
print('Evaluation result: {}'.format(result[0]))
if __name__ == "__main__":
parser = argsparser()
global args
args = parser.parse_args()
config = load_config(args.config_path)
if args.model_dir != config['model_dir']:
config['model_dir'] = args.model_dir
if args.use_fp16 != config['use_fp16']:
config['use_fp16'] = args.use_fp16
if args.use_int8 != config['use_int8']:
config['use_int8'] = args.use_int8
predictor = Predictor(config)
predictor = Predictor(args)
predictor.predict()
if args.eval:
predictor.eval()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册