提交 3db1e41e 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into refine/op/lstm

...@@ -305,9 +305,9 @@ paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'neg ...@@ -305,9 +305,9 @@ paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'neg
paddle.fluid.layers.detection_output ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0)) paddle.fluid.layers.detection_output ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0))
paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)) paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None))
paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')) paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'anchor_var', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3)) paddle.fluid.layers.rpn_target_assign ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True))
paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)) paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'gt_boxes', 'im_scales', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None)) paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
......
...@@ -120,13 +120,20 @@ void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map, ...@@ -120,13 +120,20 @@ void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
outputs.insert(node); outputs.insert(node);
} }
// update the dst and src node's inlinks and outlinks. // update the dst and src node's inlinks and outlinks.
#ifdef __clang__
src_node->inlinks = std::vector<BriefNode *>(inputs.begin(), inputs.end());
src_node->outlinks = std::vector<BriefNode *>(outputs.begin(), outputs.end());
dst_node->inlinks.clear();
dst_node->outlinks.clear();
#else
src_node->inlinks = src_node->inlinks =
std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end())); std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end()));
src_node->outlinks = src_node->outlinks =
std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end())); std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end()));
dst_node->inlinks.clear(); dst_node->inlinks.clear();
dst_node->outlinks.clear(); dst_node->outlinks.clear();
#endif
auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) { auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
for (auto *&n : nodes) { for (auto *&n : nodes) {
......
...@@ -77,6 +77,9 @@ bool AnalysisPredictor::Init( ...@@ -77,6 +77,9 @@ bool AnalysisPredictor::Init(
OptimizeInferenceProgram(); OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_._use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
VLOG(5) << "to create variables"; VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get()); PADDLE_ENFORCE(scope_.get());
......
...@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -64,13 +64,15 @@ PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) { ...@@ -64,13 +64,15 @@ PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) {
void PaddleBuf::Resize(size_t length) { void PaddleBuf::Resize(size_t length) {
// Only the owned memory can be reset, the external memory can't be changed. // Only the owned memory can be reset, the external memory can't be changed.
if (length_ == length) return; if (length_ >= length) return;
if (memory_owned_) { if (memory_owned_) {
Free(); Free();
data_ = malloc(length);
length_ = length;
memory_owned_ = true;
} else {
PADDLE_THROW("The memory is allocated externally, can not Resized");
} }
data_ = new char[length];
length_ = length;
memory_owned_ = true;
} }
void PaddleBuf::Reset(void* data, size_t length) { void PaddleBuf::Reset(void* data, size_t length) {
...@@ -82,8 +84,8 @@ void PaddleBuf::Reset(void* data, size_t length) { ...@@ -82,8 +84,8 @@ void PaddleBuf::Reset(void* data, size_t length) {
void PaddleBuf::Free() { void PaddleBuf::Free() {
if (memory_owned_ && data_) { if (memory_owned_ && data_) {
assert(length_ > 0); PADDLE_ENFORCE_GT(length_, 0);
delete[] static_cast<char*>(data_); free(static_cast<char*>(data_));
data_ = nullptr; data_ = nullptr;
length_ = 0; length_ = 0;
} }
......
...@@ -106,6 +106,9 @@ bool NativePaddlePredictor::Init( ...@@ -106,6 +106,9 @@ bool NativePaddlePredictor::Init(
} }
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_._use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
executor_->CreateVariables(*inference_program_, executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0); sub_scope_ ? sub_scope_ : scope_.get(), 0);
......
...@@ -45,7 +45,7 @@ class PaddleBuf { ...@@ -45,7 +45,7 @@ class PaddleBuf {
PaddleBuf(void* data, size_t length) PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {} : data_(data), length_(length), memory_owned_{false} {}
// Own memory. // Own memory.
PaddleBuf(size_t length) explicit PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {} : data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes. // Resize to `length` bytes.
void Resize(size_t length); void Resize(size_t length);
...@@ -121,6 +121,8 @@ struct NativeConfig : public PaddlePredictor::Config { ...@@ -121,6 +121,8 @@ struct NativeConfig : public PaddlePredictor::Config {
bool use_gpu{false}; bool use_gpu{false};
int device{0}; int device{0};
float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization. float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization.
// NOTE: NOT use it, just for the internal test, will discard later
bool _use_mkldnn{false};
// Specify the variable's name of each input. // Specify the variable's name of each input.
bool specify_input_name{false}; bool specify_input_name{false};
......
...@@ -53,5 +53,21 @@ set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classifi ...@@ -53,5 +53,21 @@ set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classifi
download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz") download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz")
inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/model
--infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt) --infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt)
# ocr
set(OCR_MODEL_URL "http://paddlemodels.cdn.bcebos.com/inference-vis-demos%2Focr.tar.gz")
set(OCR_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR} AND WITH_INFERENCE)
get_filename_component(filename ${OCR_MODEL_URL} NAME)
message(STATUS "Download inference test stuff ${filename} from ${OCR_MODEL_URL}")
execute_process(COMMAND bash -c "mkdir -p ${OCR_INSTALL_DIR}")
execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && wget -q ${OCR_MODEL_URL}")
execute_process(COMMAND bash -c "cd ${OCR_INSTALL_DIR} && tar xzf ${filename}")
message(STATUS "finish downloading ${filename}")
endif()
inference_analysis_test(test_analyzer_ocr SRCS analyzer_vis_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${OCR_INSTALL_DIR}/model
--infer_data=${OCR_INSTALL_DIR}/data.txt)
...@@ -110,8 +110,7 @@ const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, ...@@ -110,8 +110,7 @@ const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25,
void TestLACPrediction(const std::string &model_path, void TestLACPrediction(const std::string &model_path,
const std::string &data_file, const int batch_size, const std::string &data_file, const int batch_size,
const int repeat, bool test_all_data, const int repeat, bool use_analysis = false) {
bool use_analysis = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
cfg.model_dir = model_path; cfg.model_dir = model_path;
cfg.use_gpu = false; cfg.use_gpu = false;
...@@ -199,13 +198,13 @@ void TestLACPrediction(const std::string &model_path, ...@@ -199,13 +198,13 @@ void TestLACPrediction(const std::string &model_path,
TEST(Analyzer_LAC, native) { TEST(Analyzer_LAC, native) {
LOG(INFO) << "LAC with native"; LOG(INFO) << "LAC with native";
TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size, TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size,
FLAGS_repeat, FLAGS_test_all_data); FLAGS_repeat);
} }
TEST(Analyzer_LAC, analysis) { TEST(Analyzer_LAC, analysis) {
LOG(INFO) << "LAC with analysis"; LOG(INFO) << "LAC with analysis";
TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size, TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size,
FLAGS_repeat, FLAGS_test_all_data, true); FLAGS_repeat, true);
} }
} // namespace analysis } // namespace analysis
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <iostream>
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
struct Record {
std::vector<float> data;
std::vector<int32_t> shape;
};
Record ProcessALine(const std::string &line) {
VLOG(3) << "process a line";
std::vector<std::string> columns;
split(line, '\t', &columns);
CHECK_EQ(columns.size(), 2UL)
<< "data format error, should be <data>\t<shape>";
Record record;
std::vector<std::string> data_strs;
split(columns[0], ' ', &data_strs);
for (auto &d : data_strs) {
record.data.push_back(std::stof(d));
}
std::vector<std::string> shape_strs;
split(columns[1], ' ', &shape_strs);
for (auto &s : shape_strs) {
record.shape.push_back(std::stoi(s));
}
VLOG(3) << "data size " << record.data.size();
VLOG(3) << "data shape size " << record.shape.size();
return record;
}
/*
* Use the native and analysis fluid engine to inference the demo.
* ocr, mobilenet and se_resnext50
*/
void TestVisualPrediction(bool use_mkldnn) {
std::unique_ptr<PaddlePredictor> predictor;
AnalysisConfig cfg;
cfg.param_file = FLAGS_infer_model + "/__params__";
cfg.prog_file = FLAGS_infer_model + "/__model__";
cfg.use_gpu = false;
cfg._use_mkldnn = use_mkldnn;
cfg.device = 0;
cfg.enable_ir_optim = true;
// TODO(TJ): fix fusion gru
cfg.ir_passes.push_back("fc_gru_fuse_pass");
#ifdef PADDLE_WITH_MKLDNN
// disable mkldnn fuse since it should have some bugs
cfg.ir_passes.push_back("conv_relu_mkldnn_fuse_pass");
#endif
predictor =
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(cfg);
// Only have single batch of data.
std::string line;
std::ifstream file(FLAGS_infer_data);
std::getline(file, line);
auto record = ProcessALine(line);
file.close();
// Inference.
PaddleTensor input;
input.shape = record.shape;
input.data =
PaddleBuf(record.data.data(), record.data.size() * sizeof(float));
input.dtype = PaddleDType::FLOAT32;
std::vector<PaddleTensor> outputs_slots;
Timer timer;
timer.tic();
for (int i = 0; i < FLAGS_repeat; i++) {
predictor->Run({input}, &outputs_slots);
}
PrintTime(/*batch size*/ 1, FLAGS_repeat, /*num threads*/ 1, /*thread id*/ 0,
timer.toc() / FLAGS_repeat);
VLOG(3) << "output.size " << outputs_slots.size();
// run native as reference
auto ref_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
std::vector<PaddleTensor> ref_outputs_slots;
ref_predictor->Run({input}, &ref_outputs_slots);
CompareResult(outputs_slots, ref_outputs_slots);
// print what are fused
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
.Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr);
for (auto &item : fuse_statis) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
int num_ops = 0;
for (auto &node :
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
if (node->IsFunction()) {
++num_ops;
}
}
LOG(INFO) << "has num ops: " << num_ops;
}
TEST(Analyzer_vis, analysis) { TestVisualPrediction(/*use_mkldnn*/ false); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_vis, analysis_mkldnn) {
TestVisualPrediction(/*use_mkldnn*/ true);
}
#endif
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -37,22 +37,37 @@ namespace paddle { ...@@ -37,22 +37,37 @@ namespace paddle {
namespace inference { namespace inference {
void CompareResult(const std::vector<PaddleTensor> &outputs, void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &base_outputs) { const std::vector<PaddleTensor> &ref_outputs) {
PADDLE_ENFORCE_GT(outputs.size(), 0); EXPECT_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size()); EXPECT_EQ(outputs.size(), ref_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) { for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i]; auto &out = outputs[i];
auto &base_out = base_outputs[i]; auto &ref_out = ref_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1, size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; }); [](int a, int b) { return a * b; });
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(), size_t ref_size =
1, [](int a, int b) { return a * b; }); std::accumulate(ref_out.shape.begin(), ref_out.shape.end(), 1,
PADDLE_ENFORCE_EQ(size, size1); [](int a, int b) { return a * b; });
PADDLE_ENFORCE_GT(size, 0); EXPECT_GT(size, 0);
float *data = static_cast<float *>(out.data.data()); EXPECT_EQ(size, ref_size);
float *base_data = static_cast<float *>(base_out.data.data()); EXPECT_EQ(out.dtype, ref_out.dtype);
for (size_t i = 0; i < size; i++) { switch (out.dtype) {
EXPECT_NEAR(data[i], base_data[i], 1e-3); case PaddleDType::INT64: {
int64_t *pdata = static_cast<int64_t *>(out.data.data());
int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_NEAR(pdata_ref[j], pdata[j], 1e-3);
}
break;
}
} }
} }
} }
......
...@@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
// TODO: add support for dilation // TODO: add support for dilation
...@@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, strides, paddings, mkldnn_engine,
paddings, mkldnn_engine, fuse_relu); fuse_relu, fuse_eltwise);
} else { } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, conv_pd =
paddings, mkldnn_engine, fuse_relu); ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_eltwise);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
private: private:
mkldnn::primitive_attr AddRelu() const { mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
// Fusion with ReLU layer is executed through the PostOps feature. Create a bool fuse_eltwise) const {
// PostOps object and configure it to execute an eltwise relu operation.
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, // Fusion with Elementwise layer relies on adding a sum post-operation with
negative_slope, placeholder); // the scale parameter. It is assumed that when fuse_eltwise is true, the
// Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out.
if (fuse_eltwise) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
} }
...@@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_relu) const { const bool fuse_eltwise) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst, const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_relu) const { const bool fuse_eltwise) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims, bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
......
...@@ -164,6 +164,11 @@ void Conv2DOpMaker::Make() { ...@@ -164,6 +164,11 @@ void Conv2DOpMaker::Make() {
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_eltwise",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is connected via skip connection "
"to a previous layer.")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
......
...@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -21,7 +22,7 @@ namespace operators { ...@@ -21,7 +22,7 @@ namespace operators {
*/ */
template <typename T> template <typename T>
inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes, inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes,
const framework::Tensor& gt_boxes, const T* weights, const framework::Tensor& gt_boxes, const float* weights,
const bool normalized, framework::Tensor* box_delta) { const bool normalized, framework::Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes); auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes); auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
...@@ -62,5 +63,35 @@ void Gather(const T* in, const int in_stride, const int* index, const int num, ...@@ -62,5 +63,35 @@ void Gather(const T* in, const int in_stride, const int* index, const int num,
} }
} }
template <typename T>
void BboxOverlaps(const framework::Tensor& r_boxes,
const framework::Tensor& c_boxes,
framework::Tensor* overlaps) {
auto r_boxes_et = framework::EigenTensor<T, 2>::From(r_boxes);
auto c_boxes_et = framework::EigenTensor<T, 2>::From(c_boxes);
auto overlaps_et = framework::EigenTensor<T, 2>::From(*overlaps);
int r_num = r_boxes.dims()[0];
int c_num = c_boxes.dims()[0];
auto zero = static_cast<T>(0.0);
T r_box_area, c_box_area, x_min, y_min, x_max, y_max, inter_w, inter_h,
inter_area;
for (int i = 0; i < r_num; ++i) {
r_box_area = (r_boxes_et(i, 2) - r_boxes_et(i, 0) + 1) *
(r_boxes_et(i, 3) - r_boxes_et(i, 1) + 1);
for (int j = 0; j < c_num; ++j) {
c_box_area = (c_boxes_et(j, 2) - c_boxes_et(j, 0) + 1) *
(c_boxes_et(j, 3) - c_boxes_et(j, 1) + 1);
x_min = std::max(r_boxes_et(i, 0), c_boxes_et(j, 0));
y_min = std::max(r_boxes_et(i, 1), c_boxes_et(j, 1));
x_max = std::min(r_boxes_et(i, 2), c_boxes_et(j, 2));
y_max = std::min(r_boxes_et(i, 3), c_boxes_et(j, 3));
inter_w = std::max(x_max - x_min + 1, zero);
inter_h = std::max(y_max - y_min + 1, zero);
inter_area = inter_w * inter_h;
overlaps_et(i, j) = inter_area / (r_box_area + c_box_area - inter_area);
}
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -42,10 +42,11 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { ...@@ -42,10 +42,11 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
"Input(RpnRois) shouldn't be null."); "Input(RpnRois) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("GtClasses"), PADDLE_ENFORCE(ctx->HasInput("GtClasses"),
"Input(GtClasses) shouldn't be null."); "Input(GtClasses) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("IsCrowd"),
"Input(IsCrowd) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("GtBoxes"), PADDLE_ENFORCE(ctx->HasInput("GtBoxes"),
"Input(GtBoxes) shouldn't be null."); "Input(GtBoxes) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("ImScales"), PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null.");
"Input(ImScales) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasOutput("Rois"), PADDLE_ENFORCE(ctx->HasOutput("Rois"),
"Output(Rois) of RpnTargetAssignOp should not be null"); "Output(Rois) of RpnTargetAssignOp should not be null");
...@@ -64,22 +65,21 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { ...@@ -64,22 +65,21 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
auto rpn_rois_dims = ctx->GetInputDim("RpnRois"); auto rpn_rois_dims = ctx->GetInputDim("RpnRois");
auto gt_classes_dims = ctx->GetInputDim("GtClasses"); auto gt_classes_dims = ctx->GetInputDim("GtClasses");
auto is_crowd_dims = ctx->GetInputDim("IsCrowd");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
auto im_scales_dims = ctx->GetInputDim("ImScales"); auto im_info_dims = ctx->GetInputDim("ImInfo");
PADDLE_ENFORCE_EQ(rpn_rois_dims.size(), 2, PADDLE_ENFORCE_EQ(rpn_rois_dims.size(), 2,
"The rank of Input(RpnRois) must be 2."); "The rank of Input(RpnRois) must be 2.");
PADDLE_ENFORCE_EQ(gt_classes_dims.size(), 1,
"The rank of Input(GtClasses) must be 1.");
PADDLE_ENFORCE_EQ(gt_boxes_dims.size(), 2, PADDLE_ENFORCE_EQ(gt_boxes_dims.size(), 2,
"The rank of Input(GtBoxes) must be 2."); "The rank of Input(GtBoxes) must be 2.");
PADDLE_ENFORCE_EQ(im_scales_dims.size(), 1, PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(ImScales) must be 1."); "The rank of Input(ImInfo) must be 2.");
int class_nums = ctx->Attrs().Get<int>("class_nums"); int class_nums = ctx->Attrs().Get<int>("class_nums");
ctx->SetOutputDim("Rois", {-1, 4}); ctx->SetOutputDim("Rois", {-1, 4});
ctx->SetOutputDim("LabelsInt32", {-1}); ctx->SetOutputDim("LabelsInt32", {-1, 1});
ctx->SetOutputDim("BboxTargets", {-1, 4 * class_nums}); ctx->SetOutputDim("BboxTargets", {-1, 4 * class_nums});
ctx->SetOutputDim("BboxInsideWeights", {-1, 4 * class_nums}); ctx->SetOutputDim("BboxInsideWeights", {-1, 4 * class_nums});
ctx->SetOutputDim("BboxOutsideWeights", {-1, 4 * class_nums}); ctx->SetOutputDim("BboxOutsideWeights", {-1, 4 * class_nums});
...@@ -105,45 +105,18 @@ void Concat(const platform::CPUDeviceContext& context, ...@@ -105,45 +105,18 @@ void Concat(const platform::CPUDeviceContext& context,
concat_functor(context, inputs, axis, out_tensor); concat_functor(context, inputs, axis, out_tensor);
} }
template <typename T>
void BboxOverlaps(const Tensor& r_boxes, const Tensor& c_boxes,
Tensor* overlaps) {
auto r_boxes_et = framework::EigenTensor<T, 2>::From(r_boxes);
auto c_boxes_et = framework::EigenTensor<T, 2>::From(c_boxes);
auto overlaps_et = framework::EigenTensor<T, 2>::From(*overlaps);
int r_num = r_boxes.dims()[0];
int c_num = c_boxes.dims()[0];
auto zero = static_cast<T>(0.0);
T r_box_area, c_box_area, x_min, y_min, x_max, y_max, inter_w, inter_h,
inter_area;
for (int i = 0; i < r_num; ++i) {
r_box_area = (r_boxes_et(i, 2) - r_boxes_et(i, 0) + 1) *
(r_boxes_et(i, 3) - r_boxes_et(i, 1) + 1);
for (int j = 0; j < c_num; ++j) {
c_box_area = (c_boxes_et(j, 2) - c_boxes_et(j, 0) + 1) *
(c_boxes_et(j, 3) - c_boxes_et(j, 1) + 1);
x_min = std::max(r_boxes_et(i, 0), c_boxes_et(j, 0));
y_min = std::max(r_boxes_et(i, 1), c_boxes_et(j, 1));
x_max = std::min(r_boxes_et(i, 2), c_boxes_et(j, 2));
y_max = std::min(r_boxes_et(i, 3), c_boxes_et(j, 3));
inter_w = std::max(x_max - x_min + 1, zero);
inter_h = std::max(y_max - y_min + 1, zero);
inter_area = inter_w * inter_h;
overlaps_et(i, j) = inter_area / (r_box_area + c_box_area - inter_area);
}
}
}
template <typename T> template <typename T>
std::vector<std::vector<int>> SampleFgBgGt( std::vector<std::vector<int>> SampleFgBgGt(
const platform::CPUDeviceContext& context, Tensor* iou, const platform::CPUDeviceContext& context, Tensor* iou,
const int batch_size_per_im, const float fg_fraction, const float fg_thresh, const Tensor& is_crowd, const int batch_size_per_im,
const float bg_thresh_hi, const float bg_thresh_lo, const float fg_fraction, const float fg_thresh, const float bg_thresh_hi,
std::minstd_rand engine) { const float bg_thresh_lo, std::minstd_rand engine, const bool use_random) {
std::vector<int> fg_inds; std::vector<int> fg_inds;
std::vector<int> bg_inds; std::vector<int> bg_inds;
std::vector<int> gt_inds; std::vector<int> gt_inds;
T* proposal_to_gt_overlaps = iou->mutable_data<T>(context.GetPlace()); int64_t gt_num = is_crowd.numel();
const int* crowd_data = is_crowd.data<int>();
T* proposal_to_gt_overlaps = iou->data<T>();
int64_t row = iou->dims()[0]; int64_t row = iou->dims()[0];
int64_t col = iou->dims()[1]; int64_t col = iou->dims()[1];
float epsilon = 0.00001; float epsilon = 0.00001;
...@@ -152,6 +125,9 @@ std::vector<std::vector<int>> SampleFgBgGt( ...@@ -152,6 +125,9 @@ std::vector<std::vector<int>> SampleFgBgGt(
for (int64_t i = 0; i < row; ++i) { for (int64_t i = 0; i < row; ++i) {
const T* v = proposal_to_gt_overlaps + i * col; const T* v = proposal_to_gt_overlaps + i * col;
T max_overlap = *std::max_element(v, v + col); T max_overlap = *std::max_element(v, v + col);
if ((i < gt_num) && (crowd_data[i])) {
max_overlap = -1.0;
}
if (max_overlap > fg_thresh) { if (max_overlap > fg_thresh) {
for (int64_t j = 0; j < col; ++j) { for (int64_t j = 0; j < col; ++j) {
T val = proposal_to_gt_overlaps[i * col + j]; T val = proposal_to_gt_overlaps[i * col + j];
...@@ -170,17 +146,19 @@ std::vector<std::vector<int>> SampleFgBgGt( ...@@ -170,17 +146,19 @@ std::vector<std::vector<int>> SampleFgBgGt(
} }
// Reservoir Sampling // Reservoir Sampling
std::uniform_real_distribution<float> uniform(0, 1);
int fg_rois_per_im = std::floor(batch_size_per_im * fg_fraction); int fg_rois_per_im = std::floor(batch_size_per_im * fg_fraction);
int fg_rois_this_image = fg_inds.size(); int fg_rois_this_image = fg_inds.size();
int fg_rois_per_this_image = std::min(fg_rois_per_im, fg_rois_this_image); int fg_rois_per_this_image = std::min(fg_rois_per_im, fg_rois_this_image);
std::uniform_real_distribution<float> uniform(0, 1); if (use_random) {
const int64_t fg_size = static_cast<int64_t>(fg_inds.size()); const int64_t fg_size = static_cast<int64_t>(fg_inds.size());
if (fg_size > fg_rois_per_this_image) { if (fg_size > fg_rois_per_this_image) {
for (int64_t i = fg_rois_per_this_image; i < fg_size; ++i) { for (int64_t i = fg_rois_per_this_image; i < fg_size; ++i) {
int rng_ind = std::floor(uniform(engine) * i); int rng_ind = std::floor(uniform(engine) * i);
if (rng_ind < fg_rois_per_this_image) { if (rng_ind < fg_rois_per_this_image) {
std::iter_swap(fg_inds.begin() + rng_ind, fg_inds.begin() + i); std::iter_swap(fg_inds.begin() + rng_ind, fg_inds.begin() + i);
std::iter_swap(gt_inds.begin() + rng_ind, gt_inds.begin() + i); std::iter_swap(gt_inds.begin() + rng_ind, gt_inds.begin() + i);
}
} }
} }
} }
...@@ -192,12 +170,14 @@ std::vector<std::vector<int>> SampleFgBgGt( ...@@ -192,12 +170,14 @@ std::vector<std::vector<int>> SampleFgBgGt(
int bg_rois_per_image = batch_size_per_im - fg_rois_per_this_image; int bg_rois_per_image = batch_size_per_im - fg_rois_per_this_image;
int bg_rois_this_image = bg_inds.size(); int bg_rois_this_image = bg_inds.size();
int bg_rois_per_this_image = std::min(bg_rois_per_image, bg_rois_this_image); int bg_rois_per_this_image = std::min(bg_rois_per_image, bg_rois_this_image);
const int64_t bg_size = static_cast<int64_t>(bg_inds.size()); if (use_random) {
if (bg_size > bg_rois_per_this_image) { const int64_t bg_size = static_cast<int64_t>(bg_inds.size());
for (int64_t i = bg_rois_per_this_image; i < bg_size; ++i) { if (bg_size > bg_rois_per_this_image) {
int rng_ind = std::floor(uniform(engine) * i); for (int64_t i = bg_rois_per_this_image; i < bg_size; ++i) {
if (rng_ind < fg_rois_per_this_image) int rng_ind = std::floor(uniform(engine) * i);
std::iter_swap(bg_inds.begin() + rng_ind, bg_inds.begin() + i); if (rng_ind < fg_rois_per_this_image)
std::iter_swap(bg_inds.begin() + rng_ind, bg_inds.begin() + i);
}
} }
} }
std::vector<int> new_bg_inds(bg_inds.begin(), std::vector<int> new_bg_inds(bg_inds.begin(),
...@@ -248,14 +228,14 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context, ...@@ -248,14 +228,14 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context,
template <typename T> template <typename T>
std::vector<Tensor> SampleRoisForOneImage( std::vector<Tensor> SampleRoisForOneImage(
const platform::CPUDeviceContext& context, Tensor* rpn_rois, const platform::CPUDeviceContext& context, Tensor* rpn_rois,
Tensor* gt_classes, Tensor* gt_boxes, Tensor* im_scale, Tensor* gt_classes, Tensor* is_crowd, Tensor* gt_boxes, Tensor* im_info,
const int batch_size_per_im, const float fg_fraction, const float fg_thresh, const int batch_size_per_im, const float fg_fraction, const float fg_thresh,
const float bg_thresh_hi, const float bg_thresh_lo, const float bg_thresh_hi, const float bg_thresh_lo,
const std::vector<float>& bbox_reg_weights, const int class_nums, const std::vector<float>& bbox_reg_weights, const int class_nums,
std::minstd_rand engine) { std::minstd_rand engine, bool use_random) {
auto rpn_rois_et = framework::EigenTensor<T, 2>::From(*rpn_rois); auto rpn_rois_et = framework::EigenTensor<T, 2>::From(*rpn_rois);
auto im_scale_data = im_scale->data<T>()[0]; auto im_scale = im_info->data<T>()[2];
rpn_rois_et = rpn_rois_et / im_scale_data; rpn_rois_et = rpn_rois_et / im_scale;
Tensor boxes; Tensor boxes;
int proposals_num = gt_boxes->dims()[0] + rpn_rois->dims()[0]; int proposals_num = gt_boxes->dims()[0] + rpn_rois->dims()[0];
...@@ -270,8 +250,8 @@ std::vector<Tensor> SampleRoisForOneImage( ...@@ -270,8 +250,8 @@ std::vector<Tensor> SampleRoisForOneImage(
// Generate proposal index // Generate proposal index
std::vector<std::vector<int>> fg_bg_gt = SampleFgBgGt<T>( std::vector<std::vector<int>> fg_bg_gt = SampleFgBgGt<T>(
context, &proposal_to_gt_overlaps, batch_size_per_im, fg_fraction, context, &proposal_to_gt_overlaps, *is_crowd, batch_size_per_im,
fg_thresh, bg_thresh_hi, bg_thresh_lo, engine); fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, engine, use_random);
std::vector<int> fg_inds = fg_bg_gt[0]; std::vector<int> fg_inds = fg_bg_gt[0];
std::vector<int> bg_inds = fg_bg_gt[1]; std::vector<int> bg_inds = fg_bg_gt[1];
std::vector<int> gt_inds = fg_bg_gt[2]; std::vector<int> gt_inds = fg_bg_gt[2];
...@@ -291,15 +271,15 @@ std::vector<Tensor> SampleRoisForOneImage( ...@@ -291,15 +271,15 @@ std::vector<Tensor> SampleRoisForOneImage(
// Compute targets // Compute targets
Tensor bbox_targets_single; Tensor bbox_targets_single;
bbox_targets_single.mutable_data<T>(bbox_dim, context.GetPlace()); bbox_targets_single.mutable_data<T>(bbox_dim, context.GetPlace());
BoxToDelta<T>(fg_num, sampled_boxes, sampled_gts, nullptr, false, BoxToDelta<T>(fg_num, sampled_boxes, sampled_gts, bbox_reg_weights.data(),
&bbox_targets_single); false, &bbox_targets_single);
// Scale rois // Scale rois
Tensor sampled_rois; Tensor sampled_rois;
sampled_rois.mutable_data<T>(sampled_boxes.dims(), context.GetPlace()); sampled_rois.mutable_data<T>(sampled_boxes.dims(), context.GetPlace());
auto sampled_rois_et = framework::EigenTensor<T, 2>::From(sampled_rois); auto sampled_rois_et = framework::EigenTensor<T, 2>::From(sampled_rois);
auto sampled_boxes_et = framework::EigenTensor<T, 2>::From(sampled_boxes); auto sampled_boxes_et = framework::EigenTensor<T, 2>::From(sampled_boxes);
sampled_rois_et = sampled_boxes_et * im_scale_data; sampled_rois_et = sampled_boxes_et * im_scale;
// Expand box targets // Expand box targets
Tensor bbox_targets, bbox_inside_weights, bbox_outside_weights; Tensor bbox_targets, bbox_inside_weights, bbox_outside_weights;
...@@ -351,8 +331,9 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> { ...@@ -351,8 +331,9 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* rpn_rois = context.Input<LoDTensor>("RpnRois"); auto* rpn_rois = context.Input<LoDTensor>("RpnRois");
auto* gt_classes = context.Input<LoDTensor>("GtClasses"); auto* gt_classes = context.Input<LoDTensor>("GtClasses");
auto* is_crowd = context.Input<LoDTensor>("IsCrowd");
auto* gt_boxes = context.Input<LoDTensor>("GtBoxes"); auto* gt_boxes = context.Input<LoDTensor>("GtBoxes");
auto* im_scales = context.Input<LoDTensor>("ImScales"); auto* im_info = context.Input<LoDTensor>("ImInfo");
auto* rois = context.Output<LoDTensor>("Rois"); auto* rois = context.Output<LoDTensor>("Rois");
auto* labels_int32 = context.Output<LoDTensor>("LabelsInt32"); auto* labels_int32 = context.Output<LoDTensor>("LabelsInt32");
...@@ -369,18 +350,21 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> { ...@@ -369,18 +350,21 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
std::vector<float> bbox_reg_weights = std::vector<float> bbox_reg_weights =
context.Attr<std::vector<float>>("bbox_reg_weights"); context.Attr<std::vector<float>>("bbox_reg_weights");
int class_nums = context.Attr<int>("class_nums"); int class_nums = context.Attr<int>("class_nums");
bool use_random = context.Attr<bool>("use_random");
PADDLE_ENFORCE_EQ(rpn_rois->lod().size(), 1UL, PADDLE_ENFORCE_EQ(rpn_rois->lod().size(), 1UL,
"GenerateProposalLabelsOp rpn_rois needs 1 level of LoD"); "GenerateProposalLabelsOp rpn_rois needs 1 level of LoD");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
gt_classes->lod().size(), 1UL, gt_classes->lod().size(), 1UL,
"GenerateProposalLabelsOp gt_classes needs 1 level of LoD"); "GenerateProposalLabelsOp gt_classes needs 1 level of LoD");
PADDLE_ENFORCE_EQ(is_crowd->lod().size(), 1UL,
"GenerateProposalLabelsOp is_crowd needs 1 level of LoD");
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
"GenerateProposalLabelsOp gt_boxes needs 1 level of LoD"); "GenerateProposalLabelsOp gt_boxes needs 1 level of LoD");
int64_t n = static_cast<int64_t>(rpn_rois->lod().back().size() - 1); int64_t n = static_cast<int64_t>(rpn_rois->lod().back().size() - 1);
rois->mutable_data<T>({n * batch_size_per_im, kBoxDim}, context.GetPlace()); rois->mutable_data<T>({n * batch_size_per_im, kBoxDim}, context.GetPlace());
labels_int32->mutable_data<int>({n * batch_size_per_im}, labels_int32->mutable_data<int>({n * batch_size_per_im, 1},
context.GetPlace()); context.GetPlace());
bbox_targets->mutable_data<T>({n * batch_size_per_im, kBoxDim * class_nums}, bbox_targets->mutable_data<T>({n * batch_size_per_im, kBoxDim * class_nums},
context.GetPlace()); context.GetPlace());
...@@ -391,8 +375,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> { ...@@ -391,8 +375,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
std::random_device rnd; std::random_device rnd;
std::minstd_rand engine; std::minstd_rand engine;
int seed = int seed = rnd();
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
engine.seed(seed); engine.seed(seed);
framework::LoD lod; framework::LoD lod;
...@@ -403,19 +386,23 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> { ...@@ -403,19 +386,23 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
auto rpn_rois_lod = rpn_rois->lod().back(); auto rpn_rois_lod = rpn_rois->lod().back();
auto gt_classes_lod = gt_classes->lod().back(); auto gt_classes_lod = gt_classes->lod().back();
auto is_crowd_lod = is_crowd->lod().back();
auto gt_boxes_lod = gt_boxes->lod().back(); auto gt_boxes_lod = gt_boxes->lod().back();
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
Tensor rpn_rois_slice = Tensor rpn_rois_slice =
rpn_rois->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]); rpn_rois->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]);
Tensor gt_classes_slice = Tensor gt_classes_slice =
gt_classes->Slice(gt_classes_lod[i], gt_classes_lod[i + 1]); gt_classes->Slice(gt_classes_lod[i], gt_classes_lod[i + 1]);
Tensor is_crowd_slice =
is_crowd->Slice(is_crowd_lod[i], is_crowd_lod[i + 1]);
Tensor gt_boxes_slice = Tensor gt_boxes_slice =
gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]); gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]);
Tensor im_scales_slice = im_scales->Slice(i, i + 1); Tensor im_info_slice = im_info->Slice(i, i + 1);
std::vector<Tensor> tensor_output = SampleRoisForOneImage<T>( std::vector<Tensor> tensor_output = SampleRoisForOneImage<T>(
dev_ctx, &rpn_rois_slice, &gt_classes_slice, &gt_boxes_slice, dev_ctx, &rpn_rois_slice, &gt_classes_slice, &is_crowd_slice,
&im_scales_slice, batch_size_per_im, fg_fraction, fg_thresh, &gt_boxes_slice, &im_info_slice, batch_size_per_im, fg_fraction,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums, engine); fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums,
engine, use_random);
Tensor sampled_rois = tensor_output[0]; Tensor sampled_rois = tensor_output[0];
Tensor sampled_labels_int32 = tensor_output[1]; Tensor sampled_labels_int32 = tensor_output[1];
Tensor sampled_bbox_targets = tensor_output[2]; Tensor sampled_bbox_targets = tensor_output[2];
...@@ -442,7 +429,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> { ...@@ -442,7 +429,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
bbox_inside_weights->set_lod(lod); bbox_inside_weights->set_lod(lod);
bbox_outside_weights->set_lod(lod); bbox_outside_weights->set_lod(lod);
rois->Resize({num_rois, kBoxDim}); rois->Resize({num_rois, kBoxDim});
labels_int32->Resize({num_rois}); labels_int32->Resize({num_rois, 1});
bbox_targets->Resize({num_rois, kBoxDim * class_nums}); bbox_targets->Resize({num_rois, kBoxDim * class_nums});
bbox_inside_weights->Resize({num_rois, kBoxDim * class_nums}); bbox_inside_weights->Resize({num_rois, kBoxDim * class_nums});
bbox_outside_weights->Resize({num_rois, kBoxDim * class_nums}); bbox_outside_weights->Resize({num_rois, kBoxDim * class_nums});
...@@ -455,8 +442,9 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -455,8 +442,9 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
// TODO(buxingyuan): Add Document // TODO(buxingyuan): Add Document
AddInput("RpnRois", "RpnRois."); AddInput("RpnRois", "RpnRois.");
AddInput("GtClasses", "GtClasses."); AddInput("GtClasses", "GtClasses.");
AddInput("IsCrowd", "IsCrowd.");
AddInput("GtBoxes", "GtBoxes."); AddInput("GtBoxes", "GtBoxes.");
AddInput("ImScales", "ImScales."); AddInput("ImInfo", "ImInfo.");
AddOutput("Rois", "Rois."); AddOutput("Rois", "Rois.");
AddOutput("LabelsInt32", "LabelsInt32."); AddOutput("LabelsInt32", "LabelsInt32.");
...@@ -471,8 +459,7 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -471,8 +459,7 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("bg_thresh_lo", "bg_thresh_lo"); AddAttr<float>("bg_thresh_lo", "bg_thresh_lo");
AddAttr<std::vector<float>>("bbox_reg_weights", "bbox_reg_weights"); AddAttr<std::vector<float>>("bbox_reg_weights", "bbox_reg_weights");
AddAttr<int>("class_nums", "class_nums"); AddAttr<int>("class_nums", "class_nums");
AddAttr<bool>("fix_seed", "fix_seed").SetDefault(false); AddAttr<bool>("use_random", "use_random").SetDefault(true);
AddAttr<int>("seed", "seed").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
Generate Proposals Labels Operator. Generate Proposals Labels Operator.
......
...@@ -89,12 +89,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -89,12 +89,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
} }
for (int64_t i = 0; i < row; ++i) { for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len]; T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1]; T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2; T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
(anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2;
T bbox_center_x = 0, bbox_center_y = 0; T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0; T bbox_width = 0, bbox_height = 0;
...@@ -106,25 +105,31 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -106,25 +105,31 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y = variances_data[i * len + 1] * bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height + bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y; anchor_center_y;
bbox_width = std::exp(variances_data[i * len + 2] * bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2]) * bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
anchor_width; anchor_width;
bbox_height = std::exp(variances_data[i * len + 3] * bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3]) * bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height; anchor_height;
} else { } else {
bbox_center_x = bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x; bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y = bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width; bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height; std::log(1000.0 / 16.0))) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height;
} }
proposals_data[i * len] = bbox_center_x - bbox_width / 2; proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2; proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2; proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2; proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
} }
// return proposals; // return proposals;
} }
...@@ -156,18 +161,23 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, ...@@ -156,18 +161,23 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) { float min_size, const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>(); const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace()); T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
min_size *= im_info_data[2]; T im_scale = im_info_data[2];
keep->Resize({boxes->dims()[0], 1}); keep->Resize({boxes->dims()[0], 1});
min_size = std::max(min_size, 1.0f);
int *keep_data = keep->mutable_data<int>(ctx.GetPlace()); int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0; int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) { for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1; T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1; T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T ws_origin_scale =
(boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1;
T hs_origin_scale =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1;
T x_ctr = boxes_data[4 * i] + ws / 2; T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2; T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] && if (ws_origin_scale >= min_size && hs_origin_scale >= min_size &&
y_ctr <= im_info_data[0]) { x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i; keep_data[keep_len++] = i;
} }
} }
...@@ -218,8 +228,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { ...@@ -218,8 +228,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]); const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]); const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]); const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin; const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1);
const T inter_h = inter_ymax - inter_ymin; const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h; const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized); const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized); const T bbox2_area = BBoxArea<T>(box2, normalized);
......
...@@ -82,8 +82,10 @@ class ProtoEncodeHelper { ...@@ -82,8 +82,10 @@ class ProtoEncodeHelper {
: base_(buf), p_(buf), limit_(base_ + max_size) {} : base_(buf), p_(buf), limit_(base_ + max_size) {}
~ProtoEncodeHelper() { ~ProtoEncodeHelper() {
#define REPLACE_ENFORCE_GLOG 1
// Make sure callers didn't do operations that went over max_size promised // Make sure callers didn't do operations that went over max_size promised
PADDLE_ENFORCE_LE(p_, limit_); paddle::platform::throw_on_error(p_ <= limit_);
#undef REPLACE_ENFORCE_GLOG
} }
const char* data() const { return base_; } const char* data() const { return base_; }
......
...@@ -59,17 +59,16 @@ static void ParallelExecuteBlocks( ...@@ -59,17 +59,16 @@ static void ParallelExecuteBlocks(
framework::ProgramDesc *program, framework::Scope *scope) { framework::ProgramDesc *program, framework::Scope *scope) {
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) { for (size_t idx : parallel_blkids) {
fs.push_back( fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() {
framework::Async([&executor, &prepared, &program, &scope, idx]() { int run_block = idx; // thread local
int run_block = idx; // thread local try {
try { VLOG(3) << "running server block: " << run_block
VLOG(3) << "running server block: " << run_block << "pointer: " << prepared[run_block].get();
<< "pointer: " << prepared[run_block].get(); executor->RunPreparedContext(prepared[run_block].get(), scope);
executor->RunPreparedContext(prepared[run_block].get(), scope); } catch (const std::exception &e) {
} catch (const std::exception &e) { LOG(ERROR) << "run sub program error " << e.what();
LOG(ERROR) << "run sub program error " << e.what(); }
} }));
}));
} }
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} }
......
...@@ -26,10 +26,13 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -26,10 +26,13 @@ class PReluOp : public framework::OperatorWithKernel {
std::string mode = ctx->Attrs().Get<std::string>("mode"); std::string mode = ctx->Attrs().Get<std::string>("mode");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); "Input(X) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"),
"Input(Alpha) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PreluOp should not be null");
if (mode == "all") { if (mode == "all") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
"For mode 'all', size of weight Alpha must be one."); "For mode 'all', size of weight Alpha must be one.");
......
...@@ -55,15 +55,19 @@ for _OP in set(__auto__): ...@@ -55,15 +55,19 @@ for _OP in set(__auto__):
globals()[_OP] = generate_layer_fn(_OP) globals()[_OP] = generate_layer_fn(_OP)
def rpn_target_assign(loc, def rpn_target_assign(bbox_pred,
scores, cls_logits,
anchor_box, anchor_box,
anchor_var, anchor_var,
gt_box, gt_boxes,
is_crowd,
im_info,
rpn_batch_size_per_im=256, rpn_batch_size_per_im=256,
fg_fraction=0.25, rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7, rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3): rpn_negative_overlap=0.3,
use_random=True):
""" """
** Target Assign Layer for region proposal network (RPN) in Faster-RCNN detection. ** ** Target Assign Layer for region proposal network (RPN) in Faster-RCNN detection. **
...@@ -83,14 +87,13 @@ def rpn_target_assign(loc, ...@@ -83,14 +87,13 @@ def rpn_target_assign(loc,
the positive anchors. the positive anchors.
Args: Args:
loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size, predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax]. is [xmin, ymin, xmax, ymax].
scores(Variable): A 3-D Tensor with shape [N, M, C] represents the cls_logits(Variable): A 3-D Tensor with shape [N, M, 1] represents the
predicted confidence predictions. N is the batch size, C is the predicted confidence predictions. N is the batch size, 1 is the
class number, M is number of bounding boxes. For each category frontground and background sigmoid, M is number of bounding boxes.
there are total M scores which corresponding M bounding boxes.
anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes, anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax], each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box, [xmin, ymin] is the left top coordinate of the anchor box,
...@@ -99,11 +102,16 @@ def rpn_target_assign(loc, ...@@ -99,11 +102,16 @@ def rpn_target_assign(loc,
coordinate of the anchor box. coordinate of the anchor box.
anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded
variances of anchors. variances of anchors.
gt_box (Variable): The ground-truth boudding boxes (bboxes) are a 2D gt_boxes (Variable): The ground-truth boudding boxes (bboxes) are a 2D
LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth
bboxes of mini-batch input. bboxes of mini-batch input.
is_crowd (Variable): A 1-D LoDTensor which indicates groud-truth is crowd.
im_info (Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size,
3 is the height, width and scale.
rpn_batch_size_per_im(int): Total number of RPN examples per image. rpn_batch_size_per_im(int): Total number of RPN examples per image.
fg_fraction(float): Target fraction of RoI minibatch that is labeled rpn_straddle_thresh(float): Remove RPN anchors that go outside the image
by straddle_thresh pixels.
rpn_fg_fraction(float): Target fraction of RoI minibatch that is labeled
foreground (i.e. class > 0), 0-th class is background. foreground (i.e. class > 0), 0-th class is background.
rpn_positive_overlap(float): Minimum overlap required between an anchor rpn_positive_overlap(float): Minimum overlap required between an anchor
and ground-truth box for the (anchor, gt box) pair to be a positive and ground-truth box for the (anchor, gt box) pair to be a positive
...@@ -129,45 +137,48 @@ def rpn_target_assign(loc, ...@@ -129,45 +137,48 @@ def rpn_target_assign(loc,
Examples: Examples:
.. code-block:: python .. code-block:: python
loc = layers.data(name='location', shape=[2, 80], bbox_pred = layers.data(name='bbox_pred', shape=[100, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[2, 40], cls_logits = layers.data(name='cls_logits', shape=[100, 1],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
anchor_box = layers.data(name='anchor_box', shape=[20, 4], anchor_box = layers.data(name='anchor_box', shape=[20, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
gt_box = layers.data(name='gt_box', shape=[10, 4], gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target = loc_pred, score_pred, loc_target, score_target =
fluid.layers.detection_output(loc=location, fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
scores=scores, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
gt_box=gt_box) gt_boxes=gt_boxes)
""" """
helper = LayerHelper('rpn_target_assign', **locals()) helper = LayerHelper('rpn_target_assign', **locals())
# Compute overlaps between the prior boxes and the gt boxes overlaps
iou = iou_similarity(x=gt_box, y=anchor_box)
# Assign target label to anchors # Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype='int32') loc_index = helper.create_tmp_variable(dtype='int32')
score_index = helper.create_tmp_variable(dtype='int32') score_index = helper.create_tmp_variable(dtype='int32')
target_label = helper.create_tmp_variable(dtype='int64') target_label = helper.create_tmp_variable(dtype='int32')
target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype) target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype)
helper.append_op( helper.append_op(
type="rpn_target_assign", type="rpn_target_assign",
inputs={'Anchor': anchor_box, inputs={
'GtBox': gt_box, 'Anchor': anchor_box,
'DistMat': iou}, 'GtBoxes': gt_boxes,
'IsCrowd': is_crowd,
'ImInfo': im_info
},
outputs={ outputs={
'LocationIndex': loc_index, 'LocationIndex': loc_index,
'ScoreIndex': score_index, 'ScoreIndex': score_index,
'TargetLabel': target_label, 'TargetLabel': target_label,
'TargetBBox': target_bbox, 'TargetBBox': target_bbox
}, },
attrs={ attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im, 'rpn_batch_size_per_im': rpn_batch_size_per_im,
'rpn_straddle_thresh': rpn_straddle_thresh,
'rpn_positive_overlap': rpn_positive_overlap, 'rpn_positive_overlap': rpn_positive_overlap,
'rpn_negative_overlap': rpn_negative_overlap, 'rpn_negative_overlap': rpn_negative_overlap,
'fg_fraction': fg_fraction 'rpn_fg_fraction': rpn_fg_fraction,
'use_random': use_random
}) })
loc_index.stop_gradient = True loc_index.stop_gradient = True
...@@ -175,12 +186,12 @@ def rpn_target_assign(loc, ...@@ -175,12 +186,12 @@ def rpn_target_assign(loc,
target_label.stop_gradient = True target_label.stop_gradient = True
target_bbox.stop_gradient = True target_bbox.stop_gradient = True
scores = nn.reshape(x=scores, shape=(-1, 1)) cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
loc = nn.reshape(x=loc, shape=(-1, 4)) bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_scores = nn.gather(scores, score_index) predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_location = nn.gather(loc, loc_index) predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_scores, predicted_location, target_label, target_bbox return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox
def detection_output(loc, def detection_output(loc,
...@@ -1258,15 +1269,17 @@ def anchor_generator(input, ...@@ -1258,15 +1269,17 @@ def anchor_generator(input,
def generate_proposal_labels(rpn_rois, def generate_proposal_labels(rpn_rois,
gt_classes, gt_classes,
is_crowd,
gt_boxes, gt_boxes,
im_scales, im_info,
batch_size_per_im=256, batch_size_per_im=256,
fg_fraction=0.25, fg_fraction=0.25,
fg_thresh=0.25, fg_thresh=0.25,
bg_thresh_hi=0.5, bg_thresh_hi=0.5,
bg_thresh_lo=0.0, bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=None): class_nums=None,
use_random=True):
""" """
** Generate proposal labels Faster-RCNN ** ** Generate proposal labels Faster-RCNN **
TODO(buxingyuan): Add Document TODO(buxingyuan): Add Document
...@@ -1285,8 +1298,9 @@ def generate_proposal_labels(rpn_rois, ...@@ -1285,8 +1298,9 @@ def generate_proposal_labels(rpn_rois,
inputs={ inputs={
'RpnRois': rpn_rois, 'RpnRois': rpn_rois,
'GtClasses': gt_classes, 'GtClasses': gt_classes,
'IsCrowd': is_crowd,
'GtBoxes': gt_boxes, 'GtBoxes': gt_boxes,
'ImScales': im_scales 'ImInfo': im_info
}, },
outputs={ outputs={
'Rois': rois, 'Rois': rois,
...@@ -1302,7 +1316,8 @@ def generate_proposal_labels(rpn_rois, ...@@ -1302,7 +1316,8 @@ def generate_proposal_labels(rpn_rois,
'bg_thresh_hi': bg_thresh_hi, 'bg_thresh_hi': bg_thresh_hi,
'bg_thresh_lo': bg_thresh_lo, 'bg_thresh_lo': bg_thresh_lo,
'bbox_reg_weights': bbox_reg_weights, 'bbox_reg_weights': bbox_reg_weights,
'class_nums': class_nums 'class_nums': class_nums,
'use_random': use_random
}) })
rois.stop_gradient = True rois.stop_gradient = True
......
...@@ -148,51 +148,60 @@ class TestAnchorGenerator(unittest.TestCase): ...@@ -148,51 +148,60 @@ class TestAnchorGenerator(unittest.TestCase):
class TestGenerateProposalLabels(unittest.TestCase): class TestGenerateProposalLabels(unittest.TestCase):
def test_generate_proposal_labels(self): def test_generate_proposal_labels(self):
rpn_rois = layers.data( program = Program()
name='rpn_rois', with program_guard(program):
shape=[4, 4], rpn_rois = layers.data(
dtype='float32', name='rpn_rois',
lod_level=1, shape=[4, 4],
append_batch_size=False) dtype='float32',
gt_classes = layers.data( lod_level=1,
name='gt_classes', append_batch_size=False)
shape=[6], gt_classes = layers.data(
dtype='int32', name='gt_classes',
lod_level=1, shape=[6],
append_batch_size=False) dtype='int32',
gt_boxes = layers.data( lod_level=1,
name='gt_boxes', append_batch_size=False)
shape=[6, 4], is_crowd = layers.data(
dtype='float32', name='is_crowd',
lod_level=1, shape=[6],
append_batch_size=False) dtype='int32',
im_scales = layers.data( lod_level=1,
name='im_scales', append_batch_size=False)
shape=[1], gt_boxes = layers.data(
dtype='float32', name='gt_boxes',
lod_level=1, shape=[6, 4],
append_batch_size=False) dtype='float32',
class_nums = 5 lod_level=1,
rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels( append_batch_size=False)
rpn_rois=rpn_rois, im_info = layers.data(
gt_classes=gt_classes, name='im_info',
gt_boxes=gt_boxes, shape=[1, 3],
im_scales=im_scales, dtype='float32',
batch_size_per_im=2, lod_level=1,
fg_fraction=0.5, append_batch_size=False)
fg_thresh=0.5, class_nums = 5
bg_thresh_hi=0.5, rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels(
bg_thresh_lo=0.0, rpn_rois=rpn_rois,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], gt_classes=gt_classes,
class_nums=class_nums) is_crowd=is_crowd,
assert rois.shape[1] == 4 gt_boxes=gt_boxes,
assert rois.shape[0] == labels_int32.shape[0] im_info=im_info,
assert rois.shape[0] == bbox_targets.shape[0] batch_size_per_im=2,
assert rois.shape[0] == bbox_inside_weights.shape[0] fg_fraction=0.5,
assert rois.shape[0] == bbox_outside_weights.shape[0] fg_thresh=0.5,
assert bbox_targets.shape[1] == 4 * class_nums bg_thresh_hi=0.5,
assert bbox_inside_weights.shape[1] == 4 * class_nums bg_thresh_lo=0.0,
assert bbox_outside_weights.shape[1] == 4 * class_nums bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * class_nums
assert bbox_inside_weights.shape[1] == 4 * class_nums
assert bbox_outside_weights.shape[1] == 4 * class_nums
class TestMultiBoxHead(unittest.TestCase): class TestMultiBoxHead(unittest.TestCase):
...@@ -254,18 +263,18 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -254,18 +263,18 @@ class TestRpnTargetAssign(unittest.TestCase):
def test_rpn_target_assign(self): def test_rpn_target_assign(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
loc_shape = [10, 50, 4] bbox_pred_shape = [10, 50, 4]
score_shape = [10, 50, 2] cls_logits_shape = [10, 50, 2]
anchor_shape = [50, 4] anchor_shape = [50, 4]
loc = layers.data( bbox_pred = layers.data(
name='loc', name='bbox_pred',
shape=loc_shape, shape=bbox_pred_shape,
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
scores = layers.data( cls_logits = layers.data(
name='scores', name='cls_logits',
shape=score_shape, shape=cls_logits_shape,
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
anchor_box = layers.data( anchor_box = layers.data(
...@@ -278,17 +287,31 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -278,17 +287,31 @@ class TestRpnTargetAssign(unittest.TestCase):
shape=anchor_shape, shape=anchor_shape,
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
gt_box = layers.data( gt_boxes = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32') name='gt_boxes', shape=[4], lod_level=1, dtype='float32')
is_crowd = layers.data(
name='is_crowd',
shape=[10],
dtype='int32',
lod_level=1,
append_batch_size=False)
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
loc=loc, bbox_pred=bbox_pred,
scores=scores, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
anchor_var=anchor_var, anchor_var=anchor_var,
gt_box=gt_box, gt_boxes=gt_boxes,
is_crowd=is_crowd,
im_info=im_info,
rpn_batch_size_per_im=256, rpn_batch_size_per_im=256,
fg_fraction=0.25, rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7, rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3) rpn_negative_overlap=0.3)
......
...@@ -20,10 +20,10 @@ import paddle.fluid as fluid ...@@ -20,10 +20,10 @@ import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
def generate_proposal_labels_in_python( def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
rpn_rois, gt_classes, gt_boxes, im_scales, batch_size_per_im, im_info, batch_size_per_im, fg_fraction,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, fg_thresh, bg_thresh_hi, bg_thresh_lo,
class_nums): bbox_reg_weights, class_nums):
rois = [] rois = []
labels_int32 = [] labels_int32 = []
bbox_targets = [] bbox_targets = []
...@@ -31,13 +31,13 @@ def generate_proposal_labels_in_python( ...@@ -31,13 +31,13 @@ def generate_proposal_labels_in_python(
bbox_outside_weights = [] bbox_outside_weights = []
lod = [] lod = []
assert len(rpn_rois) == len( assert len(rpn_rois) == len(
im_scales), 'batch size of rpn_rois and ground_truth is not matched' im_info), 'batch size of rpn_rois and ground_truth is not matched'
for im_i in range(len(im_scales)): for im_i in range(len(im_info)):
frcn_blobs = _sample_rois( frcn_blobs = _sample_rois(
rpn_rois[im_i], gt_classes[im_i], gt_boxes[im_i], im_scales[im_i], rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i],
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_lo, bbox_reg_weights, class_nums) bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums)
lod.append(frcn_blobs['rois'].shape[0]) lod.append(frcn_blobs['rois'].shape[0])
...@@ -50,13 +50,14 @@ def generate_proposal_labels_in_python( ...@@ -50,13 +50,14 @@ def generate_proposal_labels_in_python(
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod
def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im, def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bbox_reg_weights, class_nums): bg_thresh_lo, bbox_reg_weights, class_nums):
rois_per_image = int(batch_size_per_im) rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
# Roidb # Roidb
im_scale = im_info[2]
inv_im_scale = 1. / im_scale inv_im_scale = 1. / im_scale
rpn_rois = rpn_rois * inv_im_scale rpn_rois = rpn_rois * inv_im_scale
...@@ -78,6 +79,9 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im, ...@@ -78,6 +79,9 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[ box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[
overlapped_boxes_ind] overlapped_boxes_ind]
crowd_ind = np.where(is_crowd)[0]
gt_overlaps[crowd_ind] = -1
max_overlaps = gt_overlaps.max(axis=1) max_overlaps = gt_overlaps.max(axis=1)
max_classes = gt_overlaps.argmax(axis=1) max_classes = gt_overlaps.argmax(axis=1)
...@@ -85,9 +89,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im, ...@@ -85,9 +89,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
fg_inds = np.where(max_overlaps >= fg_thresh)[0] fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0]) fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0])
# Sample foreground if there are too many # Sample foreground if there are too many
if fg_inds.shape[0] > fg_rois_per_this_image: # if fg_inds.shape[0] > fg_rois_per_this_image:
fg_inds = np.random.choice( # fg_inds = np.random.choice(
fg_inds, size=fg_rois_per_this_image, replace=False) # fg_inds, size=fg_rois_per_this_image, replace=False)
fg_inds = fg_inds[:fg_rois_per_this_image]
# Background # Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
...@@ -96,9 +101,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im, ...@@ -96,9 +101,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0]) bg_inds.shape[0])
# Sample background if there are too many # Sample background if there are too many
if bg_inds.shape[0] > bg_rois_per_this_image: # if bg_inds.shape[0] > bg_rois_per_this_image:
bg_inds = np.random.choice( # bg_inds = np.random.choice(
bg_inds, size=bg_rois_per_this_image, replace=False) # bg_inds, size=bg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image]
keep_inds = np.append(fg_inds, bg_inds) keep_inds = np.append(fg_inds, bg_inds)
sampled_labels = max_classes[keep_inds] sampled_labels = max_classes[keep_inds]
...@@ -208,8 +214,9 @@ class TestGenerateProposalLabelsOp(OpTest): ...@@ -208,8 +214,9 @@ class TestGenerateProposalLabelsOp(OpTest):
self.inputs = { self.inputs = {
'RpnRois': (self.rpn_rois[0], self.rpn_rois_lod), 'RpnRois': (self.rpn_rois[0], self.rpn_rois_lod),
'GtClasses': (self.gt_classes[0], self.gts_lod), 'GtClasses': (self.gt_classes[0], self.gts_lod),
'IsCrowd': (self.is_crowd[0], self.gts_lod),
'GtBoxes': (self.gt_boxes[0], self.gts_lod), 'GtBoxes': (self.gt_boxes[0], self.gts_lod),
'ImScales': self.im_scales[0] 'ImInfo': self.im_info
} }
self.attrs = { self.attrs = {
'batch_size_per_im': self.batch_size_per_im, 'batch_size_per_im': self.batch_size_per_im,
...@@ -218,14 +225,15 @@ class TestGenerateProposalLabelsOp(OpTest): ...@@ -218,14 +225,15 @@ class TestGenerateProposalLabelsOp(OpTest):
'bg_thresh_hi': self.bg_thresh_hi, 'bg_thresh_hi': self.bg_thresh_hi,
'bg_thresh_lo': self.bg_thresh_lo, 'bg_thresh_lo': self.bg_thresh_lo,
'bbox_reg_weights': self.bbox_reg_weights, 'bbox_reg_weights': self.bbox_reg_weights,
'class_nums': self.class_nums 'class_nums': self.class_nums,
'use_random': False
} }
self.outputs = { self.outputs = {
'Rois': (self.rois[0], [self.lod]), 'Rois': (self.rois, [self.lod]),
'LabelsInt32': (self.labels_int32[0], [self.lod]), 'LabelsInt32': (self.labels_int32, [self.lod]),
'BboxTargets': (self.bbox_targets[0], [self.lod]), 'BboxTargets': (self.bbox_targets, [self.lod]),
'BboxInsideWeights': (self.bbox_inside_weights[0], [self.lod]), 'BboxInsideWeights': (self.bbox_inside_weights, [self.lod]),
'BboxOutsideWeights': (self.bbox_outside_weights[0], [self.lod]), 'BboxOutsideWeights': (self.bbox_outside_weights, [self.lod]),
} }
def test_check_output(self): def test_check_output(self):
...@@ -236,8 +244,8 @@ class TestGenerateProposalLabelsOp(OpTest): ...@@ -236,8 +244,8 @@ class TestGenerateProposalLabelsOp(OpTest):
self.set_data() self.set_data()
def init_test_params(self): def init_test_params(self):
self.batch_size_per_im = 10 self.batch_size_per_im = 512
self.fg_fraction = 1.0 self.fg_fraction = 0.25
self.fg_thresh = 0.5 self.fg_thresh = 0.5
self.bg_thresh_hi = 0.5 self.bg_thresh_hi = 0.5
self.bg_thresh_lo = 0.0 self.bg_thresh_lo = 0.0
...@@ -246,14 +254,14 @@ class TestGenerateProposalLabelsOp(OpTest): ...@@ -246,14 +254,14 @@ class TestGenerateProposalLabelsOp(OpTest):
def init_test_input(self): def init_test_input(self):
np.random.seed(0) np.random.seed(0)
image_nums = 1
gt_nums = 6 # Keep same with batch_size_per_im for unittest gt_nums = 6 # Keep same with batch_size_per_im for unittest
proposal_nums = self.batch_size_per_im - gt_nums proposal_nums = 2000 #self.batch_size_per_im - gt_nums
images_shape = [] images_shape = [[64, 64]]
self.im_scales = [] self.im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(image_nums): for i in range(len(images_shape)):
images_shape.append(np.random.randint(200, size=2)) self.im_info[i, 0] = images_shape[i][0]
self.im_scales.append(np.ones((1)).astype(np.float32)) self.im_info[i, 1] = images_shape[i][1]
self.im_info[i, 2] = 0.8 #scale
self.rpn_rois, self.rpn_rois_lod = _generate_proposals(images_shape, self.rpn_rois, self.rpn_rois_lod = _generate_proposals(images_shape,
proposal_nums) proposal_nums)
...@@ -261,16 +269,23 @@ class TestGenerateProposalLabelsOp(OpTest): ...@@ -261,16 +269,23 @@ class TestGenerateProposalLabelsOp(OpTest):
images_shape, self.class_nums, gt_nums) images_shape, self.class_nums, gt_nums)
self.gt_classes = [gt['gt_classes'] for gt in ground_truth] self.gt_classes = [gt['gt_classes'] for gt in ground_truth]
self.gt_boxes = [gt['boxes'] for gt in ground_truth] self.gt_boxes = [gt['boxes'] for gt in ground_truth]
self.is_crowd = [gt['is_crowd'] for gt in ground_truth]
def init_test_output(self): def init_test_output(self):
self.rois, self.labels_int32, self.bbox_targets, \ self.rois, self.labels_int32, self.bbox_targets, \
self.bbox_inside_weights, self.bbox_outside_weights, \ self.bbox_inside_weights, self.bbox_outside_weights, \
self.lod = generate_proposal_labels_in_python( self.lod = generate_proposal_labels_in_python(
self.rpn_rois, self.gt_classes, self.gt_boxes, self.im_scales, self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info,
self.batch_size_per_im, self.fg_fraction, self.batch_size_per_im, self.fg_fraction,
self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo, self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo,
self.bbox_reg_weights, self.class_nums self.bbox_reg_weights, self.class_nums
) )
self.rois = np.vstack(self.rois)
self.labels_int32 = np.hstack(self.labels_int32)
self.labels_int32 = self.labels_int32[:, np.newaxis]
self.bbox_targets = np.vstack(self.bbox_targets)
self.bbox_inside_weights = np.vstack(self.bbox_inside_weights)
self.bbox_outside_weights = np.vstack(self.bbox_outside_weights)
def _generate_proposals(images_shape, proposal_nums): def _generate_proposals(images_shape, proposal_nums):
...@@ -280,7 +295,7 @@ def _generate_proposals(images_shape, proposal_nums): ...@@ -280,7 +295,7 @@ def _generate_proposals(images_shape, proposal_nums):
for i, image_shape in enumerate(images_shape): for i, image_shape in enumerate(images_shape):
proposals = _generate_boxes(image_shape, proposal_nums) proposals = _generate_boxes(image_shape, proposal_nums)
rpn_rois.append(proposals) rpn_rois.append(proposals)
num_proposals += len(proposals) num_proposals = len(proposals)
rpn_rois_lod.append(num_proposals) rpn_rois_lod.append(num_proposals)
return rpn_rois, [rpn_rois_lod] return rpn_rois, [rpn_rois_lod]
...@@ -294,7 +309,11 @@ def _generate_groundtruth(images_shape, class_nums, gt_nums): ...@@ -294,7 +309,11 @@ def _generate_groundtruth(images_shape, class_nums, gt_nums):
gt_classes = np.random.randint( gt_classes = np.random.randint(
low=1, high=class_nums, size=gt_nums).astype(np.int32) low=1, high=class_nums, size=gt_nums).astype(np.int32)
gt_boxes = _generate_boxes(image_shape, gt_nums) gt_boxes = _generate_boxes(image_shape, gt_nums)
ground_truth.append(dict(gt_classes=gt_classes, boxes=gt_boxes)) is_crowd = np.zeros((gt_nums), dtype=np.int32)
is_crowd[0] = 1
ground_truth.append(
dict(
gt_classes=gt_classes, boxes=gt_boxes, is_crowd=is_crowd))
num_gts += len(gt_classes) num_gts += len(gt_classes)
gts_lod.append(num_gts) gts_lod.append(num_gts)
return ground_truth, [gts_lod] return ground_truth, [gts_lod]
......
...@@ -114,10 +114,10 @@ def box_coder(all_anchors, bbox_deltas, variances): ...@@ -114,10 +114,10 @@ def box_coder(all_anchors, bbox_deltas, variances):
#anchor_loc: width, height, center_x, center_y #anchor_loc: width, height, center_x, center_y
anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32) anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32)
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] + 1
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] + 1
anchor_loc[:, 2] = (all_anchors[:, 2] + all_anchors[:, 0]) / 2 anchor_loc[:, 2] = all_anchors[:, 0] + 0.5 * anchor_loc[:, 0]
anchor_loc[:, 3] = (all_anchors[:, 3] + all_anchors[:, 1]) / 2 anchor_loc[:, 3] = all_anchors[:, 1] + 0.5 * anchor_loc[:, 1]
#predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height #predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height
pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32) pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32)
...@@ -127,23 +127,29 @@ def box_coder(all_anchors, bbox_deltas, variances): ...@@ -127,23 +127,29 @@ def box_coder(all_anchors, bbox_deltas, variances):
i, 0] + anchor_loc[i, 2] i, 0] + anchor_loc[i, 2]
pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[ pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[
i, 1] + anchor_loc[i, 3] i, 1] + anchor_loc[i, 3]
pred_bbox[i, 2] = math.exp(variances[i, 2] * pred_bbox[i, 2] = math.exp(
bbox_deltas[i, 2]) * anchor_loc[i, 0] min(variances[i, 2] * bbox_deltas[i, 2], math.log(
pred_bbox[i, 3] = math.exp(variances[i, 3] * 1000 / 16.0))) * anchor_loc[i, 0]
bbox_deltas[i, 3]) * anchor_loc[i, 1] pred_bbox[i, 3] = math.exp(
min(variances[i, 3] * bbox_deltas[i, 3], math.log(
1000 / 16.0))) * anchor_loc[i, 1]
else: else:
for i in range(bbox_deltas.shape[0]): for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[ pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[
i, 2] i, 2]
pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[ pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[
i, 3] i, 3]
pred_bbox[i, 2] = math.exp(bbox_deltas[i, 2]) * anchor_loc[i, 0] pred_bbox[i, 2] = math.exp(
pred_bbox[i, 3] = math.exp(bbox_deltas[i, 3]) * anchor_loc[i, 1] min(bbox_deltas[i, 2], math.log(1000 / 16.0))) * anchor_loc[i,
0]
pred_bbox[i, 3] = math.exp(
min(bbox_deltas[i, 3], math.log(1000 / 16.0))) * anchor_loc[i,
1]
proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2 proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2
proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2 proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 - 1
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 - 1
return proposals return proposals
...@@ -170,13 +176,16 @@ def filter_boxes(boxes, min_size, im_info): ...@@ -170,13 +176,16 @@ def filter_boxes(boxes, min_size, im_info):
"""Only keep boxes with both sides >= min_size and center within the image. """Only keep boxes with both sides >= min_size and center within the image.
""" """
# Scale min_size to match image scale # Scale min_size to match image scale
min_size *= im_info[2] im_scale = im_info[2]
min_size = max(min_size, 1.0)
ws = boxes[:, 2] - boxes[:, 0] + 1 ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1 hs = boxes[:, 3] - boxes[:, 1] + 1
ws_orig_scale = (boxes[:, 2] - boxes[:, 0]) / im_scale + 1
hs_orig_scale = (boxes[:, 3] - boxes[:, 1]) / im_scale + 1
x_ctr = boxes[:, 0] + ws / 2. x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2. y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_info[1]) & keep = np.where((ws_orig_scale >= min_size) & (hs_orig_scale >= min_size) &
(y_ctr < im_info[0]))[0] (x_ctr < im_info[1]) & (y_ctr < im_info[0]))[0]
return keep return keep
...@@ -204,7 +213,7 @@ def iou(box_a, box_b): ...@@ -204,7 +213,7 @@ def iou(box_a, box_b):
xb = min(xmax_a, xmax_b) xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b) yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0) inter_area = max(xb - xa + 1, 0.0) * max(yb - ya + 1, 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area) iou_ratio = inter_area / (area_a + area_b - inter_area)
......
...@@ -19,48 +19,58 @@ import numpy as np ...@@ -19,48 +19,58 @@ import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
from test_anchor_generator_op import anchor_generator_in_python from test_anchor_generator_op import anchor_generator_in_python
from test_generate_proposal_labels import _generate_groundtruth from test_generate_proposal_labels_op import _generate_groundtruth
from test_generate_proposal_labels import _bbox_overlaps, _box_to_delta from test_generate_proposal_labels_op import _bbox_overlaps, _box_to_delta
def rpn_target_assign(gt_anchor_iou, rpn_batch_size_per_im, def rpn_target_assign(anchor_by_gt_overlap,
rpn_positive_overlap, rpn_negative_overlap, fg_fraction): rpn_batch_size_per_im,
iou = np.transpose(gt_anchor_iou) rpn_positive_overlap,
anchor_to_gt_max = iou.max(axis=1) rpn_negative_overlap,
anchor_to_gt_argmax = iou.argmax(axis=1) rpn_fg_fraction,
use_random=True):
gt_to_anchor_argmax = iou.argmax(axis=0) anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1)
gt_to_anchor_max = iou[gt_to_anchor_argmax, np.arange(iou.shape[1])] anchor_to_gt_max = anchor_by_gt_overlap[np.arange(
anchors_with_max_overlap = np.where(iou == gt_to_anchor_max)[0] anchor_by_gt_overlap.shape[0]), anchor_to_gt_argmax]
tgt_lbl = np.ones((iou.shape[0], ), dtype=np.int32) * -1 gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0)
tgt_lbl[anchors_with_max_overlap] = 1 gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, np.arange(
tgt_lbl[anchor_to_gt_max >= rpn_positive_overlap] = 1 anchor_by_gt_overlap.shape[1])]
anchors_with_max_overlap = np.where(
num_fg = int(fg_fraction * rpn_batch_size_per_im) anchor_by_gt_overlap == gt_to_anchor_max)[0]
fg_inds = np.where(tgt_lbl == 1)[0]
if len(fg_inds) > num_fg: labels = np.ones((anchor_by_gt_overlap.shape[0], ), dtype=np.int32) * -1
labels[anchors_with_max_overlap] = 1
labels[anchor_to_gt_max >= rpn_positive_overlap] = 1
num_fg = int(rpn_fg_fraction * rpn_batch_size_per_im)
fg_inds = np.where(labels == 1)[0]
if len(fg_inds) > num_fg and use_random:
disable_inds = np.random.choice( disable_inds = np.random.choice(
fg_inds, size=(len(fg_inds) - num_fg), replace=False) fg_inds, size=(len(fg_inds) - num_fg), replace=False)
tgt_lbl[disable_inds] = -1 else:
fg_inds = np.where(tgt_lbl == 1)[0] disable_inds = fg_inds[num_fg:]
labels[disable_inds] = -1
fg_inds = np.where(labels == 1)[0]
num_bg = rpn_batch_size_per_im - np.sum(tgt_lbl == 1) num_bg = rpn_batch_size_per_im - np.sum(labels == 1)
bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
tgt_lbl[bg_inds] = 0 if len(bg_inds) > num_bg and use_random:
if len(bg_inds) > num_bg:
enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
tgt_lbl[enable_inds] = 0 else:
bg_inds = np.where(tgt_lbl == 0)[0] enable_inds = bg_inds[:num_bg]
tgt_lbl[bg_inds] = 0 labels[enable_inds] = 0
fg_inds = np.where(labels == 1)[0]
bg_inds = np.where(labels == 0)[0]
loc_index = fg_inds loc_index = fg_inds
score_index = np.hstack((fg_inds, bg_inds)) score_index = np.hstack((fg_inds, bg_inds))
tgt_lbl = np.expand_dims(tgt_lbl, axis=1) labels = labels[score_index]
assert not np.any(labels == -1), "Wrong labels with -1"
gt_inds = anchor_to_gt_argmax[fg_inds] gt_inds = anchor_to_gt_argmax[fg_inds]
return loc_index, score_index, tgt_lbl, gt_inds return loc_index, score_index, labels, gt_inds
def get_anchor(n, c, h, w): def get_anchor(n, c, h, w):
...@@ -75,85 +85,129 @@ def get_anchor(n, c, h, w): ...@@ -75,85 +85,129 @@ def get_anchor(n, c, h, w):
return anchors return anchors
def rpn_blob(anchor, gt_boxes, iou, lod, rpn_batch_size_per_im, def rpn_target_assign_in_python(all_anchors,
rpn_positive_overlap, rpn_negative_overlap, fg_fraction): gt_boxes,
is_crowd,
loc_indexes = [] im_info,
score_indexes = [] lod,
tmp_tgt_labels = [] rpn_straddle_thresh,
tgt_bboxes = [] rpn_batch_size_per_im,
anchor_num = anchor.shape[0] rpn_positive_overlap,
rpn_negative_overlap,
rpn_fg_fraction,
use_random=True):
anchor_num = all_anchors.shape[0]
batch_size = len(lod) - 1 batch_size = len(lod) - 1
for i in range(batch_size): for i in range(batch_size):
im_height = im_info[i][0]
im_width = im_info[i][1]
im_scale = im_info[i][2]
if rpn_straddle_thresh >= 0:
# Only keep anchors inside the image by a margin of straddle_thresh
inds_inside = np.where(
(all_anchors[:, 0] >= -rpn_straddle_thresh) &
(all_anchors[:, 1] >= -rpn_straddle_thresh) & (
all_anchors[:, 2] < im_width + rpn_straddle_thresh) & (
all_anchors[:, 3] < im_height + rpn_straddle_thresh))[0]
# keep only inside anchors
inside_anchors = all_anchors[inds_inside, :]
else:
inds_inside = np.arange(all_anchors.shape[0])
inside_anchors = all_anchors
b, e = lod[i], lod[i + 1] b, e = lod[i], lod[i + 1]
iou_slice = iou[b:e, :] gt_boxes_slice = gt_boxes[b:e, :] * im_scale
bboxes_slice = gt_boxes[b:e, :] is_crowd_slice = is_crowd[b:e]
loc_idx, score_idx, tgt_lbl, gt_inds = rpn_target_assign( not_crowd_inds = np.where(is_crowd_slice == 0)[0]
iou_slice, rpn_batch_size_per_im, rpn_positive_overlap, gt_boxes_slice = gt_boxes_slice[not_crowd_inds]
rpn_negative_overlap, fg_fraction) iou = _bbox_overlaps(inside_anchors, gt_boxes_slice)
fg_bboxes = bboxes_slice[gt_inds] loc_inds, score_inds, labels, gt_inds = rpn_target_assign(
fg_anchors = anchor[loc_idx] iou, rpn_batch_size_per_im, rpn_positive_overlap,
box_deltas = _box_to_delta(fg_anchors, fg_bboxes, [1., 1., 1., 1.]) rpn_negative_overlap, rpn_fg_fraction, use_random)
# unmap to all anchor
loc_inds = inds_inside[loc_inds]
score_inds = inds_inside[score_inds]
sampled_gt = gt_boxes_slice[gt_inds]
sampled_anchor = all_anchors[loc_inds]
box_deltas = _box_to_delta(sampled_anchor, sampled_gt, [1., 1., 1., 1.])
if i == 0: if i == 0:
loc_indexes = loc_idx loc_indexes = loc_inds
score_indexes = score_idx score_indexes = score_inds
tmp_tgt_labels = tgt_lbl tgt_labels = labels
tgt_bboxes = box_deltas tgt_bboxes = box_deltas
else: else:
loc_indexes = np.concatenate( loc_indexes = np.concatenate(
[loc_indexes, loc_idx + i * anchor_num]) [loc_indexes, loc_inds + i * anchor_num])
score_indexes = np.concatenate( score_indexes = np.concatenate(
[score_indexes, score_idx + i * anchor_num]) [score_indexes, score_inds + i * anchor_num])
tmp_tgt_labels = np.concatenate([tmp_tgt_labels, tgt_lbl]) tgt_labels = np.concatenate([tgt_labels, labels])
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
tgt_labels = tmp_tgt_labels[score_indexes]
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels return loc_indexes, score_indexes, tgt_bboxes, tgt_labels
class TestRpnTargetAssignOp(OpTest): class TestRpnTargetAssignOp(OpTest):
def setUp(self): def setUp(self):
n, c, h, w = 2, 4, 14, 14 n, c, h, w = 2, 4, 14, 14
anchor = get_anchor(n, c, h, w) all_anchors = get_anchor(n, c, h, w)
gt_num = 10 gt_num = 10
anchor = anchor.reshape(-1, 4) all_anchors = all_anchors.reshape(-1, 4)
anchor_num = anchor.shape[0] anchor_num = all_anchors.shape[0]
im_shapes = [[64, 64], [64, 64]] images_shape = [[64, 64], [64, 64]]
gt_box, lod = _generate_groundtruth(im_shapes, 3, 4) #images_shape = [[64, 64]]
bbox = np.vstack([v['boxes'] for v in gt_box]) groundtruth, lod = _generate_groundtruth(images_shape, 3, 4)
lod = [0, 4, 8]
iou = _bbox_overlaps(bbox, anchor) #lod = [0, 4]
anchor = anchor.astype('float32') im_info = np.ones((len(images_shape), 3)).astype(np.float32)
bbox = bbox.astype('float32') for i in range(len(images_shape)):
iou = iou.astype('float32') im_info[i, 0] = images_shape[i][0]
im_info[i, 1] = images_shape[i][1]
loc_index, score_index, tgt_bbox, tgt_lbl = rpn_blob( im_info[i, 2] = 0.8 #scale
anchor, bbox, iou, [0, 4, 8], 25600, 0.95, 0.03, 0.25) gt_boxes = np.vstack([v['boxes'] for v in groundtruth])
is_crowd = np.hstack([v['is_crowd'] for v in groundtruth])
all_anchors = all_anchors.astype('float32')
gt_boxes = gt_boxes.astype('float32')
rpn_straddle_thresh = 0.0
rpn_batch_size_per_im = 256
rpn_positive_overlap = 0.7
rpn_negative_overlap = 0.3
rpn_fg_fraction = 0.5
use_random = False
loc_index, score_index, tgt_bbox, labels = rpn_target_assign_in_python(
all_anchors, gt_boxes, is_crowd, im_info, lod, rpn_straddle_thresh,
rpn_batch_size_per_im, rpn_positive_overlap, rpn_negative_overlap,
rpn_fg_fraction, use_random)
labels = labels[:, np.newaxis]
self.op_type = "rpn_target_assign" self.op_type = "rpn_target_assign"
self.inputs = { self.inputs = {
'Anchor': anchor, 'Anchor': all_anchors,
'GtBox': (bbox, [[4, 4]]), 'GtBoxes': (gt_boxes, [[4, 4]]),
'DistMat': (iou, [[4, 4]]), 'IsCrowd': (is_crowd, [[4, 4]]),
'ImInfo': (im_info, [[1, 1]])
} }
self.attrs = { self.attrs = {
'rpn_batch_size_per_im': 25600, 'rpn_batch_size_per_im': rpn_batch_size_per_im,
'rpn_positive_overlap': 0.95, 'rpn_straddle_thresh': rpn_straddle_thresh,
'rpn_negative_overlap': 0.03, 'rpn_positive_overlap': rpn_positive_overlap,
'fg_fraction': 0.25, 'rpn_negative_overlap': rpn_negative_overlap,
'fix_seed': True 'rpn_fg_fraction': rpn_fg_fraction,
'use_random': use_random
} }
self.outputs = { self.outputs = {
'LocationIndex': loc_index.astype('int32'), 'LocationIndex': loc_index.astype('int32'),
'ScoreIndex': score_index.astype('int32'), 'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'), 'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': tgt_lbl.astype('int64'), 'TargetLabel': labels.astype('int32')
} }
def test_check_output(self): def test_check_output(self):
......
...@@ -65,8 +65,43 @@ class InferenceTranspiler(object): ...@@ -65,8 +65,43 @@ class InferenceTranspiler(object):
if use_mkldnn: if use_mkldnn:
self._fuse_conv_bias_mkldnn(program) self._fuse_conv_bias_mkldnn(program)
self._fuse_conv_relu_mkldnn(program) self._fuse_conv_relu_mkldnn(program)
self._fuse_conv_eltwise_mkldnn(program)
self._fuse_conv_relu_mkldnn(
program) # ResNet residual block merging
self._fuse_bn_relu_mkldnn(program) self._fuse_bn_relu_mkldnn(program)
def _fuse_conv_eltwise_mkldnn(self, program):
'''
Transpile the program fusing elementwise_add into conv for MKLDNN
program. Elementwise add following convolution OP can be fused by adding
'fuse_eltwise' attribute to convolution OP and replacing its output
Tensor with second parameter of elementwise_add.
The result of fuse is:
- before:
- conv->elementwise_add->any_other_op
- after:
- conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'elementwise_add':
self._fuse_conv_eltwise(current_op, next_op)
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
self._adjust_input()
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_conv_relu_mkldnn(self, program): def _fuse_conv_relu_mkldnn(self, program):
''' '''
Transpile the program by fused relu activation for MKLDNN program. Transpile the program by fused relu activation for MKLDNN program.
...@@ -88,9 +123,9 @@ class InferenceTranspiler(object): ...@@ -88,9 +123,9 @@ class InferenceTranspiler(object):
if current_op.type in ['conv2d']: if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1] next_op = self.block.ops[i + 1]
if next_op.type == 'relu': if next_op.type == 'relu':
# modify conv OP to include relu # modify bnorm OP to include relu
current_op.set_attr("fuse_relu", True) current_op.set_attr("fuse_relu", True)
# remove conv OP # remove relu OP
self.block._remove_op(i + 1) self.block._remove_op(i + 1)
i = i + 1 i = i + 1
...@@ -409,6 +444,20 @@ class InferenceTranspiler(object): ...@@ -409,6 +444,20 @@ class InferenceTranspiler(object):
outputs={"Output": out_var}, outputs={"Output": out_var},
attrs=attrs) attrs=attrs)
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
:param conv_op: convolution operator
:type conv_op: Operator
:param eltwise_op: operator adding data from skip connection
:type eltwise_op: Operator
'''
conv_op.set_attr("fuse_eltwise", True)
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
def _adjust_input(self): def _adjust_input(self):
for i in range(len(self.block.ops)): for i in range(len(self.block.ops)):
current_op = self.block.ops[i] current_op = self.block.ops[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册