未验证 提交 c0ebad1a 编写于 作者: G Guanghua Yu 提交者: GitHub

fix image_shape in export_model (#3093)

* fix image_shape in export_model
上级 712e19f3
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
**PaddleDetection、 PaddlePaddle与PaddleSlim 版本关系:** **PaddleDetection、 PaddlePaddle与PaddleSlim 版本关系:**
| PaddleDetection版本 | PaddlePaddle版本 | PaddleSlim版本 | 备注 | | PaddleDetection版本 | PaddlePaddle版本 | PaddleSlim版本 | 备注 |
| :------------------: | :---------------: | :-------: |:---------------: | | :------------------: | :---------------: | :-------: |:---------------: |
| release/2.1 | >= 2.1.0 | 2.1 | -- | | release/2.1 | >= 2.1.0 | 2.1 | 量化模型导出依赖最新Paddle develop分支,可在[PaddlePaddle每日版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)中下载安装 |
| release/2.0 | >= 2.0.1 | 2.0 | 量化依赖Paddle 2.1及PaddleSlim 2.1 | | release/2.0 | >= 2.0.1 | 2.0 | 量化依赖Paddle 2.1及PaddleSlim 2.1 |
...@@ -107,7 +107,7 @@ python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{ ...@@ -107,7 +107,7 @@ python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{
#### COCO上benchmark #### COCO上benchmark
| 模型 | 压缩策略 | GFLOPs | 模型体积(MB) | 输入尺寸 | 预测时延(SD855) | Box AP | 下载 | 模型配置文件 | 压缩算法配置文件 | | 模型 | 压缩策略 | GFLOPs | 模型体积(MB) | 输入尺寸 | 预测时延(SD855) | Box AP | 下载 | 模型配置文件 | 压缩算法配置文件 |
| :---------: | :-------: | :------------: |:-------------: | :------: | :-------------: | :------: | :-----------------------------------------------------: |:-------------: | :------: | | :---------: | :-------: | :------------: |:-------------: | :------: | :-------------: | :------: | :-----------------------------------------------------: |:-------------: | :------: |
| PP-YOLO-MobileNetV3_large | baseline | -- | 18.5 | 608 | 25.1ms | 24.3 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyolo_mbv3_large_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyolo/ppyolo_mbv3_large_coco.yml) | - | | PP-YOLO-MobileNetV3_large | baseline | -- | 18.5 | 608 | 25.1ms | 23.2 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyolo_mbv3_large_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyolo/ppyolo_mbv3_large_coco.yml) | - |
| PP-YOLO-MobileNetV3_large | 剪裁-FPGM | -37% | 12.6 | 608 | - | 22.3 | [下载链接](https://paddledet.bj.bcebos.com/models/slim/ppyolo_mbv3_large_prune_fpgm.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyolo/ppyolo_mbv3_large_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/prune/ppyolo_mbv3_large_prune_fpgm.yml) | | PP-YOLO-MobileNetV3_large | 剪裁-FPGM | -37% | 12.6 | 608 | - | 22.3 | [下载链接](https://paddledet.bj.bcebos.com/models/slim/ppyolo_mbv3_large_prune_fpgm.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyolo/ppyolo_mbv3_large_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/prune/ppyolo_mbv3_large_prune_fpgm.yml) |
| YOLOv3-DarkNet53 | baseline | -- | 238.2 | 608 | - | 39.0 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyolo_mbv3_large_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | - | | YOLOv3-DarkNet53 | baseline | -- | 238.2 | 608 | - | 39.0 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyolo_mbv3_large_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | - |
| YOLOv3-DarkNet53 | 剪裁-FPGM | -24% | - | 608 | - | 37.6 | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_darknet_prune_fpgm.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/prune/yolov3_darknet_prune_fpgm.yml) | | YOLOv3-DarkNet53 | 剪裁-FPGM | -24% | - | 608 | - | 37.6 | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_darknet_prune_fpgm.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/prune/yolov3_darknet_prune_fpgm.yml) |
......
...@@ -8,7 +8,9 @@ TensorRT是NVIDIA提出的用于统一模型部署的加速库,可以应用于 ...@@ -8,7 +8,9 @@ TensorRT是NVIDIA提出的用于统一模型部署的加速库,可以应用于
- 如果Python和CPP官网没有提供已编译好的安装包或预测库,请参考[源码安装](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html) 自行编译 - 如果Python和CPP官网没有提供已编译好的安装包或预测库,请参考[源码安装](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html) 自行编译
注意,您的机器上TensorRT的版本需要跟您使用的预测库中TensorRT版本保持一致。 **注意:**
- 您的机器上TensorRT的版本需要跟您使用的预测库中TensorRT版本保持一致。
- PaddleDetection中部署预测要求TensorRT版本 > 6.0。
## 2. 导出模型 ## 2. 导出模型
模型导出具体请参考文档[PaddleDetection模型导出教程](../EXPORT_MODEL.md) 模型导出具体请参考文档[PaddleDetection模型导出教程](../EXPORT_MODEL.md)
...@@ -31,7 +33,6 @@ config->EnableTensorRtEngine(1 << 20 /*workspace_size*/, ...@@ -31,7 +33,6 @@ config->EnableTensorRtEngine(1 << 20 /*workspace_size*/,
``` ```
### 3.2 TensorRT固定尺寸预测 ### 3.2 TensorRT固定尺寸预测
TensorRT版本<=5时,使用TensorRT预测时,只支持固定尺寸输入。
在导出模型时指定模型输入尺寸,设置`TestReader.inputs_def.image_shape=[3,640,640]`,具体请参考[PaddleDetection模型导出教程](../EXPORT_MODEL.md) 在导出模型时指定模型输入尺寸,设置`TestReader.inputs_def.image_shape=[3,640,640]`,具体请参考[PaddleDetection模型导出教程](../EXPORT_MODEL.md)
......
...@@ -91,13 +91,6 @@ class ConfigPaser { ...@@ -91,13 +91,6 @@ class ConfigPaser {
return false; return false;
} }
if (config["image_shape"].IsDefined()) {
image_shape_ = config["image_shape"].as<std::vector<int>>();
} else {
std::cerr << "Please set image_shape." << std::endl;
return false;
}
return true; return true;
} }
std::string mode_; std::string mode_;
...@@ -106,7 +99,6 @@ class ConfigPaser { ...@@ -106,7 +99,6 @@ class ConfigPaser {
int min_subgraph_size_; int min_subgraph_size_;
YAML::Node preprocess_info_; YAML::Node preprocess_info_;
std::vector<std::string> label_list_; std::vector<std::string> label_list_;
std::vector<int> image_shape_;
}; };
} // namespace PaddleDetection } // namespace PaddleDetection
......
...@@ -82,8 +82,7 @@ class ObjectDetector { ...@@ -82,8 +82,7 @@ class ObjectDetector {
config_.load_config(model_dir); config_.load_config(model_dir);
this->min_subgraph_size_ = config_.min_subgraph_size_; this->min_subgraph_size_ = config_.min_subgraph_size_;
threshold_ = config_.draw_threshold_; threshold_ = config_.draw_threshold_;
image_shape_ = config_.image_shape_; preprocessor_.Init(config_.preprocess_info_);
preprocessor_.Init(config_.preprocess_info_, image_shape_);
LoadModel(model_dir, batch_size, run_mode); LoadModel(model_dir, batch_size, run_mode);
} }
...@@ -134,7 +133,6 @@ class ObjectDetector { ...@@ -134,7 +133,6 @@ class ObjectDetector {
std::vector<int> out_bbox_num_data_; std::vector<int> out_bbox_num_data_;
float threshold_; float threshold_;
ConfigPaser config_; ConfigPaser config_;
std::vector<int> image_shape_;
}; };
} // namespace PaddleDetection } // namespace PaddleDetection
...@@ -48,19 +48,19 @@ class ImageBlob { ...@@ -48,19 +48,19 @@ class ImageBlob {
// Abstraction of preprocessing opration class // Abstraction of preprocessing opration class
class PreprocessOp { class PreprocessOp {
public: public:
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) = 0; virtual void Init(const YAML::Node& item) = 0;
virtual void Run(cv::Mat* im, ImageBlob* data) = 0; virtual void Run(cv::Mat* im, ImageBlob* data) = 0;
}; };
class InitInfo : public PreprocessOp{ class InitInfo : public PreprocessOp{
public: public:
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) {} virtual void Init(const YAML::Node& item) {}
virtual void Run(cv::Mat* im, ImageBlob* data); virtual void Run(cv::Mat* im, ImageBlob* data);
}; };
class NormalizeImage : public PreprocessOp { class NormalizeImage : public PreprocessOp {
public: public:
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) { virtual void Init(const YAML::Node& item) {
mean_ = item["mean"].as<std::vector<float>>(); mean_ = item["mean"].as<std::vector<float>>();
scale_ = item["std"].as<std::vector<float>>(); scale_ = item["std"].as<std::vector<float>>();
is_scale_ = item["is_scale"].as<bool>(); is_scale_ = item["is_scale"].as<bool>();
...@@ -77,21 +77,18 @@ class NormalizeImage : public PreprocessOp { ...@@ -77,21 +77,18 @@ class NormalizeImage : public PreprocessOp {
class Permute : public PreprocessOp { class Permute : public PreprocessOp {
public: public:
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) {} virtual void Init(const YAML::Node& item) {}
virtual void Run(cv::Mat* im, ImageBlob* data); virtual void Run(cv::Mat* im, ImageBlob* data);
}; };
class Resize : public PreprocessOp { class Resize : public PreprocessOp {
public: public:
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) { virtual void Init(const YAML::Node& item) {
interp_ = item["interp"].as<int>(); interp_ = item["interp"].as<int>();
//max_size_ = item["target_size"].as<int>(); //max_size_ = item["target_size"].as<int>();
keep_ratio_ = item["keep_ratio"].as<bool>(); keep_ratio_ = item["keep_ratio"].as<bool>();
target_size_ = item["target_size"].as<std::vector<int>>(); target_size_ = item["target_size"].as<std::vector<int>>();
if (item["keep_ratio"]) {
in_net_shape_ = image_shape;
}
} }
// Compute best resize scale for x-dimension, y-dimension // Compute best resize scale for x-dimension, y-dimension
...@@ -109,7 +106,7 @@ class Resize : public PreprocessOp { ...@@ -109,7 +106,7 @@ class Resize : public PreprocessOp {
// Models with FPN need input shape % stride == 0 // Models with FPN need input shape % stride == 0
class PadStride : public PreprocessOp { class PadStride : public PreprocessOp {
public: public:
virtual void Init(const YAML::Node& item, const std::vector<int> image_shape) { virtual void Init(const YAML::Node& item) {
stride_ = item["stride"].as<int>(); stride_ = item["stride"].as<int>();
} }
...@@ -121,14 +118,14 @@ class PadStride : public PreprocessOp { ...@@ -121,14 +118,14 @@ class PadStride : public PreprocessOp {
class Preprocessor { class Preprocessor {
public: public:
void Init(const YAML::Node& config_node, const std::vector<int> image_shape) { void Init(const YAML::Node& config_node) {
// initialize image info at first // initialize image info at first
ops_["InitInfo"] = std::make_shared<InitInfo>(); ops_["InitInfo"] = std::make_shared<InitInfo>();
for (const auto& item : config_node) { for (const auto& item : config_node) {
auto op_name = item["type"].as<std::string>(); auto op_name = item["type"].as<std::string>();
ops_[op_name] = CreateOp(op_name); ops_[op_name] = CreateOp(op_name);
ops_[op_name]->Init(item, image_shape); ops_[op_name]->Init(item);
} }
} }
......
...@@ -99,8 +99,7 @@ class Detector(object): ...@@ -99,8 +99,7 @@ class Detector(object):
input_im_lst = [] input_im_lst = []
input_im_info_lst = [] input_im_info_lst = []
for im_path in image_list: for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops, im, im_info = preprocess(im_path, preprocess_ops)
self.pred_config.input_shape)
input_im_lst.append(im) input_im_lst.append(im)
input_im_info_lst.append(im_info) input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst) inputs = create_inputs(input_im_lst, input_im_info_lst)
...@@ -141,12 +140,12 @@ class Detector(object): ...@@ -141,12 +140,12 @@ class Detector(object):
''' '''
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list) inputs = self.preprocess(image_list)
self.det_times.preprocess_time_s.end()
np_boxes, np_masks = None, None np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for i in range(len(input_names)): for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i]) input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]]) input_tensor.copy_from_cpu(inputs[input_names[i]])
self.det_times.preprocess_time_s.end()
for i in range(warmup): for i in range(warmup):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
...@@ -236,14 +235,14 @@ class DetectorSOLOv2(Detector): ...@@ -236,14 +235,14 @@ class DetectorSOLOv2(Detector):
'cate_label': label of segm, shape:[N] 'cate_label': label of segm, shape:[N]
'cate_score': confidence score of segm, shape:[N] 'cate_score': confidence score of segm, shape:[N]
''' '''
self.det_times.postprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image) inputs = self.preprocess(image)
self.det_times.preprocess_time_s.end()
np_label, np_score, np_segms = None, None, None np_label, np_score, np_segms = None, None, None
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for i in range(len(input_names)): for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i]) input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]]) input_tensor.copy_from_cpu(inputs[input_names[i]])
self.det_times.postprocess_time_s.end()
for i in range(warmup): for i in range(warmup):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
...@@ -331,7 +330,6 @@ class PredictConfig(): ...@@ -331,7 +330,6 @@ class PredictConfig():
self.mask = False self.mask = False
if 'mask' in yml_conf: if 'mask' in yml_conf:
self.mask = yml_conf['mask'] self.mask = yml_conf['mask']
self.input_shape = yml_conf['image_shape']
self.print_config() self.print_config()
def check_model(self, yml_conf): def check_model(self, yml_conf):
......
...@@ -88,8 +88,7 @@ class KeyPoint_Detector(object): ...@@ -88,8 +88,7 @@ class KeyPoint_Detector(object):
new_op_info = op_info.copy() new_op_info = op_info.copy()
op_type = new_op_info.pop('type') op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info)) preprocess_ops.append(eval(op_type)(**new_op_info))
im, im_info = preprocess(im, preprocess_ops, im, im_info = preprocess(im, preprocess_ops)
self.pred_config.input_shape)
inputs = create_inputs(im, im_info) inputs = create_inputs(im, im_info)
return inputs return inputs
...@@ -213,7 +212,6 @@ class PredictConfig_KeyPoint(): ...@@ -213,7 +212,6 @@ class PredictConfig_KeyPoint():
self.tagmap = False self.tagmap = False
if 'keypoint_bottomup' == self.archcls: if 'keypoint_bottomup' == self.archcls:
self.tagmap = True self.tagmap = True
self.input_shape = yml_conf['image_shape']
self.print_config() self.print_config()
def check_model(self, yml_conf): def check_model(self, yml_conf):
......
...@@ -47,11 +47,7 @@ class Resize(object): ...@@ -47,11 +47,7 @@ class Resize(object):
interp (int): method of resize interp (int): method of resize
""" """
def __init__( def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
self,
target_size,
keep_ratio=True,
interp=cv2.INTER_LINEAR, ):
if isinstance(target_size, int): if isinstance(target_size, int):
target_size = [target_size, target_size] target_size = [target_size, target_size]
self.target_size = target_size self.target_size = target_size
...@@ -81,14 +77,6 @@ class Resize(object): ...@@ -81,14 +77,6 @@ class Resize(object):
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array( im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32') [im_scale_y, im_scale_x]).astype('float32')
# padding im when image_shape fixed by infer_cfg.yml
if self.keep_ratio and im_info['input_shape'][1] != -1:
max_size = im_info['input_shape'][1]
padding_im = np.zeros(
(max_size, max_size, im_channel), dtype=np.float32)
im_h, im_w = im.shape[:2]
padding_im[:im_h, :im_w, :] = im
im = padding_im
return im, im_info return im, im_info
def generate_scale(self, im): def generate_scale(self, im):
...@@ -205,13 +193,12 @@ class PadStride(object): ...@@ -205,13 +193,12 @@ class PadStride(object):
return padding_im, im_info return padding_im, im_info
def preprocess(im, preprocess_ops, input_shape): def preprocess(im, preprocess_ops):
# process image by preprocess_ops # process image by preprocess_ops
im_info = { im_info = {
'scale_factor': np.array( 'scale_factor': np.array(
[1., 1.], dtype=np.float32), [1., 1.], dtype=np.float32),
'im_shape': None, 'im_shape': None,
'input_shape': input_shape,
} }
im, im_info = decode_image(im, im_info) im, im_info = decode_image(im, im_info)
for operator in preprocess_ops: for operator in preprocess_ops:
......
...@@ -58,9 +58,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): ...@@ -58,9 +58,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
for key, value in st.items(): for key, value in st.items():
p = {'type': key} p = {'type': key}
if key == 'Resize': if key == 'Resize':
if value.get('keep_ratio', False) and int(image_shape[1]) != -1: if int(image_shape[1]) != -1:
max_size = max(image_shape[1:])
image_shape = [3, max_size, max_size]
value['target_size'] = image_shape[1:] value['target_size'] = image_shape[1:]
p.update(value) p.update(value)
preprocess_list.append(p) preprocess_list.append(p)
...@@ -76,7 +74,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): ...@@ -76,7 +74,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
}) })
break break
return preprocess_list, label_list, image_shape return preprocess_list, label_list
def _dump_infer_config(config, path, image_shape, model): def _dump_infer_config(config, path, image_shape, model):
...@@ -87,7 +85,6 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -87,7 +85,6 @@ def _dump_infer_config(config, path, image_shape, model):
'mode': 'fluid', 'mode': 'fluid',
'draw_threshold': 0.5, 'draw_threshold': 0.5,
'metric': config['metric'], 'metric': config['metric'],
'image_shape': image_shape
}) })
infer_arch = config['architecture'] infer_arch = config['architecture']
...@@ -107,8 +104,7 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -107,8 +104,7 @@ def _dump_infer_config(config, path, image_shape, model):
label_arch = 'detection_arch' label_arch = 'detection_arch'
if infer_arch in KEYPOINT_ARCH: if infer_arch in KEYPOINT_ARCH:
label_arch = 'keypoint_arch' label_arch = 'keypoint_arch'
infer_cfg['Preprocess'], infer_cfg[ infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader(
'label_list'], image_shape = _parse_reader(
config['TestReader'], config['TestDataset'], config['metric'], config['TestReader'], config['TestDataset'], config['metric'],
label_arch, image_shape) label_arch, image_shape)
...@@ -119,4 +115,3 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -119,4 +115,3 @@ def _dump_infer_config(config, path, image_shape, model):
yaml.dump(infer_cfg, open(path, 'w')) yaml.dump(infer_cfg, open(path, 'w'))
logger.info("Export inference config file to {}".format(os.path.join(path))) logger.info("Export inference config file to {}".format(os.path.join(path)))
return image_shape
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册