未验证 提交 76b49aa7 编写于 作者: J Jason 提交者: GitHub

Merge pull request #23 from Channingss/cpp_trt

support deploy with TensorRT
# 模型部署 # 模型部署
本目录为PaddleX模型部署代码。 本目录为PaddleX模型部署代码, 编译和使用的教程参考:
- [C++部署文档](../docs/deploy/deploy.md#C部署)
...@@ -3,9 +3,10 @@ project(PaddleX CXX C) ...@@ -3,9 +3,10 @@ project(PaddleX CXX C)
option(WITH_MKL "Compile demo with MKL/OpenBlas support,defaultuseMKL." ON) option(WITH_MKL "Compile demo with MKL/OpenBlas support,defaultuseMKL." ON)
option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." ON) option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." ON)
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." OFF)
option(WITH_TENSORRT "Compile demo with TensorRT." OFF) option(WITH_TENSORRT "Compile demo with TensorRT." OFF)
SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
SET(PADDLE_DIR "" CACHE PATH "Location of libraries") SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
SET(OPENCV_DIR "" CACHE PATH "Location of libraries") SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
SET(CUDA_LIB "" CACHE PATH "Location of libraries") SET(CUDA_LIB "" CACHE PATH "Location of libraries")
...@@ -111,8 +112,8 @@ endif() ...@@ -111,8 +112,8 @@ endif()
if (NOT WIN32) if (NOT WIN32)
if (WITH_TENSORRT AND WITH_GPU) if (WITH_TENSORRT AND WITH_GPU)
include_directories("${PADDLE_DIR}/third_party/install/tensorrt/include") include_directories("${TENSORRT_DIR}/include")
link_directories("${PADDLE_DIR}/third_party/install/tensorrt/lib") link_directories("${TENSORRT_DIR}/lib")
endif() endif()
endif(NOT WIN32) endif(NOT WIN32)
...@@ -169,7 +170,7 @@ endif() ...@@ -169,7 +170,7 @@ endif()
if (NOT WIN32) if (NOT WIN32)
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf z xxhash yaml-cpp glog gflags protobuf z xxhash yaml-cpp
) )
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib") if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
...@@ -194,8 +195,8 @@ endif(NOT WIN32) ...@@ -194,8 +195,8 @@ endif(NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
if(NOT WIN32) if(NOT WIN32)
if (WITH_TENSORRT) if (WITH_TENSORRT)
set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${PADDLE_DIR}/third_party/install/tensorrt/lib/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
endif() endif()
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX})
...@@ -211,7 +212,7 @@ if (NOT WIN32) ...@@ -211,7 +212,7 @@ if (NOT WIN32)
set(DEPS ${DEPS} ${EXTERNAL_LIB}) set(DEPS ${DEPS} ${EXTERNAL_LIB})
endif() endif()
set(DEPS ${DEPS} ${OpenCV_LIBS}) set(DEPS ${DEPS} ${OpenCV_LIBS})
add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp) add_executable(classifier src/classifier.cpp src/transforms.cpp src/paddlex.cpp)
ADD_DEPENDENCIES(classifier ext-yaml-cpp) ADD_DEPENDENCIES(classifier ext-yaml-cpp)
target_link_libraries(classifier ${DEPS}) target_link_libraries(classifier ${DEPS})
...@@ -251,4 +252,3 @@ if (WIN32 AND WITH_MKL) ...@@ -251,4 +252,3 @@ if (WIN32 AND WITH_MKL)
) )
endif() endif()
...@@ -38,12 +38,14 @@ class Model { ...@@ -38,12 +38,14 @@ class Model {
public: public:
void Init(const std::string& model_dir, void Init(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false,
int gpu_id = 0) { int gpu_id = 0) {
create_predictor(model_dir, use_gpu, gpu_id); create_predictor(model_dir, use_gpu, use_trt, gpu_id);
} }
void create_predictor(const std::string& model_dir, void create_predictor(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false,
int gpu_id = 0); int gpu_id = 0);
bool load_config(const std::string& model_dir); bool load_config(const std::string& model_dir);
......
...@@ -35,10 +35,8 @@ class ImageBlob { ...@@ -35,10 +35,8 @@ class ImageBlob {
std::vector<int> ori_im_size_ = std::vector<int>(2); std::vector<int> ori_im_size_ = std::vector<int>(2);
// Newest image height and width after process // Newest image height and width after process
std::vector<int> new_im_size_ = std::vector<int>(2); std::vector<int> new_im_size_ = std::vector<int>(2);
// Image height and width before padding
std::vector<int> im_size_before_padding_ = std::vector<int>(2);
// Image height and width before resize // Image height and width before resize
std::vector<int> im_size_before_resize_ = std::vector<int>(2); std::vector<std::vector<int>> im_size_before_resize_;
// Reshape order // Reshape order
std::vector<std::string> reshape_order_; std::vector<std::string> reshape_order_;
// Resize scale // Resize scale
...@@ -49,7 +47,6 @@ class ImageBlob { ...@@ -49,7 +47,6 @@ class ImageBlob {
void clear() { void clear() {
ori_im_size_.clear(); ori_im_size_.clear();
new_im_size_.clear(); new_im_size_.clear();
im_size_before_padding_.clear();
im_size_before_resize_.clear(); im_size_before_resize_.clear();
reshape_order_.clear(); reshape_order_.clear();
im_data_.clear(); im_data_.clear();
...@@ -155,12 +152,13 @@ class Padding : public Transform { ...@@ -155,12 +152,13 @@ class Padding : public Transform {
virtual void Init(const YAML::Node& item) { virtual void Init(const YAML::Node& item) {
if (item["coarsest_stride"].IsDefined()) { if (item["coarsest_stride"].IsDefined()) {
coarsest_stride_ = item["coarsest_stride"].as<int>(); coarsest_stride_ = item["coarsest_stride"].as<int>();
if (coarsest_stride_ <= 1) { if (coarsest_stride_ < 1) {
std::cerr << "[Padding] coarest_stride should greater than 0" std::cerr << "[Padding] coarest_stride should greater than 0"
<< std::endl; << std::endl;
exit(-1); exit(-1);
} }
} else { }
if (item["target_size"].IsDefined()) {
if (item["target_size"].IsScalar()) { if (item["target_size"].IsScalar()) {
width_ = item["target_size"].as<int>(); width_ = item["target_size"].as<int>();
height_ = item["target_size"].as<int>(); height_ = item["target_size"].as<int>();
......
# 是否使用GPU(即是否使用 CUDA) # 是否使用GPU(即是否使用 CUDA)
WITH_GPU=ON WITH_GPU=OFF
# 使用MKL or openblas
WITH_MKL=ON
# 是否集成 TensorRT(仅WITH_GPU=ON 有效) # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
WITH_TENSORRT=OFF WITH_TENSORRT=OFF
# TensorRT 的lib路径
TENSORRT_DIR=/path/to/TensorRT/
# Paddle 预测库路径 # Paddle 预测库路径
PADDLE_DIR=/path/to/fluid_inference/ PADDLE_DIR=/path/to/fluid_inference/
# Paddle 的预测库是否使用静态库来编译
# 使用TensorRT时,Paddle的预测库通常为动态库
WITH_STATIC_LIB=OFF
# CUDA 的 lib 路径 # CUDA 的 lib 路径
CUDA_LIB=/path/to/cuda/lib/ CUDA_LIB=/path/to/cuda/lib/
# CUDNN 的 lib 路径 # CUDNN 的 lib 路径
...@@ -19,8 +26,11 @@ mkdir -p build ...@@ -19,8 +26,11 @@ mkdir -p build
cd build cd build
cmake .. \ cmake .. \
-DWITH_GPU=${WITH_GPU} \ -DWITH_GPU=${WITH_GPU} \
-DWITH_MKL=${WITH_MKL} \
-DWITH_TENSORRT=${WITH_TENSORRT} \ -DWITH_TENSORRT=${WITH_TENSORRT} \
-DTENSORRT_DIR=${TENSORRT_DIR} \
-DPADDLE_DIR=${PADDLE_DIR} \ -DPADDLE_DIR=${PADDLE_DIR} \
-DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
-DCUDA_LIB=${CUDA_LIB} \ -DCUDA_LIB=${CUDA_LIB} \
-DCUDNN_LIB=${CUDNN_LIB} \ -DCUDNN_LIB=${CUDNN_LIB} \
-DOPENCV_DIR=${OPENCV_DIR} -DOPENCV_DIR=${OPENCV_DIR}
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_int32(gpu_id, 0, "GPU card id"); DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(image, "", "Path of test image file"); DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(image_list, "", "Path of test image list file");
...@@ -42,7 +43,7 @@ int main(int argc, char** argv) { ...@@ -42,7 +43,7 @@ int main(int argc, char** argv) {
// 加载模型 // 加载模型
PaddleX::Model model; PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id); model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
// 进行预测 // 进行预测
if (FLAGS_image_list != "") { if (FLAGS_image_list != "") {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_int32(gpu_id, 0, "GPU card id"); DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(image, "", "Path of test image file"); DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(image_list, "", "Path of test image list file");
...@@ -44,7 +45,7 @@ int main(int argc, char** argv) { ...@@ -44,7 +45,7 @@ int main(int argc, char** argv) {
// 加载模型 // 加载模型
PaddleX::Model model; PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id); model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
auto colormap = PaddleX::GenerateColorMap(model.labels.size()); auto colormap = PaddleX::GenerateColorMap(model.labels.size());
std::string save_dir = "output"; std::string save_dir = "output";
...@@ -68,7 +69,7 @@ int main(int argc, char** argv) { ...@@ -68,7 +69,7 @@ int main(int argc, char** argv) {
<< result.boxes[i].coordinate[0] << ", " << result.boxes[i].coordinate[0] << ", "
<< result.boxes[i].coordinate[1] << ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", " << result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << std::endl; << result.boxes[i].coordinate[3] << ")" << std::endl;
} }
// 可视化 // 可视化
...@@ -91,7 +92,7 @@ int main(int argc, char** argv) { ...@@ -91,7 +92,7 @@ int main(int argc, char** argv) {
<< result.boxes[i].coordinate[0] << ", " << result.boxes[i].coordinate[0] << ", "
<< result.boxes[i].coordinate[1] << ", " << result.boxes[i].coordinate[1] << ", "
<< result.boxes[i].coordinate[2] << ", " << result.boxes[i].coordinate[2] << ", "
<< result.boxes[i].coordinate[3] << std::endl; << result.boxes[i].coordinate[3] << ")" << std::endl;
} }
// 可视化 // 可视化
......
...@@ -18,6 +18,7 @@ namespace PaddleX { ...@@ -18,6 +18,7 @@ namespace PaddleX {
void Model::create_predictor(const std::string& model_dir, void Model::create_predictor(const std::string& model_dir,
bool use_gpu, bool use_gpu,
bool use_trt,
int gpu_id) { int gpu_id) {
// 读取配置文件 // 读取配置文件
if (!load_config(model_dir)) { if (!load_config(model_dir)) {
...@@ -37,6 +38,15 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -37,6 +38,15 @@ void Model::create_predictor(const std::string& model_dir,
config.SwitchSpecifyInputNames(true); config.SwitchSpecifyInputNames(true);
// 开启内存优化 // 开启内存优化
config.EnableMemoryOptim(); config.EnableMemoryOptim();
if (use_trt) {
config.EnableTensorRtEngine(
1 << 20 /* workspace_size*/,
32 /* max_batch_size*/,
20 /* min_subgraph_size*/,
paddle::AnalysisConfig::Precision::kFloat32 /* precision*/,
true /* use_static*/,
false /* use_calib_mode*/);
}
predictor_ = std::move(CreatePaddlePredictor(config)); predictor_ = std::move(CreatePaddlePredictor(config));
} }
...@@ -246,7 +256,6 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { ...@@ -246,7 +256,6 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
auto im_tensor = predictor_->GetInputTensor("image"); auto im_tensor = predictor_->GetInputTensor("image");
im_tensor->Reshape({1, 3, h, w}); im_tensor->Reshape({1, 3, h, w});
im_tensor->copy_from_cpu(inputs_.im_data_.data()); im_tensor->copy_from_cpu(inputs_.im_data_.data());
std::cout << "input image: " << h << " " << w << std::endl;
// 使用加载的模型进行预测 // 使用加载的模型进行预测
predictor_->ZeroCopyRun(); predictor_->ZeroCopyRun();
...@@ -286,19 +295,24 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { ...@@ -286,19 +295,24 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
result->score_map.shape[3], result->score_map.shape[3],
CV_32FC1, CV_32FC1,
result->score_map.data.data()); result->score_map.data.data());
int idx = 1;
int len_postprocess = inputs_.im_size_before_resize_.size();
for (std::vector<std::string>::reverse_iterator iter = for (std::vector<std::string>::reverse_iterator iter =
inputs_.reshape_order_.rbegin(); inputs_.reshape_order_.rbegin();
iter != inputs_.reshape_order_.rend(); iter != inputs_.reshape_order_.rend();
++iter) { ++iter) {
if (*iter == "padding") { if (*iter == "padding") {
auto padding_w = inputs_.im_size_before_padding_[0]; auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
auto padding_h = inputs_.im_size_before_padding_[1]; inputs_.im_size_before_resize_.pop_back();
auto padding_w = before_shape[0];
auto padding_h = before_shape[1];
mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h)); mask_label = mask_label(cv::Rect(0, 0, padding_w, padding_h));
mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h)); mask_score = mask_score(cv::Rect(0, 0, padding_w, padding_h));
} else if (*iter == "resize") { } else if (*iter == "resize") {
auto resize_w = inputs_.im_size_before_resize_[0]; auto before_shape = inputs_.im_size_before_resize_[len_postprocess - idx];
auto resize_h = inputs_.im_size_before_resize_[1]; inputs_.im_size_before_resize_.pop_back();
auto resize_w = before_shape[0];
auto resize_h = before_shape[1];
cv::resize(mask_label, cv::resize(mask_label,
mask_label, mask_label,
cv::Size(resize_h, resize_w), cv::Size(resize_h, resize_w),
...@@ -312,6 +326,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) { ...@@ -312,6 +326,7 @@ bool Model::predict(const cv::Mat& im, SegResult* result) {
0, 0,
cv::INTER_NEAREST); cv::INTER_NEAREST);
} }
++idx;
} }
result->label_map.data.assign(mask_label.begin<uint8_t>(), result->label_map.data.assign(mask_label.begin<uint8_t>(),
mask_label.end<uint8_t>()); mask_label.end<uint8_t>());
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_trt, false, "Infering with TensorRT");
DEFINE_int32(gpu_id, 0, "GPU card id"); DEFINE_int32(gpu_id, 0, "GPU card id");
DEFINE_string(image, "", "Path of test image file"); DEFINE_string(image, "", "Path of test image file");
DEFINE_string(image_list, "", "Path of test image list file"); DEFINE_string(image_list, "", "Path of test image list file");
...@@ -44,7 +45,8 @@ int main(int argc, char** argv) { ...@@ -44,7 +45,8 @@ int main(int argc, char** argv) {
// 加载模型 // 加载模型
PaddleX::Model model; PaddleX::Model model;
model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_gpu_id); model.Init(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_trt, FLAGS_gpu_id);
auto colormap = PaddleX::GenerateColorMap(model.labels.size()); auto colormap = PaddleX::GenerateColorMap(model.labels.size());
// 进行预测 // 进行预测
if (FLAGS_image_list != "") { if (FLAGS_image_list != "") {
......
...@@ -56,8 +56,7 @@ float ResizeByShort::GenerateScale(const cv::Mat& im) { ...@@ -56,8 +56,7 @@ float ResizeByShort::GenerateScale(const cv::Mat& im) {
} }
bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) { bool ResizeByShort::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_resize_[0] = im->rows; data->im_size_before_resize_.push_back({im->rows, im->cols});
data->im_size_before_resize_[1] = im->cols;
data->reshape_order_.push_back("resize"); data->reshape_order_.push_back("resize");
float scale = GenerateScale(*im); float scale = GenerateScale(*im);
...@@ -88,21 +87,21 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) { ...@@ -88,21 +87,21 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
} }
bool Padding::Run(cv::Mat* im, ImageBlob* data) { bool Padding::Run(cv::Mat* im, ImageBlob* data) {
data->im_size_before_padding_[0] = im->rows; data->im_size_before_resize_.push_back({im->rows, im->cols});
data->im_size_before_padding_[1] = im->cols;
data->reshape_order_.push_back("padding"); data->reshape_order_.push_back("padding");
int padding_w = 0; int padding_w = 0;
int padding_h = 0; int padding_h = 0;
if (width_ > 0 & height_ > 0) { if (width_ > 1 & height_ > 1) {
padding_w = width_ - im->cols; padding_w = width_ - im->cols;
padding_h = height_ - im->rows; padding_h = height_ - im->rows;
} else if (coarsest_stride_ > 0) { } else if (coarsest_stride_ > 1) {
padding_h = padding_h =
ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows; ceil(im->rows * 1.0 / coarsest_stride_) * coarsest_stride_ - im->rows;
padding_w = padding_w =
ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols; ceil(im->cols * 1.0 / coarsest_stride_) * coarsest_stride_ - im->cols;
} }
if (padding_h < 0 || padding_w < 0) { if (padding_h < 0 || padding_w < 0) {
std::cerr << "[Padding] Computed padding_h=" << padding_h std::cerr << "[Padding] Computed padding_h=" << padding_h
<< ", padding_w=" << padding_w << ", padding_w=" << padding_w
...@@ -122,8 +121,7 @@ bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) { ...@@ -122,8 +121,7 @@ bool ResizeByLong::Run(cv::Mat* im, ImageBlob* data) {
<< std::endl; << std::endl;
return false; return false;
} }
data->im_size_before_resize_[0] = im->rows; data->im_size_before_resize_.push_back({im->rows, im->cols});
data->im_size_before_resize_[1] = im->cols;
data->reshape_order_.push_back("resize"); data->reshape_order_.push_back("resize");
int origin_w = im->cols; int origin_w = im->cols;
int origin_h = im->rows; int origin_h = im->rows;
...@@ -149,8 +147,7 @@ bool Resize::Run(cv::Mat* im, ImageBlob* data) { ...@@ -149,8 +147,7 @@ bool Resize::Run(cv::Mat* im, ImageBlob* data) {
<< std::endl; << std::endl;
return false; return false;
} }
data->im_size_before_resize_[0] = im->rows; data->im_size_before_resize_.push_back({im->rows, im->cols});
data->im_size_before_resize_[1] = im->cols;
data->reshape_order_.push_back("resize"); data->reshape_order_.push_back("resize");
cv::resize( cv::resize(
......
...@@ -7,20 +7,29 @@ ...@@ -7,20 +7,29 @@
### 导出inference模型 ### 导出inference模型
在服务端部署的模型需要首先将模型导出为inference格式模型,导出的模型将包括`__model__``__params__``model.yml`三个文名,分别为模型的网络结构,模型权重和模型的配置文件(包括数据预处理参数等等)。在安装完PaddleX后,在命令行终端使用如下命令导出模型到当前目录`inferece_model`下。 在服务端部署的模型需要首先将模型导出为inference格式模型,导出的模型将包括`__model__``__params__``model.yml`三个文名,分别为模型的网络结构,模型权重和模型的配置文件(包括数据预处理参数等等)。在安装完PaddleX后,在命令行终端使用如下命令导出模型到当前目录`inferece_model`下。
> 可直接下载小度熊分拣模型测试本文档的流程[xiaoduxiong_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)
> 可直接下载垃圾检测模型测试本文档的流程[garbage_epoch_12.tar.gz](https://bj.bcebos.com/paddlex/models/garbage_epoch_12.tar.gz) ```
paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model
```
使用TensorRT预测时,需指定模型的图像输入shape:[w,h]。
**注**
- 分类模型请保持于训练时输入的shape一致。
- 指定[w,h]时,w和h中间逗号隔开,不允许存在空格等其他字符
``` ```
paddlex --export_inference --model_dir=./garbage_epoch_12 --save_dir=./inference_model paddlex --export_inference --model_dir=./xiaoduxiong_epoch_12 --save_dir=./inference_model --fixed_input_shape=[640,960]
``` ```
### Python部署 ### Python部署
PaddleX已经集成了基于Python的高性能预测接口,在安装PaddleX后,可参照如下代码示例,进行预测。相关的接口文档可参考[paddlex.deploy](apis/deploy.md) PaddleX已经集成了基于Python的高性能预测接口,在安装PaddleX后,可参照如下代码示例,进行预测。相关的接口文档可参考[paddlex.deploy](apis/deploy.md)
> 点击下载测试图片 [garbage.bmp](https://bj.bcebos.com/paddlex/datasets/garbage.bmp) > 点击下载测试图片 [xiaoduxiong_test_image.tar.gz](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_test_image.tar.gz)
``` ```
import paddlex as pdx import paddlex as pdx
predictorpdx.deploy.create_predictor('./inference_model') predictor = pdx.deploy.create_predictor('./inference_model')
result = predictor.predict(image='garbage.bmp') result = predictor.predict(image='xiaoduxiong_test_image/JPEGImages/WeChatIMG110.jpeg')
``` ```
### C++部署 ### C++部署
......
...@@ -19,8 +19,18 @@ ...@@ -19,8 +19,18 @@
### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
PaddlePaddle C++ 预测库针对不同的`CPU``CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,请根据实际情况下载: [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id1) PaddlePaddle C++ 预测库针对不同的`CPU``CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,目前PaddleX依赖于Paddle1.7版本,以下提供了多个不同版本的Paddle预测库:
| 版本说明 | 预测库(1.7.2版本) |
| ---- | ---- |
| ubuntu14.04_cpu_avx_mkl | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-avx-mkl/fluid_inference.tgz) |
| ubuntu14.04_cpu_avx_openblas | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-avx-openblas/fluid_inference.tgz) |
| ubuntu14.04_cpu_noavx_openblas | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-cpu-noavx-openblas/fluid_inference.tgz) |
| ubuntu14.04_cuda9.0_cudnn7_avx_mkl | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda9-cudnn7-avx-mkl/fluid_inference.tgz) |
| ubuntu14.04_cuda10.0_cudnn7_avx_mkl | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda10-cudnn7-avx-mkl/fluid_inference.tgz ) |
| ubuntu14.04_cuda10.1_cudnn7.6_avx_mkl_trt6 | [fluid_inference.tgz](https://paddle-inference-lib.bj.bcebos.com/1.7.2-gpu-cuda10.1-cudnn7.6-avx-mkl-trt6%2Ffluid_inference.tgz) |
更多和更新的版本,请根据实际情况下载: [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/windows_cpp_inference.html#id1)
下载并解压后`/root/projects/fluid_inference`目录包含内容为: 下载并解压后`/root/projects/fluid_inference`目录包含内容为:
``` ```
...@@ -40,17 +50,24 @@ fluid_inference ...@@ -40,17 +50,24 @@ fluid_inference
编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下: 编译`cmake`的命令在`scripts/build.sh`中,请根据实际情况修改主要参数,其主要内容说明如下:
``` ```
# 是否使用GPU(即是否使用 CUDA) # 是否使用GPU(即是否使用 CUDA)
WITH_GPU=ON WITH_GPU=OFF
# 使用MKL or openblas
WITH_MKL=ON
# 是否集成 TensorRT(仅WITH_GPU=ON 有效) # 是否集成 TensorRT(仅WITH_GPU=ON 有效)
WITH_TENSORRT=OFF WITH_TENSORRT=OFF
# 上一步下载的 Paddle 预测库路径 # TensorRT 的lib路径
PADDLE_DIR=/root/projects/deps/fluid_inference/ TENSORRT_DIR=/path/to/TensorRT/
# Paddle 预测库路径
PADDLE_DIR=/path/to/fluid_inference/
# Paddle 的预测库是否使用静态库来编译
# 使用TensorRT时,Paddle的预测库通常为动态库
WITH_STATIC_LIB=ON
# CUDA 的 lib 路径 # CUDA 的 lib 路径
CUDA_LIB=/usr/local/cuda/lib64/ CUDA_LIB=/path/to/cuda/lib/
# CUDNN 的 lib 路径 # CUDNN 的 lib 路径
CUDNN_LIB=/usr/local/cudnn/lib64/ CUDNN_LIB=/path/to/cudnn/lib/
# OPENCV 路径, 如果使用自带预编译版本可不设置 # OPENCV 路径, 如果使用自带预编译版本可不修改
OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/ OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
sh $(pwd)/scripts/bootstrap.sh sh $(pwd)/scripts/bootstrap.sh
...@@ -60,8 +77,11 @@ mkdir -p build ...@@ -60,8 +77,11 @@ mkdir -p build
cd build cd build
cmake .. \ cmake .. \
-DWITH_GPU=${WITH_GPU} \ -DWITH_GPU=${WITH_GPU} \
-DWITH_MKL=${WITH_MKL} \
-DWITH_TENSORRT=${WITH_TENSORRT} \ -DWITH_TENSORRT=${WITH_TENSORRT} \
-DTENSORRT_DIR=${TENSORRT_DIR} \
-DPADDLE_DIR=${PADDLE_DIR} \ -DPADDLE_DIR=${PADDLE_DIR} \
-DWITH_STATIC_LIB=${WITH_STATIC_LIB} \
-DCUDA_LIB=${CUDA_LIB} \ -DCUDA_LIB=${CUDA_LIB} \
-DCUDNN_LIB=${CUDNN_LIB} \ -DCUDNN_LIB=${CUDNN_LIB} \
-DOPENCV_DIR=${OPENCV_DIR} -DOPENCV_DIR=${OPENCV_DIR}
...@@ -83,19 +103,20 @@ make ...@@ -83,19 +103,20 @@ make
| image | 要预测的图片文件路径 | | image | 要预测的图片文件路径 |
| image_list | 按行存储图片路径的.txt文件 | | image_list | 按行存储图片路径的.txt文件 |
| use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0) | | use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0) |
| use_trt | 是否使用 TensorTr 预测, 支持值为0或1(默认值为0) |
| gpu_id | GPU 设备ID, 默认值为0 | | gpu_id | GPU 设备ID, 默认值为0 |
| save_dir | 保存可视化结果的路径, 默认值为"output",classfier无该参数 | | save_dir | 保存可视化结果的路径, 默认值为"output",classfier无该参数 |
## 样例 ## 样例
可使用[垃圾检测模型](deploy.md#导出inference模型)中生成的`inference_model`模型和测试图片进行预测。 可使用[小度熊识别模型](deploy.md#导出inference模型)中导出的`inference_model`和测试图片进行预测。
`样例一` `样例一`
不使用`GPU`测试图片 `/path/to/garbage.bmp` 不使用`GPU`测试图片 `/path/to/xiaoduxiong.jpeg`
```shell ```shell
./build/detector --model_dir=/path/to/inference_model --image=/path/to/garbage.bmp --save_dir=output ./build/detector --model_dir=/path/to/inference_model --image=/path/to/xiaoduxiong.jpeg --save_dir=output
``` ```
图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
...@@ -104,13 +125,12 @@ make ...@@ -104,13 +125,12 @@ make
使用`GPU`预测多个图片`/path/to/image_list.txt`,image_list.txt内容的格式如下: 使用`GPU`预测多个图片`/path/to/image_list.txt`,image_list.txt内容的格式如下:
``` ```
/path/to/images/garbage1.jpeg /path/to/images/xiaoduxiong1.jpeg
/path/to/images/garbage2.jpeg /path/to/images/xiaoduxiong2.jpeg
... ...
/path/to/images/garbagen.jpeg /path/to/images/xiaoduxiongn.jpeg
``` ```
```shell ```shell
./build/detector --model_dir=/path/to/models/inference_model --image_list=/root/projects/images_list.txt --use_gpu=1 --save_dir=output ./build/detector --model_dir=/path/to/models/inference_model --image_list=/root/projects/images_list.txt --use_gpu=1 --save_dir=output
``` ```
图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
...@@ -27,7 +27,18 @@ git clone https://github.com/PaddlePaddle/PaddleX.git ...@@ -27,7 +27,18 @@ git clone https://github.com/PaddlePaddle/PaddleX.git
### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference ### Step2: 下载PaddlePaddle C++ 预测库 fluid_inference
PaddlePaddle C++ 预测库针对不同的`CPU``CUDA`版本提供了不同的预编译版本,请根据实际情况下载: [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/windows_cpp_inference.html) PaddlePaddle C++ 预测库针对不同的`CPU``CUDA`,以及是否支持TensorRT,提供了不同的预编译版本,目前PaddleX依赖于Paddle1.7版本,以下提供了多个不同版本的Paddle预测库:
| 版本说明 | 预测库(1.7.2版本) | 编译器 | 构建工具| cuDNN | CUDA
| ---- | ---- | ---- | ---- | ---- | ---- |
| cpu_avx_mkl | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/cpu/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 |
| cpu_avx_openblas | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/open/cpu/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 |
| cuda9.0_cudnn7_avx_mkl | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/post97/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.4.1 | 9.0 |
| cuda9.0_cudnn7_avx_openblas | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/open/post97/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.4.1 | 9.0 |
| cuda10.0_cudnn7_avx_mkl | [fluid_inference.zip](https://paddle-wheel.bj.bcebos.com/1.7.2/win-infer/mkl/post107/fluid_inference_install_dir.zip) | MSVC 2015 update 3 | CMake v3.16.0 | 7.5.0 | 9.0 |
更多和更新的版本,请根据实际情况下载: [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id1)
解压后`D:\projects\fluid_inference*\`目录下主要包含的内容为: 解压后`D:\projects\fluid_inference*\`目录下主要包含的内容为:
``` ```
...@@ -109,14 +120,14 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release ...@@ -109,14 +120,14 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release
## 样例 ## 样例
可使用[垃圾检测模型](deploy.md#导出inference模型)中生成的`inference_model`模型和测试图片进行预测。 可使用[小度熊识别模型](deploy.md#导出inference模型)中导出的`inference_model`和测试图片进行预测。
`样例一`: `样例一`:
不使用`GPU`测试图片 `\\path\\to\\garbage.bmp` 不使用`GPU`测试图片 `\\path\\to\\xiaoduxiong.jpeg`
```shell ```shell
.\detector --model_dir=\\path\\to\\inference_model --image=D:\\images\\garbage.bmp --save_dir=output .\detector --model_dir=\\path\\to\\inference_model --image=D:\\images\\xiaoduxiong.jpeg --save_dir=output
``` ```
图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
...@@ -126,13 +137,12 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release ...@@ -126,13 +137,12 @@ cd D:\projects\PaddleX\deploy\cpp\out\build\x64-Release
使用`GPU`预测多个图片`\\path\\to\\image_list.txt`,image_list.txt内容的格式如下: 使用`GPU`预测多个图片`\\path\\to\\image_list.txt`,image_list.txt内容的格式如下:
``` ```
\\path\\to\\images\\garbage1.jpeg \\path\\to\\images\\xiaoduxiong1.jpeg
\\path\\to\\images\\garbage2.jpeg \\path\\to\\images\\xiaoduxiong2.jpeg
... ...
\\path\\to\\images\\garbagen.jpeg \\path\\to\\images\\xiaoduxiongn.jpeg
``` ```
```shell ```shell
.\detector --model_dir=\\path\\to\\inference_model --image_list=\\path\\to\\images_list.txt --use_gpu=1 --save_dir=output .\detector --model_dir=\\path\\to\\inference_model --image_list=\\path\\to\\images_list.txt --use_gpu=1 --save_dir=output
``` ```
图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。 图片文件`可视化预测结果`会保存在`save_dir`参数设置的目录下。
...@@ -29,7 +29,11 @@ def arg_parser(): ...@@ -29,7 +29,11 @@ def arg_parser():
action="store_true", action="store_true",
default=False, default=False,
help="export inference model for C++/Python deployment") help="export inference model for C++/Python deployment")
parser.add_argument(
"--fixed_input_shape",
"-fs",
default=None,
help="export inference model with fixed input shape:[w,h]")
return parser return parser
...@@ -53,9 +57,23 @@ def main(): ...@@ -53,9 +57,23 @@ def main():
if args.export_inference: if args.export_inference:
assert args.model_dir is not None, "--model_dir should be defined while exporting inference model" assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
assert args.save_dir is not None, "--save_dir should be defined to save inference model" assert args.save_dir is not None, "--save_dir should be defined to save inference model"
model = pdx.load_model(args.model_dir) fixed_input_shape = eval(args.fixed_input_shape)
assert len(
fixed_input_shape) == 2, "len of fixed input shape must == 2"
model = pdx.load_model(args.model_dir, fixed_input_shape)
model.export_inference_model(args.save_dir) model.export_inference_model(args.save_dir)
if args.export_onnx:
assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
fixed_input_shape = eval(args.fixed_input_shape)
assert len(
fixed_input_shape) == 2, "len of fixed input shape must == 2"
model = pdx.load_model(args.model_dir, fixed_input_shape)
model.export_onnx_model(args.save_dir)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -317,11 +317,11 @@ class BaseAPI: ...@@ -317,11 +317,11 @@ class BaseAPI:
model_info['_ModelInputsOutputs']['test_outputs'] = [ model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items() [k, v.name] for k, v in self.test_outputs.items()
] ]
with open( with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8', osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f: mode='w') as f:
yaml.dump(model_info, f) yaml.dump(model_info, f)
# 模型保存成功的标志 # 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close() open(osp.join(save_dir, '.success'), 'w').close()
logging.info( logging.info(
......
...@@ -46,10 +46,18 @@ class BaseClassifier(BaseAPI): ...@@ -46,10 +46,18 @@ class BaseClassifier(BaseAPI):
self.model_name = model_name self.model_name = model_name
self.labels = None self.labels = None
self.num_classes = num_classes self.num_classes = num_classes
self.fixed_input_shape = None
def build_net(self, mode='train'): def build_net(self, mode='train'):
image = fluid.data( if self.fixed_input_shape is not None:
dtype='float32', shape=[None, 3, None, None], name='image') input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
image = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
image = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if mode != 'test': if mode != 'test':
label = fluid.data(dtype='int64', shape=[None, 1], name='label') label = fluid.data(dtype='int64', shape=[None, 1], name='label')
model = getattr(paddlex.cv.nets, str.lower(self.model_name)) model = getattr(paddlex.cv.nets, str.lower(self.model_name))
......
...@@ -48,7 +48,6 @@ class DeepLabv3p(BaseAPI): ...@@ -48,7 +48,6 @@ class DeepLabv3p(BaseAPI):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1, 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
即平时使用的交叉熵损失函数。 即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。 ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
Raises: Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。 ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25', ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
...@@ -118,6 +117,7 @@ class DeepLabv3p(BaseAPI): ...@@ -118,6 +117,7 @@ class DeepLabv3p(BaseAPI):
self.enable_decoder = enable_decoder self.enable_decoder = enable_decoder
self.labels = None self.labels = None
self.sync_bn = True self.sync_bn = True
self.fixed_input_shape = None
def _get_backbone(self, backbone): def _get_backbone(self, backbone):
def mobilenetv2(backbone): def mobilenetv2(backbone):
...@@ -182,7 +182,8 @@ class DeepLabv3p(BaseAPI): ...@@ -182,7 +182,8 @@ class DeepLabv3p(BaseAPI):
use_bce_loss=self.use_bce_loss, use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss, use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight, class_weight=self.class_weight,
ignore_index=self.ignore_index) ignore_index=self.ignore_index,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict() outputs = OrderedDict()
......
...@@ -57,6 +57,7 @@ class FasterRCNN(BaseAPI): ...@@ -57,6 +57,7 @@ class FasterRCNN(BaseAPI):
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.anchor_sizes = anchor_sizes self.anchor_sizes = anchor_sizes
self.labels = None self.labels = None
self.fixed_input_shape = None
def _get_backbone(self, backbone_name): def _get_backbone(self, backbone_name):
norm_type = None norm_type = None
...@@ -109,7 +110,8 @@ class FasterRCNN(BaseAPI): ...@@ -109,7 +110,8 @@ class FasterRCNN(BaseAPI):
aspect_ratios=self.aspect_ratios, aspect_ratios=self.aspect_ratios,
anchor_sizes=self.anchor_sizes, anchor_sizes=self.anchor_sizes,
train_pre_nms_top_n=train_pre_nms_top_n, train_pre_nms_top_n=train_pre_nms_top_n,
test_pre_nms_top_n=test_pre_nms_top_n) test_pre_nms_top_n=test_pre_nms_top_n,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
if mode == 'train': if mode == 'train':
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
......
...@@ -23,7 +23,7 @@ import paddlex ...@@ -23,7 +23,7 @@ import paddlex
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
def load_model(model_dir): def load_model(model_dir, fixed_input_shape=None):
if not osp.exists(osp.join(model_dir, "model.yml")): if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There's not model.yml in {}".format(model_dir)) raise Exception("There's not model.yml in {}".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f: with open(osp.join(model_dir, "model.yml")) as f:
...@@ -44,6 +44,7 @@ def load_model(model_dir): ...@@ -44,6 +44,7 @@ def load_model(model_dir):
else: else:
model = getattr(paddlex.cv.models, model = getattr(paddlex.cv.models,
info['Model'])(**info['_init_params']) info['Model'])(**info['_init_params'])
model.fixed_input_shape = fixed_input_shape
if status == "Normal" or \ if status == "Normal" or \
status == "Prune" or status == "fluid.save": status == "Prune" or status == "fluid.save":
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -78,6 +79,8 @@ def load_model(model_dir): ...@@ -78,6 +79,8 @@ def load_model(model_dir):
model.test_outputs[var_desc[0]] = out model.test_outputs[var_desc[0]] = out
if 'Transforms' in info: if 'Transforms' in info:
transforms_mode = info.get('TransformsMode', 'RGB') transforms_mode = info.get('TransformsMode', 'RGB')
# 固定模型的输入shape
fix_input_shape(info, fixed_input_shape=fixed_input_shape)
if transforms_mode == 'RGB': if transforms_mode == 'RGB':
to_rgb = True to_rgb = True
else: else:
...@@ -102,6 +105,33 @@ def load_model(model_dir): ...@@ -102,6 +105,33 @@ def load_model(model_dir):
return model return model
def fix_input_shape(info, fixed_input_shape=None):
if fixed_input_shape is not None:
resize = {'ResizeByShort': {}}
padding = {'Padding': {}}
if info['_Attributes']['model_type'] == 'classifier':
crop_size = 0
for transform in info['Transforms']:
if 'CenterCrop' in transform:
crop_size = transform['CenterCrop']['crop_size']
break
assert crop_size == fixed_input_shape[
0], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
crop_size)
assert crop_size == fixed_input_shape[
1], "fixed_input_shape must == CenterCrop:crop_size:{}".format(
crop_size)
if crop_size == 0:
logging.warning(
"fixed_input_shape must == input shape when trainning")
else:
resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
padding['Padding']['target_size'] = list(fixed_input_shape)
info['Transforms'].append(resize)
info['Transforms'].append(padding)
def build_transforms(model_type, transforms_info, to_rgb=True): def build_transforms(model_type, transforms_info, to_rgb=True):
if model_type == "classifier": if model_type == "classifier":
import paddlex.cv.transforms.cls_transforms as T import paddlex.cv.transforms.cls_transforms as T
......
...@@ -60,6 +60,7 @@ class MaskRCNN(FasterRCNN): ...@@ -60,6 +60,7 @@ class MaskRCNN(FasterRCNN):
self.mask_head_resolution = 28 self.mask_head_resolution = 28
else: else:
self.mask_head_resolution = 14 self.mask_head_resolution = 14
self.fixed_input_shape = None
def build_net(self, mode='train'): def build_net(self, mode='train'):
train_pre_nms_top_n = 2000 if self.with_fpn else 12000 train_pre_nms_top_n = 2000 if self.with_fpn else 12000
...@@ -73,7 +74,8 @@ class MaskRCNN(FasterRCNN): ...@@ -73,7 +74,8 @@ class MaskRCNN(FasterRCNN):
train_pre_nms_top_n=train_pre_nms_top_n, train_pre_nms_top_n=train_pre_nms_top_n,
test_pre_nms_top_n=test_pre_nms_top_n, test_pre_nms_top_n=test_pre_nms_top_n,
num_convs=num_convs, num_convs=num_convs,
mask_head_resolution=self.mask_head_resolution) mask_head_resolution=self.mask_head_resolution,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
if mode == 'train': if mode == 'train':
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
......
...@@ -77,6 +77,7 @@ class UNet(DeepLabv3p): ...@@ -77,6 +77,7 @@ class UNet(DeepLabv3p):
self.class_weight = class_weight self.class_weight = class_weight
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.labels = None self.labels = None
self.fixed_input_shape = None
def build_net(self, mode='train'): def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.UNet( model = paddlex.cv.nets.segmentation.UNet(
...@@ -86,7 +87,8 @@ class UNet(DeepLabv3p): ...@@ -86,7 +87,8 @@ class UNet(DeepLabv3p):
use_bce_loss=self.use_bce_loss, use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss, use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight, class_weight=self.class_weight,
ignore_index=self.ignore_index) ignore_index=self.ignore_index,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict() outputs = OrderedDict()
......
...@@ -80,6 +80,7 @@ class YOLOv3(BaseAPI): ...@@ -80,6 +80,7 @@ class YOLOv3(BaseAPI):
self.label_smooth = label_smooth self.label_smooth = label_smooth
self.sync_bn = True self.sync_bn = True
self.train_random_shapes = train_random_shapes self.train_random_shapes = train_random_shapes
self.fixed_input_shape = None
def _get_backbone(self, backbone_name): def _get_backbone(self, backbone_name):
if backbone_name == 'DarkNet53': if backbone_name == 'DarkNet53':
...@@ -113,7 +114,8 @@ class YOLOv3(BaseAPI): ...@@ -113,7 +114,8 @@ class YOLOv3(BaseAPI):
nms_topk=self.nms_topk, nms_topk=self.nms_topk,
nms_keep_topk=self.nms_keep_topk, nms_keep_topk=self.nms_keep_topk,
nms_iou_threshold=self.nms_iou_threshold, nms_iou_threshold=self.nms_iou_threshold,
train_random_shapes=self.train_random_shapes) train_random_shapes=self.train_random_shapes,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)]) outputs = OrderedDict([('bbox', model_out)])
......
...@@ -76,7 +76,8 @@ class FasterRCNN(object): ...@@ -76,7 +76,8 @@ class FasterRCNN(object):
fg_thresh=.5, fg_thresh=.5,
bg_thresh_hi=.5, bg_thresh_hi=.5,
bg_thresh_lo=0., bg_thresh_lo=0.,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2]): bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
fixed_input_shape=None):
super(FasterRCNN, self).__init__() super(FasterRCNN, self).__init__()
self.backbone = backbone self.backbone = backbone
self.mode = mode self.mode = mode
...@@ -148,6 +149,7 @@ class FasterRCNN(object): ...@@ -148,6 +149,7 @@ class FasterRCNN(object):
self.bg_thresh_lo = bg_thresh_lo self.bg_thresh_lo = bg_thresh_lo
self.bbox_reg_weights = bbox_reg_weights self.bbox_reg_weights = bbox_reg_weights
self.rpn_only = rpn_only self.rpn_only = rpn_only
self.fixed_input_shape = fixed_input_shape
def build_net(self, inputs): def build_net(self, inputs):
im = inputs['image'] im = inputs['image']
...@@ -219,8 +221,16 @@ class FasterRCNN(object): ...@@ -219,8 +221,16 @@ class FasterRCNN(object):
def generate_inputs(self): def generate_inputs(self):
inputs = OrderedDict() inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') if self.fixed_input_shape is not None:
input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['im_info'] = fluid.data( inputs['im_info'] = fluid.data(
dtype='float32', shape=[None, 3], name='im_info') dtype='float32', shape=[None, 3], name='im_info')
......
...@@ -86,7 +86,8 @@ class MaskRCNN(object): ...@@ -86,7 +86,8 @@ class MaskRCNN(object):
fg_thresh=.5, fg_thresh=.5,
bg_thresh_hi=.5, bg_thresh_hi=.5,
bg_thresh_lo=0., bg_thresh_lo=0.,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2]): bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
fixed_input_shape=None):
super(MaskRCNN, self).__init__() super(MaskRCNN, self).__init__()
self.backbone = backbone self.backbone = backbone
self.mode = mode self.mode = mode
...@@ -167,6 +168,7 @@ class MaskRCNN(object): ...@@ -167,6 +168,7 @@ class MaskRCNN(object):
self.bg_thresh_lo = bg_thresh_lo self.bg_thresh_lo = bg_thresh_lo
self.bbox_reg_weights = bbox_reg_weights self.bbox_reg_weights = bbox_reg_weights
self.rpn_only = rpn_only self.rpn_only = rpn_only
self.fixed_input_shape = fixed_input_shape
def build_net(self, inputs): def build_net(self, inputs):
im = inputs['image'] im = inputs['image']
...@@ -306,8 +308,16 @@ class MaskRCNN(object): ...@@ -306,8 +308,16 @@ class MaskRCNN(object):
def generate_inputs(self): def generate_inputs(self):
inputs = OrderedDict() inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') if self.fixed_input_shape is not None:
input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['im_info'] = fluid.data( inputs['im_info'] = fluid.data(
dtype='float32', shape=[None, 3], name='im_info') dtype='float32', shape=[None, 3], name='im_info')
......
...@@ -33,7 +33,8 @@ class YOLOv3: ...@@ -33,7 +33,8 @@ class YOLOv3:
nms_iou_threshold=0.45, nms_iou_threshold=0.45,
train_random_shapes=[ train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
]): ],
fixed_input_shape=None):
if anchors is None: if anchors is None:
anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]] [59, 119], [116, 90], [156, 198], [373, 326]]
...@@ -54,6 +55,7 @@ class YOLOv3: ...@@ -54,6 +55,7 @@ class YOLOv3:
self.norm_decay = 0.0 self.norm_decay = 0.0
self.prefix_name = '' self.prefix_name = ''
self.train_random_shapes = train_random_shapes self.train_random_shapes = train_random_shapes
self.fixed_input_shape = fixed_input_shape
def _head(self, feats): def _head(self, feats):
outputs = [] outputs = []
...@@ -247,8 +249,15 @@ class YOLOv3: ...@@ -247,8 +249,15 @@ class YOLOv3:
def generate_inputs(self): def generate_inputs(self):
inputs = OrderedDict() inputs = OrderedDict()
inputs['image'] = fluid.data( if self.fixed_input_shape is not None:
dtype='float32', shape=[None, 3, None, None], name='image') input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['gt_box'] = fluid.data( inputs['gt_box'] = fluid.data(
dtype='float32', shape=[None, None, 4], name='gt_box') dtype='float32', shape=[None, None, 4], name='gt_box')
......
...@@ -61,6 +61,7 @@ class DeepLabv3p(object): ...@@ -61,6 +61,7 @@ class DeepLabv3p(object):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1, 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。 即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。 ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
Raises: Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。 ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
...@@ -81,7 +82,8 @@ class DeepLabv3p(object): ...@@ -81,7 +82,8 @@ class DeepLabv3p(object):
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
ignore_index=255): ignore_index=255,
fixed_input_shape=None):
# dice_loss或bce_loss只适用两类分割中 # dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss): if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError( raise ValueError(
...@@ -115,6 +117,7 @@ class DeepLabv3p(object): ...@@ -115,6 +117,7 @@ class DeepLabv3p(object):
self.decoder_use_sep_conv = decoder_use_sep_conv self.decoder_use_sep_conv = decoder_use_sep_conv
self.encoder_with_aspp = encoder_with_aspp self.encoder_with_aspp = encoder_with_aspp
self.enable_decoder = enable_decoder self.enable_decoder = enable_decoder
self.fixed_input_shape = fixed_input_shape
def _encoder(self, input): def _encoder(self, input):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
...@@ -310,8 +313,16 @@ class DeepLabv3p(object): ...@@ -310,8 +313,16 @@ class DeepLabv3p(object):
def generate_inputs(self): def generate_inputs(self):
inputs = OrderedDict() inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') if self.fixed_input_shape is not None:
input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['label'] = fluid.data( inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label') dtype='int32', shape=[None, 1, None, None], name='label')
......
...@@ -54,6 +54,7 @@ class UNet(object): ...@@ -54,6 +54,7 @@ class UNet(object):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1, 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。 即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。 ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
Raises: Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。 ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
...@@ -69,7 +70,8 @@ class UNet(object): ...@@ -69,7 +70,8 @@ class UNet(object):
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
ignore_index=255): ignore_index=255,
fixed_input_shape=None):
# dice_loss或bce_loss只适用两类分割中 # dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss): if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise Exception( raise Exception(
...@@ -97,6 +99,7 @@ class UNet(object): ...@@ -97,6 +99,7 @@ class UNet(object):
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
self.class_weight = class_weight self.class_weight = class_weight
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.fixed_input_shape = fixed_input_shape
def _double_conv(self, data, out_ch): def _double_conv(self, data, out_ch):
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
...@@ -226,8 +229,16 @@ class UNet(object): ...@@ -226,8 +229,16 @@ class UNet(object):
def generate_inputs(self): def generate_inputs(self):
inputs = OrderedDict() inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') if self.fixed_input_shape is not None:
input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['label'] = fluid.data( inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label') dtype='int32', shape=[None, 1, None, None], name='label')
......
...@@ -93,6 +93,8 @@ class Compose: ...@@ -93,6 +93,8 @@ class Compose:
# make default im_info with [h, w, 1] # make default im_info with [h, w, 1]
im_info['im_resize_info'] = np.array( im_info['im_resize_info'] = np.array(
[im.shape[0], im.shape[1], 1.], dtype=np.float32) [im.shape[0], im.shape[1], 1.], dtype=np.float32)
im_info['image_shape'] = np.array([im.shape[0],
im.shape[1]]).astype('int32')
if not self.use_mixup: if not self.use_mixup:
if 'mixup' in im_info: if 'mixup' in im_info:
del im_info['mixup'] del im_info['mixup']
...@@ -193,11 +195,16 @@ class ResizeByShort: ...@@ -193,11 +195,16 @@ class ResizeByShort:
class Padding: class Padding:
"""将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640], """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
`coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值 `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
进行padding,最终输出图像为[320, 640]。 进行padding,最终输出图像为[320, 640]。
2.或者,将图像的长和宽padding到target_size指定的shape,如输入的图像为[300,640],
a. `target_size` = 960,在图像最右和最下使用0值进行padding,最终输出
图像为[960, 960]。
b. `target_size` = [640, 960],在图像最右和最下使用0值进行padding,最终
输出图像为[640, 960]。
1. 如果coarsest_stride为1则直接返回。 1. 如果coarsest_stride为1,target_size为None则直接返回。
2. 获取图像的高H、宽W。 2. 获取图像的高H、宽W。
3. 计算填充后图像的高H_new、宽W_new。 3. 计算填充后图像的高H_new、宽W_new。
4. 构建大小为(H_new, W_new, 3)像素值为0的np.ndarray, 4. 构建大小为(H_new, W_new, 3)像素值为0的np.ndarray,
...@@ -205,10 +212,26 @@ class Padding: ...@@ -205,10 +212,26 @@ class Padding:
Args: Args:
coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。 coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
target_size (int|list|tuple): 填充后的图像长、宽,默认为None,coarset_stride优先级更高。
Raises:
TypeError: 形参`target_size`数据类型不满足需求。
ValueError: 形参`target_size`为(list|tuple)时,长度不满足需求。
""" """
def __init__(self, coarsest_stride=1): def __init__(self, coarsest_stride=1, target_size=None):
self.coarsest_stride = coarsest_stride self.coarsest_stride = coarsest_stride
if target_size is not None:
if not isinstance(target_size, int):
if not isinstance(target_size, tuple) and not isinstance(
target_size, list):
raise TypeError(
"Padding: Type of target_size must in (int|list|tuple)."
)
elif len(target_size) != 2:
raise ValueError(
"Padding: Length of target_size must equal 2.")
self.target_size = target_size
def __call__(self, im, im_info=None, label_info=None): def __call__(self, im, im_info=None, label_info=None):
""" """
...@@ -225,13 +248,9 @@ class Padding: ...@@ -225,13 +248,9 @@ class Padding:
Raises: Raises:
TypeError: 形参数据类型不满足需求。 TypeError: 形参数据类型不满足需求。
ValueError: 数据长度不匹配。 ValueError: 数据长度不匹配。
ValueError: coarsest_stride,target_size需有且只有一个被指定。
ValueError: target_size小于原图的大小。
""" """
if self.coarsest_stride == 1:
if label_info is None:
return (im, im_info)
else:
return (im, im_info, label_info)
if im_info is None: if im_info is None:
im_info = dict() im_info = dict()
if not isinstance(im, np.ndarray): if not isinstance(im, np.ndarray):
...@@ -239,11 +258,29 @@ class Padding: ...@@ -239,11 +258,29 @@ class Padding:
if len(im.shape) != 3: if len(im.shape) != 3:
raise ValueError('Padding: image is not 3-dimensional.') raise ValueError('Padding: image is not 3-dimensional.')
im_h, im_w, im_c = im.shape[:] im_h, im_w, im_c = im.shape[:]
if self.coarsest_stride > 1:
if isinstance(self.target_size, int):
padding_im_h = self.target_size
padding_im_w = self.target_size
elif isinstance(self.target_size, list) or isinstance(
self.target_size, tuple):
padding_im_w = self.target_size[0]
padding_im_h = self.target_size[1]
elif self.coarsest_stride > 0:
padding_im_h = int( padding_im_h = int(
np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride) np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
padding_im_w = int( padding_im_w = int(
np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride) np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
else:
raise ValueError(
"coarsest_stridei(>1) or target_size(list|int) need setting in Padding transform"
)
pad_height = padding_im_h - im_h
pad_width = padding_im_w - im_w
if pad_height < 0 or pad_width < 0:
raise ValueError(
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.format(im_w, im_h, padding_im_w, padding_im_h))
padding_im = np.zeros((padding_im_h, padding_im_w, im_c), padding_im = np.zeros((padding_im_h, padding_im_w, im_c),
dtype=np.float32) dtype=np.float32)
padding_im[:im_h, :im_w, :] = im padding_im[:im_h, :im_w, :] = im
...@@ -539,7 +576,7 @@ class RandomDistort: ...@@ -539,7 +576,7 @@ class RandomDistort:
params = params_dict[ops[id].__name__] params = params_dict[ops[id].__name__]
prob = prob_dict[ops[id].__name__] prob = prob_dict[ops[id].__name__]
params['im'] = im params['im'] = im
if np.random.uniform(0, 1) < prob: if np.random.uniform(0, 1) < prob:
im = ops[id](**params) im = ops[id](**params)
if label_info is None: if label_info is None:
......
...@@ -285,7 +285,7 @@ class ResizeByLong: ...@@ -285,7 +285,7 @@ class ResizeByLong:
当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
存储与图像相关信息的字典和标注图像np.ndarray数据。 存储与图像相关信息的字典和标注图像np.ndarray数据。
其中,im_info新增字段为: 其中,im_info新增字段为:
-shape_before_resize (tuple): 保存resize之前图像的形状(h, w -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)
""" """
if im_info is None: if im_info is None:
im_info = OrderedDict() im_info = OrderedDict()
...@@ -301,6 +301,83 @@ class ResizeByLong: ...@@ -301,6 +301,83 @@ class ResizeByLong:
return (im, im_info, label) return (im, im_info, label)
class ResizeByShort:
"""根据图像的短边调整图像大小(resize)。
1. 获取图像的长边和短边长度。
2. 根据短边与short_size的比例,计算长边的目标长度,
此时高、宽的resize比例为short_size/原图短边长度。
3. 如果max_size>0,调整resize比例:
如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度。
4. 根据调整大小的比例对图像进行resize。
Args:
target_size (int): 短边目标长度。默认为800。
max_size (int): 长边目标长度的最大限制。默认为1333。
Raises:
TypeError: 形参数据类型不满足需求。
"""
def __init__(self, short_size=800, max_size=1333):
self.max_size = int(max_size)
if not isinstance(short_size, int):
raise TypeError(
"Type of short_size is invalid. Must be Integer, now is {}".
format(type(short_size)))
self.short_size = short_size
if not (isinstance(self.max_size, int)):
raise TypeError("max_size: input type is invalid.")
def __call__(self, im, im_info=None, label=None):
"""
Args:
im (numnp.ndarraypy): 图像np.ndarray数据。
im_info (list): 存储图像reisze或padding前的shape信息,如
[('resize', [200, 300]), ('padding', [400, 600])]表示
图像在过resize前shape为(200, 300), 过padding前shape为
(400, 600)
label (np.ndarray): 标注图像np.ndarray数据。
Returns:
tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
存储与图像相关信息的字典和标注图像np.ndarray数据。
其中,im_info更新字段为:
-shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
Raises:
TypeError: 形参数据类型不满足需求。
ValueError: 数据长度不匹配。
"""
if im_info is None:
im_info = OrderedDict()
if not isinstance(im, np.ndarray):
raise TypeError("ResizeByShort: image type is not numpy.")
if len(im.shape) != 3:
raise ValueError('ResizeByShort: image is not 3-dimensional.')
im_info.append(('resize', im.shape[:2]))
im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size
if self.max_size > 0 and np.round(
scale * im_long_size) > self.max_size:
scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale))
im = cv2.resize(
im, (resized_width, resized_height),
interpolation=cv2.INTER_NEAREST)
if label is not None:
im = cv2.resize(
label, (resized_width, resized_height),
interpolation=cv2.INTER_NEAREST)
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
class ResizeRangeScaling: class ResizeRangeScaling:
"""对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册