From d78dd7ea82fdd9572ad277dd8fcaa91e6b0ee35a Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Wed, 26 Oct 2022 19:07:13 +0800 Subject: [PATCH] [MKLDNN] Delete mkldnn hard code of prior_box (#47068) * remove prior_box mkldnn hard code * add header file * simplify PD_VISIT_TYPE * decouple dependency between prior_box and density_prior_box * fix pragma omp parallel error * bypass #pragma omp_parallel_for error * polish code * remove visit_type headerfile * polish codestyle * polish codestyle * try fix CI error * add testcase, datatype=float64 * reset test_prior_box testcase * add datacheck to DenseTensor * update template name * call prior_box with macro expand --- .../fluid/operators/detection/prior_box_op.cc | 68 ++----------------- .../fluid/operators/detection/prior_box_op.h | 17 +++-- paddle/fluid/platform/mkldnn_op_list.h | 5 +- paddle/phi/core/dense_tensor.cc | 12 ++-- 4 files changed, 28 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 33ad494b29..16a4a35f66 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -34,26 +34,6 @@ class PriorBoxOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { auto input_input_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - -#ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx, input_input_type)) { - auto input_image_type = framework::TransToProtoVarType( - ctx.Input("Image")->dtype()); - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - if (input_image_type == framework::DataTypeTrait::DataType()) { - customized_type_value = kPriorBoxFLOAT; - } else if (input_image_type == - framework::DataTypeTrait::DataType()) { - customized_type_value = kPriorBoxDOUBLE; - } - return framework::OpKernelType(input_input_type, - ctx.GetPlace(), - phi::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - customized_type_value); - } -#endif return framework::OpKernelType(input_input_type, ctx.GetPlace()); } }; @@ -209,44 +189,10 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, PriorBoxInferShapeFunctor); -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - FF, - ops::kPriorBoxFLOAT, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - DD, - ops::kPriorBoxDOUBLE, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - U8F, - ops::kPriorBoxFLOAT, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - S8F, - ops::kPriorBoxFLOAT, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - U8D, - ops::kPriorBoxDOUBLE, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - S8D, - ops::kPriorBoxDOUBLE, - ops::PriorBoxOpKernel); +REGISTER_OP_KERNEL(prior_box, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/fluid/operators/detection/prior_box_op.h b/paddle/fluid/operators/detection/prior_box_op.h index 3adbfda50a..51e8bbbcae 100644 --- a/paddle/fluid/operators/detection/prior_box_op.h +++ b/paddle/fluid/operators/detection/prior_box_op.h @@ -18,14 +18,12 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { -constexpr int kPriorBoxFLOAT = 1; -constexpr int kPriorBoxDOUBLE = 2; - inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, bool flip, std::vector* output_aspect_ratior) { @@ -50,10 +48,19 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, } } -template +template class PriorBoxOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto* image = ctx.Input("Image"); + + PD_VISIT_FLOATING_TYPES(image->dtype(), "PriorBoxOpHandler", ([&] { + PriorBoxOpHandler(ctx); + })); + } + + template + void PriorBoxOpHandler(const framework::ExecutionContext& ctx) const { auto* input = ctx.Input("Input"); auto* image = ctx.Input("Image"); auto* boxes = ctx.Output("Boxes"); @@ -200,7 +207,7 @@ class PriorBoxOpKernel : public framework::OpKernel { } vars->Resize(var_dim); } -}; // namespace operators +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/mkldnn_op_list.h b/paddle/fluid/platform/mkldnn_op_list.h index 686a70b1cf..a5faac2cbd 100644 --- a/paddle/fluid/platform/mkldnn_op_list.h +++ b/paddle/fluid/platform/mkldnn_op_list.h @@ -54,10 +54,7 @@ static const std::unordered_set mkldnn_white_list = { "flatten", "flatten_grad", "flatten2", - "flatten2_grad", - // NOTE(jiahongyu): Below ops register kernel with customized_type_value, we - // need to analysis and solve them one-by-one. - "prior_box"}; + "flatten2_grad"}; inline bool in_mkldnn_white_list(const std::string& op_name) { return mkldnn_white_list.find(op_name) != mkldnn_white_list.end(); diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 6c9291f816..02f0fbb895 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -139,8 +139,10 @@ const T* DenseTensor::data() const { dtype(), paddle::experimental::CppTypeToDataType::Type(), phi::errors::InvalidArgument( - "The type of data we are trying to retrieve does not match the " - "type of data currently contained in the container.")); + "The type of data we are trying to retrieve (%s) does not match the " + "type of data (%s) currently contained in the container.", + paddle::experimental::CppTypeToDataType::Type(), + dtype())); return static_cast(data()); } @@ -150,8 +152,10 @@ T* DenseTensor::data() { PADDLE_ENFORCE( (dtype() == paddle::experimental::CppTypeToDataType::Type()), phi::errors::InvalidArgument( - "The type of data we are trying to retrieve does not match the " - "type of data currently contained in the container.")); + "The type of data we are trying to retrieve (%s) does not match the " + "type of data (%s) currently contained in the container.", + paddle::experimental::CppTypeToDataType::Type(), + dtype())); return ret; } -- GitLab