diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 33ad494b2918f324844bd4b3f7ce3e2f433c3f24..16a4a35f6698d46054b2edec82ebac6799ec4892 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -34,26 +34,6 @@ class PriorBoxOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { auto input_input_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - -#ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx, input_input_type)) { - auto input_image_type = framework::TransToProtoVarType( - ctx.Input("Image")->dtype()); - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - if (input_image_type == framework::DataTypeTrait::DataType()) { - customized_type_value = kPriorBoxFLOAT; - } else if (input_image_type == - framework::DataTypeTrait::DataType()) { - customized_type_value = kPriorBoxDOUBLE; - } - return framework::OpKernelType(input_input_type, - ctx.GetPlace(), - phi::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - customized_type_value); - } -#endif return framework::OpKernelType(input_input_type, ctx.GetPlace()); } }; @@ -209,44 +189,10 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, PriorBoxInferShapeFunctor); -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - FF, - ops::kPriorBoxFLOAT, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - DD, - ops::kPriorBoxDOUBLE, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - U8F, - ops::kPriorBoxFLOAT, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - S8F, - ops::kPriorBoxFLOAT, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - U8D, - ops::kPriorBoxDOUBLE, - ops::PriorBoxOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(prior_box, - MKLDNN, - ::paddle::platform::CPUPlace, - S8D, - ops::kPriorBoxDOUBLE, - ops::PriorBoxOpKernel); +REGISTER_OP_KERNEL(prior_box, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/fluid/operators/detection/prior_box_op.h b/paddle/fluid/operators/detection/prior_box_op.h index 3adbfda50a779def299d7b06b4980f94b9b0254a..51e8bbbcae530c13e9ea7a77042166f4897ffa27 100644 --- a/paddle/fluid/operators/detection/prior_box_op.h +++ b/paddle/fluid/operators/detection/prior_box_op.h @@ -18,14 +18,12 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { -constexpr int kPriorBoxFLOAT = 1; -constexpr int kPriorBoxDOUBLE = 2; - inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, bool flip, std::vector* output_aspect_ratior) { @@ -50,10 +48,19 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, } } -template +template class PriorBoxOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto* image = ctx.Input("Image"); + + PD_VISIT_FLOATING_TYPES(image->dtype(), "PriorBoxOpHandler", ([&] { + PriorBoxOpHandler(ctx); + })); + } + + template + void PriorBoxOpHandler(const framework::ExecutionContext& ctx) const { auto* input = ctx.Input("Input"); auto* image = ctx.Input("Image"); auto* boxes = ctx.Output("Boxes"); @@ -200,7 +207,7 @@ class PriorBoxOpKernel : public framework::OpKernel { } vars->Resize(var_dim); } -}; // namespace operators +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/mkldnn_op_list.h b/paddle/fluid/platform/mkldnn_op_list.h index 686a70b1cf5f9739fbb4d3a47cb0acb99d0bbd8f..a5faac2cbd53f94f80a6413524467051a1bcb077 100644 --- a/paddle/fluid/platform/mkldnn_op_list.h +++ b/paddle/fluid/platform/mkldnn_op_list.h @@ -54,10 +54,7 @@ static const std::unordered_set mkldnn_white_list = { "flatten", "flatten_grad", "flatten2", - "flatten2_grad", - // NOTE(jiahongyu): Below ops register kernel with customized_type_value, we - // need to analysis and solve them one-by-one. - "prior_box"}; + "flatten2_grad"}; inline bool in_mkldnn_white_list(const std::string& op_name) { return mkldnn_white_list.find(op_name) != mkldnn_white_list.end(); diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 6c9291f816f7a14afba35a798cf2e5af926a1ab0..02f0fbb895215ba38e914851da9849c49bf64172 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -139,8 +139,10 @@ const T* DenseTensor::data() const { dtype(), paddle::experimental::CppTypeToDataType::Type(), phi::errors::InvalidArgument( - "The type of data we are trying to retrieve does not match the " - "type of data currently contained in the container.")); + "The type of data we are trying to retrieve (%s) does not match the " + "type of data (%s) currently contained in the container.", + paddle::experimental::CppTypeToDataType::Type(), + dtype())); return static_cast(data()); } @@ -150,8 +152,10 @@ T* DenseTensor::data() { PADDLE_ENFORCE( (dtype() == paddle::experimental::CppTypeToDataType::Type()), phi::errors::InvalidArgument( - "The type of data we are trying to retrieve does not match the " - "type of data currently contained in the container.")); + "The type of data we are trying to retrieve (%s) does not match the " + "type of data (%s) currently contained in the container.", + paddle::experimental::CppTypeToDataType::Type(), + dtype())); return ret; }