From 65b641bf660a8ecbcb831dc0b35a1e58bc15174a Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Mon, 11 Dec 2017 22:56:41 +0800 Subject: [PATCH] add detection_output op --- paddle/operators/detection_output_op.cc | 63 +++++------ paddle/operators/detection_output_op.h | 102 +++++++++--------- paddle/operators/math/detection_util.h | 70 +++++++----- .../fluid/tests/test_detection_output_op.py | 24 +++-- 4 files changed, 133 insertions(+), 126 deletions(-) diff --git a/paddle/operators/detection_output_op.cc b/paddle/operators/detection_output_op.cc index a04d6e5758..ced9caf992 100644 --- a/paddle/operators/detection_output_op.cc +++ b/paddle/operators/detection_output_op.cc @@ -21,42 +21,37 @@ class Detection_output_OpMaker : public framework::OpProtoAndCheckerMaker { Detection_output_OpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Loc", - "(Tensor) The input tensor of detection_output operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of feature."); - AddInput( - "Conf", - "(Tensor) The input tensor of detection_output operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of feature."); - AddInput( - "PriorBox", - "(Tensor) The input tensor of detection_output operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of feature."); + AddInput("Loc", + "(Tensor) The input tensor of detection_output operator. " + "The format of input tensor is kNCHW. Where K is priorbox point " + "numbers," + "N is How many boxes are there on each point, " + "C is 4, H and W both are 1."); + AddInput("Conf", + "(Tensor) The input tensor of detection_output operator. " + "The format of input tensor is kNCHW. Where K is priorbox point " + "numbers," + "N is How many boxes are there on each point, " + "C is the number of classes, H and W both are 1."); + AddInput("PriorBox", + "(Tensor) The input tensor of detection_output operator. " + "The format of input tensor is the position and variance " + "of the boxes"); AddOutput("Out", - "(Tensor) The output tensor of detection_output operator." - "N * M." - "M = C * H * W"); - AddAttr("background_label_id", "(int), multi level pooling"); - AddAttr("num_classes", "(int), multi level pooling"); - AddAttr("nms_threshold", "(int), multi level pooling"); - AddAttr("confidence_threshold", "(int), multi level pooling"); - AddAttr("top_k", "(int), multi level pooling"); - AddAttr("nms_top_k", "(int), multi level pooling"); + "(Tensor) The output tensor of detection_output operator."); + AddAttr("background_label_id", + "(int), the attr of detection_output operator"); + AddAttr("num_classes", + "(int), the attr of detection_output operator"); + AddAttr("nms_threshold", + "(float), the attr of detection_output operator"); + AddAttr("confidence_threshold", + "(float), the attr of detection_output operator"); + AddAttr("top_k", "(int), the attr of detection_output operator"); + AddAttr("nms_top_k", "(int), the attr of detection_output operator"); AddComment(R"DOC( - "Does spatial pyramid pooling on the input image by taking the max, - etc. within regions so that the result vector of different sized - images are of the same size - Input shape: $(N, C_{in}, H_{in}, W_{in})$ - Output shape: $(H_{out}, W_{out})$ - Where - $$ - H_{out} = N \\ - W_{out} = (((4^pyramid_height) - 1) / (4 - 1))$ * C_{in} - $$ + detection output for SSD(single shot multibox detector) + )DOC"); } }; diff --git a/paddle/operators/detection_output_op.h b/paddle/operators/detection_output_op.h index d03452ff8d..508e3d6939 100644 --- a/paddle/operators/detection_output_op.h +++ b/paddle/operators/detection_output_op.h @@ -18,10 +18,34 @@ limitations under the License. */ #include "paddle/operators/math/detection_util.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/softmax.h" - +#include "paddle/operators/strided_memcpy.h" namespace paddle { namespace operators { template +void transpose_fun(const platform::DeviceContext& context, + const framework::Tensor& src, framework::Tensor* dst) { + int input_nums = src.dims()[0]; + int offset = 0; + for (int j = 0; j < input_nums; ++j) { + framework::Tensor in_p_tensor = src.Slice(j, j + 1); + std::vector shape_vec( + {in_p_tensor.dims()[0], in_p_tensor.dims()[1], in_p_tensor.dims()[3], + in_p_tensor.dims()[4], in_p_tensor.dims()[2]}); + framework::DDim shape(framework::make_ddim(shape_vec)); + framework::Tensor in_p_tensor_transpose; + in_p_tensor_transpose.mutable_data(shape, context.GetPlace()); + std::vector shape_axis({0, 1, 3, 4, 2}); + math::Transpose trans5; + trans5(context, in_p_tensor, &in_p_tensor_transpose, shape_axis); + auto dst_stride = framework::stride(dst->dims()); + auto src_stride = framework::stride(in_p_tensor_transpose.dims()); + StridedMemcpy(context, in_p_tensor_transpose.data(), src_stride, + in_p_tensor_transpose.dims(), dst_stride, + dst->data() + offset); + offset += in_p_tensor_transpose.dims()[4] * src_stride[4]; + } +} +template class Detection_output_Kernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -37,77 +61,51 @@ class Detection_output_Kernel : public framework::OpKernel { float nms_threshold = context.template Attr("nms_threshold"); float confidence_threshold = context.template Attr("confidence_threshold"); - - int input_num = in_loc->dims()[0]; - int batch_size = in_loc->dims()[1]; - int channels = in_loc->dims()[2]; - int height = in_loc->dims()[3]; - int weight = in_loc->dims()[4]; - int loc_sum_size = in_loc->numel(); + int batch_size = in_conf->dims()[1]; int conf_sum_size = in_conf->numel(); - std::vector loc_shape_vec({1, loc_sum_size}); - std::vector conf_shape_vec( + // for softmax + std::vector conf_shape_softmax_vec( {conf_sum_size / num_classes, num_classes}); + framework::DDim conf_shape_softmax( + framework::make_ddim(conf_shape_softmax_vec)); + // for knchw => nhwc + std::vector loc_shape_vec({1, in_loc->dims()[1], in_loc->dims()[3], + in_loc->dims()[4], in_loc->dims()[2]}); + std::vector conf_shape_vec({1, in_conf->dims()[1], + in_conf->dims()[3], in_conf->dims()[4], + in_conf->dims()[2]}); framework::DDim loc_shape(framework::make_ddim(loc_shape_vec)); framework::DDim conf_shape(framework::make_ddim(conf_shape_vec)); framework::Tensor loc_tensor; framework::Tensor conf_tensor; - loc_tensor.Resize(loc_shape); - conf_tensor.Resize(conf_shape); loc_tensor.mutable_data(loc_shape, context.GetPlace()); conf_tensor.mutable_data(conf_shape, context.GetPlace()); + // for cpu framework::Tensor loc_cpu; framework::Tensor conf_cpu; framework::Tensor priorbox_cpu; - const T* in_loc_data = in_loc->data(); - const T* in_conf_data = in_conf->data(); - T* loc_data; - T* conf_data; const T* priorbox_data = in_priorbox->data(); - + transpose_fun(context.device_context(), *in_loc, &loc_tensor); + transpose_fun(context.device_context(), *in_conf, &conf_tensor); + conf_tensor.Resize(conf_shape_softmax); + math::SoftmaxFunctor()(context.device_context(), &conf_tensor, + &conf_tensor); + T* loc_data = loc_tensor.data(); + T* conf_data = conf_tensor.data(); if (platform::is_gpu_place(context.GetPlace())) { - loc_cpu.mutable_data(in_loc->dims(), platform::CPUPlace()); - framework::CopyFrom(*in_loc, platform::CPUPlace(), + loc_cpu.mutable_data(loc_tensor.dims(), platform::CPUPlace()); + framework::CopyFrom(loc_tensor, platform::CPUPlace(), context.device_context(), &loc_cpu); - in_loc_data = loc_cpu.data(); - conf_cpu.mutable_data(in_conf->dims(), platform::CPUPlace()); - framework::CopyFrom(*in_conf, platform::CPUPlace(), + loc_data = loc_cpu.data(); + conf_cpu.mutable_data(conf_tensor.dims(), platform::CPUPlace()); + framework::CopyFrom(conf_tensor, platform::CPUPlace(), context.device_context(), &conf_cpu); - in_conf_data = conf_cpu.data(); + conf_data = conf_cpu.data(); priorbox_cpu.mutable_data(in_priorbox->dims(), platform::CPUPlace()); framework::CopyFrom(*in_priorbox, platform::CPUPlace(), context.device_context(), &priorbox_cpu); priorbox_data = priorbox_cpu.data(); - loc_tensor.mutable_data(loc_shape, platform::CPUPlace()); - conf_tensor.mutable_data(conf_shape, platform::CPUPlace()); - } - T* loc_tensor_data = loc_tensor.data(); - T* conf_tensor_data = conf_tensor.data(); - for (int i = 0; i < input_num; ++i) { - math::appendWithPermute(in_loc_data, input_num, batch_size, channels, - height, weight, loc_tensor_data); - math::appendWithPermute(in_conf_data, input_num, batch_size, channels, - height, weight, conf_tensor_data); - } - loc_data = loc_tensor.data(); - if (platform::is_gpu_place(context.GetPlace())) { - framework::Tensor conf_gpu; - conf_gpu.Resize(conf_shape); - conf_gpu.mutable_data(conf_shape, context.GetPlace()); - framework::CopyFrom(conf_tensor, platform::GPUPlace(), - context.device_context(), &conf_gpu); - // softmax - math::SoftmaxFunctor()(context.device_context(), &conf_gpu, - &conf_gpu); - conf_tensor.mutable_data(conf_gpu.dims(), platform::CPUPlace()); - framework::CopyFrom(conf_gpu, platform::CPUPlace(), - context.device_context(), &conf_tensor); - } else { - // softmax - math::SoftmaxFunctor()(context.device_context(), &conf_tensor, - &conf_tensor); } - conf_data = conf_tensor.data(); // get decode bboxes size_t num_priors = in_priorbox->numel() / 8; std::vector>> all_decoded_bboxes; diff --git a/paddle/operators/math/detection_util.h b/paddle/operators/math/detection_util.h index 12d9ca9da8..b671f7b517 100644 --- a/paddle/operators/math/detection_util.h +++ b/paddle/operators/math/detection_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/framework/selected_rows.h" #include "paddle/platform/device_context.h" namespace paddle { namespace operators { namespace math { - template struct BBox { BBox(T x_min, T y_min, T x_max, T y_max) @@ -49,31 +49,47 @@ struct BBox { bool is_difficult; }; // KNCHW ==> NHWC +// template template -int appendWithPermute(const T* input_data, int input_nums, int batch_size, - int channels, int height, int weight, T* output_data) { - int image_size = height * weight; - int numel = input_nums * batch_size * channels * height * weight; - int offset = 0; - for (int p = 0; p < input_nums; ++p) { - int in_p_offset = p * batch_size * channels * image_size; - for (int n = 0; n < batch_size; ++n) { - int in_n_offset = n * channels * image_size; - int out_n_offset = n * numel / batch_size + offset; - int in_stride = image_size; - int out_stride = channels; - const T* in_data = input_data + in_p_offset + in_n_offset; - T* out_data = output_data + out_n_offset; - for (int c = 0; c < channels; ++c) { - for (int i = 0; i < image_size; ++i) { - out_data[out_stride * i + c] = in_data[c * in_stride + i]; - } - } - } - offset += image_size * channels; - } - return 0; -} +void getBBoxFromPriorData(const T* prior_data, const size_t num_bboxes, + std::vector>& bbox_vec); +template +void getBBoxVarFromPriorData(const T* prior_data, const size_t num, + std::vector>& var_vec); +template +BBox decodeBBoxWithVar(BBox& prior_bbox, + const std::vector& prior_bbox_var, + const std::vector& loc_pred_data); +template +bool sortScorePairDescend(const std::pair& pair1, + const std::pair& pair2); +template +bool sortScorePairDescend(const std::pair>& pair1, + const std::pair>& pair2); +template +T jaccardOverlap(const BBox& bbox1, const BBox& bbox2); + +template +void applyNMSFast(const std::vector>& bboxes, const T* conf_score_data, + size_t class_idx, size_t top_k, T conf_threshold, + T nms_threshold, size_t num_priors, size_t num_classes, + std::vector* indices); +template +int getDetectionIndices( + const T* conf_data, const size_t num_priors, const size_t num_classes, + const size_t background_label_id, const size_t batch_size, + const T conf_threshold, const size_t nms_top_k, const T nms_threshold, + const size_t top_k, + const std::vector>>& all_decoded_bboxes, + std::vector>>* all_detection_indices); +template +BBox clipBBox(const BBox& bbox); +template +void getDetectionOutput( + const T* conf_data, const size_t num_kept, const size_t num_priors, + const size_t num_classes, const size_t batch_size, + const std::vector>>& all_indices, + const std::vector>>& all_decoded_bboxes, T* out_data); template void getBBoxFromPriorData(const T* prior_data, const size_t num_bboxes, std::vector>& bbox_vec) { @@ -136,9 +152,6 @@ bool sortScorePairDescend(const std::pair& pair1, return pair1.first > pair2.first; } template -bool sortScorePairDescend(const std::pair>& pair1, - const std::pair>& pair2); -template T jaccardOverlap(const BBox& bbox1, const BBox& bbox2) { if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min || bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) { @@ -281,7 +294,6 @@ void getDetectionOutput( } } } - // out.copyFrom(out_data, num_kept * 7); } } // namespace math } // namespace operators diff --git a/python/paddle/v2/fluid/tests/test_detection_output_op.py b/python/paddle/v2/fluid/tests/test_detection_output_op.py index 56cd5dde9f..080a9743b0 100644 --- a/python/paddle/v2/fluid/tests/test_detection_output_op.py +++ b/python/paddle/v2/fluid/tests/test_detection_output_op.py @@ -8,22 +8,24 @@ class TestUnpoolOp(OpTest): self.op_type = "detection_output" self.init_test_case() - #loc = np.zeros((1, 4, 4, 1, 1)) - #conf = np.zero((1, 4, 2, 1, 1)) + #loc.shape ((1, 4, 4, 1, 1)) + #conf.shape ((1, 4, 2, 1, 1)) loc = np.array([[[[[0.1]], [[0.1]], [[0.1]], [[0.1]]], [[[0.1]], [[0.1]], [[0.1]], [[0.1]]], [[[0.1]], [[0.1]], [[0.1]], [[0.1]]], [[[0.1]], [[0.1]], [[0.1]], [[0.1]]]]]) - conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]]], - [[[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]]) - priorbox = np.array([0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2,\ - 0.2, 0.2, 0.6, 0.6, 0.1, 0.1, 0.2, 0.2,\ - 0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2,\ - 0.4, 0.4, 0.8, 0.8, 0.1, 0.1, 0.2, 0.2]) - - output = np.array([0, 1, 0.68997443, 0.099959746, 0.099959746,\ - 0.50804031, 0.50804031]) + conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]], + [[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]]) + priorbox = np.array([ + 0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.6, 0.6, 0.1, + 0.1, 0.2, 0.2, 0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2, 0.4, 0.4, + 0.8, 0.8, 0.1, 0.1, 0.2, 0.2 + ]) + + output = np.array([ + 0, 1, 0.68997443, 0.099959746, 0.099959746, 0.50804031, 0.50804031 + ]) self.inputs = { 'Loc': loc.astype('float32'), 'Conf': conf.astype('float32'), -- GitLab