diff --git a/paddle/operators/detection_output_op.cc b/paddle/operators/detection_output_op.cc index c018795fd4bfa9ad5644a2b8fac1e4b80aa6c461..a04d6e57583376c82a5a8f445275e3d33a9128cf 100644 --- a/paddle/operators/detection_output_op.cc +++ b/paddle/operators/detection_output_op.cc @@ -65,17 +65,18 @@ class Detection_output_Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), + PADDLE_ENFORCE(ctx->HasInput("Loc"), + "Input(X) of Detection_output_Op" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Conf"), + "Input(X) of Detection_output_Op" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PriorBox"), "Input(X) of Detection_output_Op" "should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of Detection_output_Op should not be null."); - auto in_x_dims = ctx->GetInputDim("X"); - int pyramid_height = ctx->Attrs().Get("pyramid_height"); - PADDLE_ENFORCE(in_x_dims.size() == 4, - "Detection_output_ing intput must be of 4-dimensional."); - int outlen = ((std::pow(4, pyramid_height) - 1) / (4 - 1)) * in_x_dims[1]; - std::vector output_shape({in_x_dims[0], outlen}); + std::vector output_shape({1, 7}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } }; diff --git a/paddle/operators/detection_output_op.h b/paddle/operators/detection_output_op.h index 184b864974d514c201af18f52a38a8311edbb539..d03452ff8d069220a93bc35dcef464d727572e4e 100644 --- a/paddle/operators/detection_output_op.h +++ b/paddle/operators/detection_output_op.h @@ -40,6 +40,9 @@ class Detection_output_Kernel : public framework::OpKernel { int input_num = in_loc->dims()[0]; int batch_size = in_loc->dims()[1]; + int channels = in_loc->dims()[2]; + int height = in_loc->dims()[3]; + int weight = in_loc->dims()[4]; int loc_sum_size = in_loc->numel(); int conf_sum_size = in_conf->numel(); std::vector loc_shape_vec({1, loc_sum_size}); @@ -49,17 +52,62 @@ class Detection_output_Kernel : public framework::OpKernel { framework::DDim conf_shape(framework::make_ddim(conf_shape_vec)); framework::Tensor loc_tensor; framework::Tensor conf_tensor; + loc_tensor.Resize(loc_shape); + conf_tensor.Resize(conf_shape); loc_tensor.mutable_data(loc_shape, context.GetPlace()); conf_tensor.mutable_data(conf_shape, context.GetPlace()); + framework::Tensor loc_cpu; + framework::Tensor conf_cpu; + framework::Tensor priorbox_cpu; + const T* in_loc_data = in_loc->data(); + const T* in_conf_data = in_conf->data(); + T* loc_data; + T* conf_data; + const T* priorbox_data = in_priorbox->data(); - // KNCHW ==> NHWC + if (platform::is_gpu_place(context.GetPlace())) { + loc_cpu.mutable_data(in_loc->dims(), platform::CPUPlace()); + framework::CopyFrom(*in_loc, platform::CPUPlace(), + context.device_context(), &loc_cpu); + in_loc_data = loc_cpu.data(); + conf_cpu.mutable_data(in_conf->dims(), platform::CPUPlace()); + framework::CopyFrom(*in_conf, platform::CPUPlace(), + context.device_context(), &conf_cpu); + in_conf_data = conf_cpu.data(); + priorbox_cpu.mutable_data(in_priorbox->dims(), platform::CPUPlace()); + framework::CopyFrom(*in_priorbox, platform::CPUPlace(), + context.device_context(), &priorbox_cpu); + priorbox_data = priorbox_cpu.data(); + loc_tensor.mutable_data(loc_shape, platform::CPUPlace()); + conf_tensor.mutable_data(conf_shape, platform::CPUPlace()); + } + T* loc_tensor_data = loc_tensor.data(); + T* conf_tensor_data = conf_tensor.data(); for (int i = 0; i < input_num; ++i) { - math::appendWithPermute(*in_loc, &loc_tensor); - math::appendWithPermute(*in_conf, &conf_tensor); + math::appendWithPermute(in_loc_data, input_num, batch_size, channels, + height, weight, loc_tensor_data); + math::appendWithPermute(in_conf_data, input_num, batch_size, channels, + height, weight, conf_tensor_data); + } + loc_data = loc_tensor.data(); + if (platform::is_gpu_place(context.GetPlace())) { + framework::Tensor conf_gpu; + conf_gpu.Resize(conf_shape); + conf_gpu.mutable_data(conf_shape, context.GetPlace()); + framework::CopyFrom(conf_tensor, platform::GPUPlace(), + context.device_context(), &conf_gpu); + // softmax + math::SoftmaxFunctor()(context.device_context(), &conf_gpu, + &conf_gpu); + conf_tensor.mutable_data(conf_gpu.dims(), platform::CPUPlace()); + framework::CopyFrom(conf_gpu, platform::CPUPlace(), + context.device_context(), &conf_tensor); + } else { + // softmax + math::SoftmaxFunctor()(context.device_context(), &conf_tensor, + &conf_tensor); } - // softmax - math::SoftmaxFunctor()(context.device_context(), &conf_tensor, - &conf_tensor); + conf_data = conf_tensor.data(); // get decode bboxes size_t num_priors = in_priorbox->numel() / 8; std::vector>> all_decoded_bboxes; @@ -69,29 +117,26 @@ class Detection_output_Kernel : public framework::OpKernel { size_t prior_offset = i * 8; size_t loc_pred_offset = n * num_priors * 4 + i * 4; std::vector> prior_bbox_vec; - math::getBBoxFromPriorData(in_priorbox->data() + prior_offset, 1, + math::getBBoxFromPriorData(priorbox_data + prior_offset, 1, prior_bbox_vec); std::vector> prior_bbox_var; - math::getBBoxVarFromPriorData(in_priorbox->data() + prior_offset, - 1, prior_bbox_var); + math::getBBoxVarFromPriorData(priorbox_data + prior_offset, 1, + prior_bbox_var); std::vector loc_pred_data; for (size_t j = 0; j < 4; ++j) - loc_pred_data.push_back( - *(loc_tensor.data() + loc_pred_offset + j)); + loc_pred_data.push_back(*(loc_data + loc_pred_offset + j)); math::BBox bbox = math::decodeBBoxWithVar( prior_bbox_vec[0], prior_bbox_var[0], loc_pred_data); decoded_bboxes.push_back(bbox); } all_decoded_bboxes.push_back(decoded_bboxes); } - std::vector>> all_indices; int num_kept = math::getDetectionIndices( - conf_tensor.data(), num_priors, num_classes, background_label_id, - batch_size, confidence_threshold, nms_top_k, nms_threshold, top_k, + conf_data, num_priors, num_classes, background_label_id, batch_size, + confidence_threshold, nms_top_k, nms_threshold, top_k, all_decoded_bboxes, &all_indices); - framework::Tensor out_tmp; if (num_kept <= 0) { std::vector out_shape_vec({0, 0}); framework::DDim out_shape(framework::make_ddim(out_shape_vec)); @@ -100,14 +145,20 @@ class Detection_output_Kernel : public framework::OpKernel { } std::vector out_shape_vec({num_kept, 7}); framework::DDim out_shape(framework::make_ddim(out_shape_vec)); - out_tmp.mutable_data(out_shape, context.GetPlace()); - - T* out_data = out_tmp.data(); - math::getDetectionOutput(conf_tensor.data(), num_kept, num_priors, - num_classes, batch_size, all_indices, - all_decoded_bboxes, out_data); out->mutable_data(out_shape, context.GetPlace()); - out->ShareDataWith(out_tmp); + framework::Tensor out_cpu; + T* out_data = out->data(); + if (platform::is_gpu_place(context.GetPlace())) { + out_cpu.mutable_data(out->dims(), platform::CPUPlace()); + out_data = out_cpu.data(); + } + math::getDetectionOutput(conf_data, num_kept, num_priors, num_classes, + batch_size, all_indices, all_decoded_bboxes, + out_data); + if (platform::is_gpu_place(context.GetPlace())) { + framework::CopyFrom(out_cpu, platform::GPUPlace(), + context.device_context(), out); + } } }; } // namespace operators diff --git a/paddle/operators/math/detection_util.h b/paddle/operators/math/detection_util.h index 265fa077010fedcd6ea50430f2bbb62fe55be42c..12d9ca9da87f24ad1edcaf1c03245feac8899337 100644 --- a/paddle/operators/math/detection_util.h +++ b/paddle/operators/math/detection_util.h @@ -50,27 +50,23 @@ struct BBox { }; // KNCHW ==> NHWC template -int appendWithPermute(const framework::Tensor& input, - framework::Tensor* output) { - const int input_nums = input.dims()[0]; - const int batch_size = input.dims()[1]; - const int channels = input.dims()[2]; - const int height = input.dims()[3]; - const int weight = input.dims()[4]; +int appendWithPermute(const T* input_data, int input_nums, int batch_size, + int channels, int height, int weight, T* output_data) { int image_size = height * weight; + int numel = input_nums * batch_size * channels * height * weight; int offset = 0; for (int p = 0; p < input_nums; ++p) { int in_p_offset = p * batch_size * channels * image_size; for (int n = 0; n < batch_size; ++n) { int in_n_offset = n * channels * image_size; - int out_n_offset = n * input.numel() / batch_size + offset; + int out_n_offset = n * numel / batch_size + offset; int in_stride = image_size; int out_stride = channels; - const T* in_data = input.data() + in_p_offset + in_n_offset; - T* out_data = output->data() + out_n_offset; - for (int i = 0; i < channels; ++i) { - for (int c = 0; c < image_size; ++c) { - out_data[out_stride * c + i] = in_data[i * in_stride + c]; + const T* in_data = input_data + in_p_offset + in_n_offset; + T* out_data = output_data + out_n_offset; + for (int c = 0; c < channels; ++c) { + for (int i = 0; i < image_size; ++i) { + out_data[out_stride * i + c] = in_data[c * in_stride + i]; } } } diff --git a/python/paddle/v2/fluid/tests/test_detection_output_op.py b/python/paddle/v2/fluid/tests/test_detection_output_op.py new file mode 100644 index 0000000000000000000000000000000000000000..56cd5dde9f8cf33de580fae8ebb2422ca0475ae4 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_detection_output_op.py @@ -0,0 +1,55 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestUnpoolOp(OpTest): + def setUp(self): + self.op_type = "detection_output" + self.init_test_case() + + #loc = np.zeros((1, 4, 4, 1, 1)) + #conf = np.zero((1, 4, 2, 1, 1)) + + loc = np.array([[[[[0.1]], [[0.1]], [[0.1]], [[0.1]]], + [[[0.1]], [[0.1]], [[0.1]], [[0.1]]], + [[[0.1]], [[0.1]], [[0.1]], [[0.1]]], + [[[0.1]], [[0.1]], [[0.1]], [[0.1]]]]]) + conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]]], + [[[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]]) + priorbox = np.array([0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2,\ + 0.2, 0.2, 0.6, 0.6, 0.1, 0.1, 0.2, 0.2,\ + 0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2,\ + 0.4, 0.4, 0.8, 0.8, 0.1, 0.1, 0.2, 0.2]) + + output = np.array([0, 1, 0.68997443, 0.099959746, 0.099959746,\ + 0.50804031, 0.50804031]) + self.inputs = { + 'Loc': loc.astype('float32'), + 'Conf': conf.astype('float32'), + 'PriorBox': priorbox.astype('float32') + } + self.attrs = { + 'num_classes': self.num_classes, + 'top_k': self.top_k, + 'nms_top_k': self.nms_top_k, + 'background_label_id': self.background_label_id, + 'nms_threshold': self.nms_threshold, + 'confidence_threshold': self.confidence_threshold, + } + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.num_classes = 2 + self.top_k = 10 + self.nms_top_k = 20 + self.background_label_id = 0 + self.nms_threshold = 0.01 + self.confidence_threshold = 0.01 + + +if __name__ == '__main__': + unittest.main()