提交 6e3062be 编写于 作者: C Channingss

optimize code structure for TensorRT(deploy)

上级 b55a5ea5
...@@ -114,8 +114,6 @@ if (NOT WIN32) ...@@ -114,8 +114,6 @@ if (NOT WIN32)
if (WITH_TENSORRT AND WITH_GPU) if (WITH_TENSORRT AND WITH_GPU)
include_directories("${TENSORRT_DIR}/include") include_directories("${TENSORRT_DIR}/include")
link_directories("${TENSORRT_DIR}/lib") link_directories("${TENSORRT_DIR}/lib")
#include_directories("${PADDLE_DIR}/third_party/install/tensorrt/include")
#link_directories("${PADDLE_DIR}/third_party/install/tensorrt/lib")
endif() endif()
endif(NOT WIN32) endif(NOT WIN32)
...@@ -172,7 +170,7 @@ endif() ...@@ -172,7 +170,7 @@ endif()
if (NOT WIN32) if (NOT WIN32)
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf z xxhash yaml-cpp glog gflags protobuf z xxhash yaml-cpp
) )
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib") if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
...@@ -199,8 +197,6 @@ if(WITH_GPU) ...@@ -199,8 +197,6 @@ if(WITH_GPU)
if (WITH_TENSORRT) if (WITH_TENSORRT)
set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
#set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
#set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
endif() endif()
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX})
...@@ -216,7 +212,7 @@ if (NOT WIN32) ...@@ -216,7 +212,7 @@ if (NOT WIN32)
set(DEPS ${DEPS} ${EXTERNAL_LIB}) set(DEPS ${DEPS} ${EXTERNAL_LIB})
endif() endif()
set(DEPS ${DEPS} ${OpenCV_LIBS}) set(DEPS ${DEPS} ${OpenCV_LIBS})
add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp) add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp)
ADD_DEPENDENCIES(classifier ext-yaml-cpp) ADD_DEPENDENCIES(classifier ext-yaml-cpp)
target_link_libraries(classifier ${DEPS}) target_link_libraries(classifier ${DEPS})
...@@ -256,4 +252,3 @@ if (WIN32 AND WITH_MKL) ...@@ -256,4 +252,3 @@ if (WIN32 AND WITH_MKL)
) )
endif() endif()
...@@ -152,18 +152,19 @@ class Padding : public Transform { ...@@ -152,18 +152,19 @@ class Padding : public Transform {
virtual void Init(const YAML::Node& item) { virtual void Init(const YAML::Node& item) {
if (item["coarsest_stride"].IsDefined()) { if (item["coarsest_stride"].IsDefined()) {
coarsest_stride_ = item["coarsest_stride"].as<int>(); coarsest_stride_ = item["coarsest_stride"].as<int>();
if (coarsest_stride_ <= 1) { if (coarsest_stride_ < 1) {
std::cerr << "[Padding] coarest_stride should greater than 0" std::cerr << "[Padding] coarest_stride should greater than 0"
<< std::endl; << std::endl;
exit(-1); exit(-1);
} }
} else { }
if (item["target_size"].IsDefined()){
if (item["target_size"].IsScalar()) { if (item["target_size"].IsScalar()) {
width_ = item["target_size"].as<int>(); width_ = item["target_size"].as<int>();
height_ = item["target_size"].as<int>(); height_ = item["target_size"].as<int>();
} else if (item["target_size"].IsSequence()) { } else if (item["target_size"].IsSequence()) {
width_ = item["target_size"].as<std::vector<int>>()[1]; width_ = item["target_size"].as<std::vector<int>>()[0];
height_ = item["target_size"].as<std::vector<int>>()[0]; height_ = item["target_size"].as<std::vector<int>>()[1];
} }
} }
if (item["im_padding_value"].IsDefined()) { if (item["im_padding_value"].IsDefined()) {
......
...@@ -6,6 +6,9 @@ WITH_TENSORRT=OFF ...@@ -6,6 +6,9 @@ WITH_TENSORRT=OFF
TENSORRT_DIR=/path/to/TensorRT/ TENSORRT_DIR=/path/to/TensorRT/
# Paddle 预测库路径 # Paddle 预测库路径
PADDLE_DIR=/path/to/fluid_inference/ PADDLE_DIR=/path/to/fluid_inference/
# Paddle 的预测库是否使用静态库来编译
# 使用TensorRT时,Paddle的预测库通常为动态库
WITH_STATIC_LIB=ON
# CUDA 的 lib 路径 # CUDA 的 lib 路径
CUDA_LIB=/path/to/cuda/lib/ CUDA_LIB=/path/to/cuda/lib/
# CUDNN 的 lib 路径 # CUDNN 的 lib 路径
...@@ -24,6 +27,7 @@ cmake .. \ ...@@ -24,6 +27,7 @@ cmake .. \
-DWITH_TENSORRT=${WITH_TENSORRT} \ -DWITH_TENSORRT=${WITH_TENSORRT} \
-DTENSORRT_DIR=${TENSORRT_DIR} \ -DTENSORRT_DIR=${TENSORRT_DIR} \
-DPADDLE_DIR=${PADDLE_DIR} \ -DPADDLE_DIR=${PADDLE_DIR} \
-DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
-DCUDA_LIB=${CUDA_LIB} \ -DCUDA_LIB=${CUDA_LIB} \
-DCUDNN_LIB=${CUDNN_LIB} \ -DCUDNN_LIB=${CUDNN_LIB} \
-DOPENCV_DIR=${OPENCV_DIR} -DOPENCV_DIR=${OPENCV_DIR}
......
...@@ -69,7 +69,7 @@ int main(int argc, char** argv) { ...@@ -69,7 +69,7 @@ int main(int argc, char** argv) {
<< result.boxes[i].coordinate[0] << ", " << result.boxes[i].coordinate[0] << ", "
<< result.boxes[i].coordinate[1] << ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", " << result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << std::endl; << result.boxes[i].coordinate[3] << ")" << std::endl;
} }
// 可视化 // 可视化
...@@ -92,7 +92,7 @@ int main(int argc, char** argv) { ...@@ -92,7 +92,7 @@ int main(int argc, char** argv) {
<< result.boxes[i].coordinate[0] << ", " << result.boxes[i].coordinate[0] << ", "
<< result.boxes[i].coordinate[1] << ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", " << result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << std::endl; << result.boxes[i].coordinate[3] << ")" << std::endl;
} }
// 可视化 // 可视化
......
...@@ -39,11 +39,11 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -39,11 +39,11 @@ void Model::create_predictor(const std::string& model_dir,
// 开启内存优化 // 开启内存优化
config.EnableMemoryOptim(); config.EnableMemoryOptim();
if (use_trt){ if (use_trt){
config.EnableTensorRtEngine(1 << 20 /* workspace_size*/, config.EnableTensorRtEngine(1 << 20 /* workspace_size*/,
32 /* max_batch_size*/, 32 /* max_batch_size*/,
20 /* min_subgraph_size*/, 20 /* min_subgraph_size*/,
paddle::AnalysisConfig::Precision::kFloat32 /* precision*/, paddle::AnalysisConfig::Precision::kFloat32 /* precision*/,
false /* use_static*/, true /* use_static*/,
false /* use_calib_mode*/); false /* use_calib_mode*/);
} }
predictor_ = std::move(CreatePaddlePredictor(config)); predictor_ = std::move(CreatePaddlePredictor(config));
......
...@@ -92,15 +92,16 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) { ...@@ -92,15 +92,16 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
int padding_w = 0; int padding_w = 0;
int padding_h = 0; int padding_h = 0;
if (width_ > 0 & height_ > 0) { if (width_ > 1 & height_ > 1) {
padding_w = width_ - im->cols; padding_w = width_ - im->cols;
padding_h = height_ - im->rows; padding_h = height_ - im->rows;
} else if (coarsest_stride_ > 0) { } else if (coarsest_stride_ > 1) {
padding_h = padding_h =
ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows; ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
padding_w = padding_w =
ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols; ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
} }
if (padding_h < 0 || padding_w < 0) { if (padding_h < 0 || padding_w < 0) {
std::cerr << "[Padding] Computed padding_h=" << padding_h std::cerr << "[Padding] Computed padding_h=" << padding_h
<< ", padding_w=" << padding_w << ", padding_w=" << padding_w
......
...@@ -14,6 +14,12 @@ ...@@ -14,6 +14,12 @@
paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model
``` ```
使用TensorRT预测时,需指定模型的图像输入shape:[w,h],需要注意的是分类模型请保持于训练时输入的shape一致。
```
paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model --fixed_input_shape=[640,960]
```
### Python部署 ### Python部署
PaddleX已经集成了基于Python的高性能预测接口,在安装PaddleX后,可参照如下代码示例,进行预测。相关的接口文档可参考[paddlex.deploy](apis/deploy.md) PaddleX已经集成了基于Python的高性能预测接口,在安装PaddleX后,可参照如下代码示例,进行预测。相关的接口文档可参考[paddlex.deploy](apis/deploy.md)
> 点击下载测试图片 [garbage.bmp](https://bj.bcebos.com/paddlex/datasets/garbage.bmp) > 点击下载测试图片 [garbage.bmp](https://bj.bcebos.com/paddlex/datasets/garbage.bmp)
......
...@@ -39,18 +39,24 @@ fluid_inference ...@@ -39,18 +39,24 @@ fluid_inference
编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下: 编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下:
``` ```
# 是否使用GPU(即是否使用 CUDA) # 是否使用GPU(即是否使用 CUDA)
WITH_GPU=ON WITH_GPU=OFF
# 是否集成 TensorRT(仅WITH_GPU=ON 有效) # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
WITH_TENSORRT=OFF WITH_TENSORRT=OFF
# 上一步下载的 Paddle 预测库路径 # TensorRT 的lib路径
PADDLE_DIR=/root/projects/deps/fluid_inference/ TENSORRT_DIR=/path/to/TensorRT/
# Paddle 预测库路径
PADDLE_DIR=/path/to/fluid_inference/
# Paddle 的预测库是否使用静态库来编译
# 使用TensorRT时,Paddle的预测库通常为动态库
WITH_STATIC_LIB=ON
# CUDA 的 lib 路径 # CUDA 的 lib 路径
CUDA_LIB=/usr/local/cuda/lib64/ CUDA_LIB=/path/to/cuda/lib/
# CUDNN 的 lib 路径 # CUDNN 的 lib 路径
CUDNN_LIB=/usr/local/cudnn/lib64/ CUDNN_LIB=/path/to/cudnn/lib/
# OPENCV 路径, 如果使用自带预编译版本可不设置 # OPENCV 路径, 如果使用自带预编译版本可不修改
OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/ OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
sh $(pwd)/scripts/bootstrap.sh sh $(pwd)/scripts/bootstrap.sh
...@@ -61,7 +67,9 @@ cd build ...@@ -61,7 +67,9 @@ cd build
cmake .. \ cmake .. \
-DWITH_GPU=${WITH_GPU} \ -DWITH_GPU=${WITH_GPU} \
-DWITH_TENSORRT=${WITH_TENSORRT} \ -DWITH_TENSORRT=${WITH_TENSORRT} \
-DTENSORRT_DIR=${TENSORRT_DIR} \
-DPADDLE_DIR=${PADDLE_DIR} \ -DPADDLE_DIR=${PADDLE_DIR} \
-DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
-DCUDA_LIB=${CUDA_LIB} \ -DCUDA_LIB=${CUDA_LIB} \
-DCUDNN_LIB=${CUDNN_LIB} \ -DCUDNN_LIB=${CUDNN_LIB} \
-DOPENCV_DIR=${OPENCV_DIR} -DOPENCV_DIR=${OPENCV_DIR}
...@@ -83,6 +91,7 @@ make ...@@ -83,6 +91,7 @@ make
| image | 要预测的图片文件路径 | | image | 要预测的图片文件路径 |
| image_list | 按行存储图片路径的.txt文件 | | image_list | 按行存储图片路径的.txt文件 |
| use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0) | | use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0) |
| use_trt | 是否使用 TensorTr 预测, 支持值为0或1(默认值为0) |
| gpu_id | GPU 设备ID, 默认值为0 | | gpu_id | GPU 设备ID, 默认值为0 |
| save_dir | 保存可视化结果的路径, 默认值为"output",classfier无该参数 | | save_dir | 保存可视化结果的路径, 默认值为"output",classfier无该参数 |
...@@ -113,4 +122,3 @@ make ...@@ -113,4 +122,3 @@ make
./build/detector --model_dir=/path/to/models/inference_model --image_list=/root/projects/images_list.txt --use_gpu=1 --save_dir=output ./build/detector --model_dir=/path/to/models/inference_model --image_list=/root/projects/images_list.txt --use_gpu=1 --save_dir=output
``` ```
图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
...@@ -33,7 +33,7 @@ def arg_parser(): ...@@ -33,7 +33,7 @@ def arg_parser():
"--fixed_input_shape", "--fixed_input_shape",
"-fs", "-fs",
default=None, default=None,
help="export inference model with fixed input shape(TensorRT need)") help="export inference model with fixed input shape:[w,h]")
return parser return parser
...@@ -58,10 +58,11 @@ def main(): ...@@ -58,10 +58,11 @@ def main():
assert args.model_dir is not None, "--model_dir should be defined while exporting inference model" assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
assert args.save_dir is not None, "--save_dir should be defined to save inference model" assert args.save_dir is not None, "--save_dir should be defined to save inference model"
fixed_input_shape = eval(args.fixed_input_shape) fixed_input_shape = eval(args.fixed_input_shape)
assert len(fixed_input_shape) == 2, "len of fixed input shape must == 2" assert len(
fixed_input_shape) == 2, "len of fixed input shape must == 2"
model = pdx.load_model(args.model_dir, fixed_input_shape) model = pdx.load_model(args.model_dir, fixed_input_shape)
model.export_inference_model(args.save_dir, fixed_input_shape) model.export_inference_model(args.save_dir)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -316,25 +316,6 @@ class BaseAPI: ...@@ -316,25 +316,6 @@ class BaseAPI:
model_info['_ModelInputsOutputs']['test_outputs'] = [ model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items() [k, v.name] for k, v in self.test_outputs.items()
] ]
resize = {'ResizeByShort': {}}
padding = {'Padding':{}}
if model_info['_Attributes']['model_type'] == 'classifier':
crop_size = 0
for transform in model_info['Transforms']:
if 'CenterCrop' in transform:
crop_size = transform['CenterCrop']['crop_size']
break
assert crop_size == fixed_input_shape[0], "fixed_input_shape must == CenterCrop:crop_size:{}".format(crop_size)
assert crop_size == fixed_input_shape[1], "fixed_input_shape must == CenterCrop:crop_size:{}".format(crop_size)
if crop_size == 0:
logging.warning("fixed_input_shape must == input shape when trainning")
else:
resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
padding['Padding']['target_size'] = list(fixed_input_shape)
model_info['Transforms'].append(resize)
model_info['Transforms'].append(padding)
with open( with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8', osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f: mode='w') as f:
......
...@@ -35,10 +35,9 @@ class BaseClassifier(BaseAPI): ...@@ -35,10 +35,9 @@ class BaseClassifier(BaseAPI):
'MobileNetV1', 'MobileNetV2', 'Xception41', 'MobileNetV1', 'MobileNetV2', 'Xception41',
'Xception65', 'Xception71']。默认为'ResNet50'。 'Xception65', 'Xception71']。默认为'ResNet50'。
num_classes (int): 类别数。默认为1000。 num_classes (int): 类别数。默认为1000。
fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
""" """
def __init__(self, model_name='ResNet50', num_classes=1000, fixed_input_shape=None): def __init__(self, model_name='ResNet50', num_classes=1000):
self.init_params = locals() self.init_params = locals()
super(BaseClassifier, self).__init__('classifier') super(BaseClassifier, self).__init__('classifier')
if not hasattr(paddlex.cv.nets, str.lower(model_name)): if not hasattr(paddlex.cv.nets, str.lower(model_name)):
...@@ -47,11 +46,13 @@ class BaseClassifier(BaseAPI): ...@@ -47,11 +46,13 @@ class BaseClassifier(BaseAPI):
self.model_name = model_name self.model_name = model_name
self.labels = None self.labels = None
self.num_classes = num_classes self.num_classes = num_classes
self.fixed_input_shape = fixed_input_shape self.fixed_input_shape = None
def build_net(self, mode='train'): def build_net(self, mode='train'):
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape =[None, 3, self.fixed_input_shape[0], self.fixed_input_shape[1]] input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
image = fluid.data( image = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
......
...@@ -48,7 +48,6 @@ class DeepLabv3p(BaseAPI): ...@@ -48,7 +48,6 @@ class DeepLabv3p(BaseAPI):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1, 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
即平时使用的交叉熵损失函数。 即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。 ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
Raises: Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。 ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25', ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
...@@ -69,8 +68,7 @@ class DeepLabv3p(BaseAPI): ...@@ -69,8 +68,7 @@ class DeepLabv3p(BaseAPI):
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
ignore_index=255, ignore_index=255):
fixed_input_shape=None):
self.init_params = locals() self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter') super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中 # dice_loss或bce_loss只适用两类分割中
...@@ -119,7 +117,7 @@ class DeepLabv3p(BaseAPI): ...@@ -119,7 +117,7 @@ class DeepLabv3p(BaseAPI):
self.enable_decoder = enable_decoder self.enable_decoder = enable_decoder
self.labels = None self.labels = None
self.sync_bn = True self.sync_bn = True
self.fixed_input_shape = fixed_input_shape self.fixed_input_shape = None
def _get_backbone(self, backbone): def _get_backbone(self, backbone):
def mobilenetv2(backbone): def mobilenetv2(backbone):
...@@ -185,7 +183,7 @@ class DeepLabv3p(BaseAPI): ...@@ -185,7 +183,7 @@ class DeepLabv3p(BaseAPI):
use_dice_loss=self.use_dice_loss, use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight, class_weight=self.class_weight,
ignore_index=self.ignore_index, ignore_index=self.ignore_index,
fixed_input_shape = self.fixed_input_shape) fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict() outputs = OrderedDict()
......
...@@ -44,8 +44,7 @@ class FasterRCNN(BaseAPI): ...@@ -44,8 +44,7 @@ class FasterRCNN(BaseAPI):
backbone='ResNet50', backbone='ResNet50',
with_fpn=True, with_fpn=True,
aspect_ratios=[0.5, 1.0, 2.0], aspect_ratios=[0.5, 1.0, 2.0],
anchor_sizes=[32, 64, 128, 256, 512], anchor_sizes=[32, 64, 128, 256, 512]):
fixed_input_shape=None):
self.init_params = locals() self.init_params = locals()
super(FasterRCNN, self).__init__('detector') super(FasterRCNN, self).__init__('detector')
backbones = [ backbones = [
...@@ -59,7 +58,7 @@ class FasterRCNN(BaseAPI): ...@@ -59,7 +58,7 @@ class FasterRCNN(BaseAPI):
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.anchor_sizes = anchor_sizes self.anchor_sizes = anchor_sizes
self.labels = None self.labels = None
self.fixed_input_shape = fixed_input_shape self.fixed_input_shape = None
def _get_backbone(self, backbone_name): def _get_backbone(self, backbone_name):
norm_type = None norm_type = None
...@@ -113,7 +112,7 @@ class FasterRCNN(BaseAPI): ...@@ -113,7 +112,7 @@ class FasterRCNN(BaseAPI):
anchor_sizes=self.anchor_sizes, anchor_sizes=self.anchor_sizes,
train_pre_nms_top_n=train_pre_nms_top_n, train_pre_nms_top_n=train_pre_nms_top_n,
test_pre_nms_top_n=test_pre_nms_top_n, test_pre_nms_top_n=test_pre_nms_top_n,
fixed_input_shape = self.fixed_input_shape) fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
if mode == 'train': if mode == 'train':
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
......
...@@ -39,13 +39,12 @@ def load_model(model_dir, fixed_input_shape=None): ...@@ -39,13 +39,12 @@ def load_model(model_dir, fixed_input_shape=None):
raise Exception("There's no attribute {} in paddlex.cv.models".format( raise Exception("There's no attribute {} in paddlex.cv.models".format(
info['Model'])) info['Model']))
info['_init_params']['fixed_input_shape'] = fixed_input_shape
if info['_Attributes']['model_type'] == 'classifier': if info['_Attributes']['model_type'] == 'classifier':
model = paddlex.cv.models.BaseClassifier(**info['_init_params']) model = paddlex.cv.models.BaseClassifier(**info['_init_params'])
else: else:
model = getattr(paddlex.cv.models, model = getattr(paddlex.cv.models,
info['Model'])(**info['_init_params']) info['Model'])(**info['_init_params'])
model.fixed_input_shape = fixed_input_shape
if status == "Normal" or \ if status == "Normal" or \
status == "Prune" or status == "fluid.save": status == "Prune" or status == "fluid.save":
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -80,6 +79,8 @@ def load_model(model_dir, fixed_input_shape=None): ...@@ -80,6 +79,8 @@ def load_model(model_dir, fixed_input_shape=None):
model.test_outputs[var_desc[0]] = out model.test_outputs[var_desc[0]] = out
if 'Transforms' in info: if 'Transforms' in info:
transforms_mode = info.get('TransformsMode', 'RGB') transforms_mode = info.get('TransformsMode', 'RGB')
# 固定模型的输入shape
fix_input_shape(info, fixed_input_shape=fixed_input_shape)
if transforms_mode == 'RGB': if transforms_mode == 'RGB':
to_rgb = True to_rgb = True
else: else:
...@@ -104,6 +105,34 @@ def load_model(model_dir, fixed_input_shape=None): ...@@ -104,6 +105,34 @@ def load_model(model_dir, fixed_input_shape=None):
return model return model
def fix_input_shape(info, fixed_input_shape=None):
if fixed_input_shape is not None:
resize = {'ResizeByShort': {}}
padding = {'Padding': {}}
if info['_Attributes']['model_type'] == 'classifier':
crop_size = 0
for transform in info['Transforms']:
if 'CenterCrop' in transform:
crop_size = transform['CenterCrop']['crop_size']
break
assert crop_size == fixed_input_shape[
0], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
crop_size)
assert crop_size == fixed_input_shape[
1], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
crop_size)
if crop_size == 0:
logging.warning(
"fixed_input_shape must == input shape when trainning")
else:
print("*" * 10)
resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
padding['Padding']['target_size'] = list(fixed_input_shape)
info['Transforms'].append(resize)
info['Transforms'].append(padding)
def build_transforms(model_type, transforms_info, to_rgb=True): def build_transforms(model_type, transforms_info, to_rgb=True):
if model_type == "classifier": if model_type == "classifier":
import paddlex.cv.transforms.cls_transforms as T import paddlex.cv.transforms.cls_transforms as T
......
...@@ -36,7 +36,6 @@ class MaskRCNN(FasterRCNN): ...@@ -36,7 +36,6 @@ class MaskRCNN(FasterRCNN):
with_fpn (bool): 是否使用FPN结构。默认为True。 with_fpn (bool): 是否使用FPN结构。默认为True。
aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。 aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。 anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
""" """
def __init__(self, def __init__(self,
...@@ -44,8 +43,7 @@ class MaskRCNN(FasterRCNN): ...@@ -44,8 +43,7 @@ class MaskRCNN(FasterRCNN):
backbone='ResNet50', backbone='ResNet50',
with_fpn=True, with_fpn=True,
aspect_ratios=[0.5, 1.0, 2.0], aspect_ratios=[0.5, 1.0, 2.0],
anchor_sizes=[32, 64, 128, 256, 512], anchor_sizes=[32, 64, 128, 256, 512]):
fixed_input_shape=None):
self.init_params = locals() self.init_params = locals()
backbones = [ backbones = [
'ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd' 'ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd'
...@@ -62,7 +60,7 @@ class MaskRCNN(FasterRCNN): ...@@ -62,7 +60,7 @@ class MaskRCNN(FasterRCNN):
self.mask_head_resolution = 28 self.mask_head_resolution = 28
else: else:
self.mask_head_resolution = 14 self.mask_head_resolution = 14
self.fixed_input_shape = fixed_input_shape self.fixed_input_shape = None
def build_net(self, mode='train'): def build_net(self, mode='train'):
train_pre_nms_top_n = 2000 if self.with_fpn else 12000 train_pre_nms_top_n = 2000 if self.with_fpn else 12000
...@@ -77,7 +75,7 @@ class MaskRCNN(FasterRCNN): ...@@ -77,7 +75,7 @@ class MaskRCNN(FasterRCNN):
test_pre_nms_top_n=test_pre_nms_top_n, test_pre_nms_top_n=test_pre_nms_top_n,
num_convs=num_convs, num_convs=num_convs,
mask_head_resolution=self.mask_head_resolution, mask_head_resolution=self.mask_head_resolution,
fixed_input_shape = self.fixed_input_shape) fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
if mode == 'train': if mode == 'train':
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
......
...@@ -60,8 +60,7 @@ class YOLOv3(BaseAPI): ...@@ -60,8 +60,7 @@ class YOLOv3(BaseAPI):
label_smooth=False, label_smooth=False,
train_random_shapes=[ train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
], ]):
fixed_input_shape=None):
self.init_params = locals() self.init_params = locals()
super(YOLOv3, self).__init__('detector') super(YOLOv3, self).__init__('detector')
backbones = [ backbones = [
...@@ -81,7 +80,7 @@ class YOLOv3(BaseAPI): ...@@ -81,7 +80,7 @@ class YOLOv3(BaseAPI):
self.label_smooth = label_smooth self.label_smooth = label_smooth
self.sync_bn = True self.sync_bn = True
self.train_random_shapes = train_random_shapes self.train_random_shapes = train_random_shapes
self.fixed_input_shape = fixed_input_shape self.fixed_input_shape = None
def _get_backbone(self, backbone_name): def _get_backbone(self, backbone_name):
if backbone_name == 'DarkNet53': if backbone_name == 'DarkNet53':
...@@ -116,7 +115,7 @@ class YOLOv3(BaseAPI): ...@@ -116,7 +115,7 @@ class YOLOv3(BaseAPI):
nms_keep_topk=self.nms_keep_topk, nms_keep_topk=self.nms_keep_topk,
nms_iou_threshold=self.nms_iou_threshold, nms_iou_threshold=self.nms_iou_threshold,
train_random_shapes=self.train_random_shapes, train_random_shapes=self.train_random_shapes,
fixed_input_shape = self.fixed_input_shape) fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)]) outputs = OrderedDict([('bbox', model_out)])
......
...@@ -223,7 +223,9 @@ class FasterRCNN(object): ...@@ -223,7 +223,9 @@ class FasterRCNN(object):
inputs = OrderedDict() inputs = OrderedDict()
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape =[None, 3, self.fixed_input_shape[0], self.fixed_input_shape[1]] input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
......
...@@ -310,7 +310,9 @@ class MaskRCNN(object): ...@@ -310,7 +310,9 @@ class MaskRCNN(object):
inputs = OrderedDict() inputs = OrderedDict()
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape =[None, 3, self.fixed_input_shape[0], self.fixed_input_shape[1]] input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
......
...@@ -250,7 +250,9 @@ class YOLOv3: ...@@ -250,7 +250,9 @@ class YOLOv3:
def generate_inputs(self): def generate_inputs(self):
inputs = OrderedDict() inputs = OrderedDict()
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape =[None, 3, self.fixed_input_shape[0], self.fixed_input_shape[1]] input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
......
...@@ -315,7 +315,9 @@ class DeepLabv3p(object): ...@@ -315,7 +315,9 @@ class DeepLabv3p(object):
inputs = OrderedDict() inputs = OrderedDict()
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape =[None, 3, self.fixed_input_shape[0], self.fixed_input_shape[1]] input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
......
...@@ -231,7 +231,9 @@ class UNet(object): ...@@ -231,7 +231,9 @@ class UNet(object):
inputs = OrderedDict() inputs = OrderedDict()
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape =[None, 3, self.fixed_input_shape[0], self.fixed_input_shape[1]] input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
......
...@@ -211,7 +211,7 @@ class Padding: ...@@ -211,7 +211,7 @@ class Padding:
target_size (int|list): 填充后的图像长、宽,默认为1。 target_size (int|list): 填充后的图像长、宽,默认为1。
""" """
def __init__(self, coarsest_stride=1, target_size=None): def __init__(self, coarsest_stride=1, target_size=1):
self.coarsest_stride = coarsest_stride self.coarsest_stride = coarsest_stride
self.target_size = target_size self.target_size = target_size
...@@ -233,11 +233,12 @@ class Padding: ...@@ -233,11 +233,12 @@ class Padding:
ValueError: target_size小于原图的大小。 ValueError: target_size小于原图的大小。
""" """
if self.coarsest_stride == 1 and self.target_size is None: if self.coarsest_stride == 1:
if label_info is None: if isinstance(self.target_size, int) and self.target_size == 1:
return (im, im_info) if label_info is None:
else: return (im, im_info)
return (im, im_info, label_info) else:
return (im, im_info, label_info)
if im_info is None: if im_info is None:
im_info = dict() im_info = dict()
if not isinstance(im, np.ndarray): if not isinstance(im, np.ndarray):
...@@ -250,18 +251,17 @@ class Padding: ...@@ -250,18 +251,17 @@ class Padding:
np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride) np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
padding_im_w = int( padding_im_w = int(
np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride) np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
if self.target_size is not None:
if isinstance(self.target_size, int):
padding_im_h = self.target_size
padding_im_w = self.target_size
else:
padding_im_h = self.target_size[0]
padding_im_w = self.target_size[1]
pad_height = padding_im_h - im_h
pad_width = padding_im_w - im_w
if pad_height < 0 or pad_width < 0: if isinstance(self.target_size, int) and self.target_size != 1:
raise ValueError( padding_im_h = self.target_size
padding_im_w = self.target_size
elif isinstance(self.target_size, list):
padding_im_w = self.target_size[0]
padding_im_h = self.target_size[1]
pad_height = padding_im_h - im_h
pad_width = padding_im_w - im_w
if pad_height < 0 or pad_width < 0:
raise ValueError(
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.format(im_w, im_h, padding_im_w, padding_im_h)) .format(im_w, im_h, padding_im_w, padding_im_h))
padding_im = np.zeros((padding_im_h, padding_im_w, im_c), padding_im = np.zeros((padding_im_h, padding_im_w, im_c),
...@@ -562,7 +562,7 @@ class RandomDistort: ...@@ -562,7 +562,7 @@ class RandomDistort:
params = params_dict[ops[id].__name__] params = params_dict[ops[id].__name__]
prob = prob_dict[ops[id].__name__] prob = prob_dict[ops[id].__name__]
params['im'] = im params['im'] = im
if np.random.uniform(0, 1) < prob: if np.random.uniform(0, 1) < prob:
im = ops[id](**params) im = ops[id](**params)
if label_info is None: if label_info is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册