diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index a2b04121b7ae8bfba419f9e96b50ba994c4fc3d9..0dcfe4d61cbb67d13b44fd4cedd85143d952c19e 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -37,9 +37,6 @@ using dnnl::memory; using dnnl::prop_kind; using dnnl::stream; -constexpr int kMULMKLDNNINT8 = 1; -constexpr int kMULMKLDNNFP32 = 2; - template class MulPrimitiveFactory { public: @@ -340,6 +337,7 @@ std::shared_ptr> GetPrimitiveFactory( return prim_creator; } +/* XT: input x data type, YT: input y data type */ template inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx, @@ -363,8 +361,8 @@ inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx, } } -/* XT: input x data type, YT: input y data type */ -template +/* XT: input x data type */ +template class MulMKLDNNINT8Kernel : public framework::OpKernel { public: void Compute(const ExecutionContext &ctx) const override { @@ -381,7 +379,8 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel { Tensor *out = ctx.Output("Out"); auto out_dims = out->dims(); - auto mul = GetMulPrimitive(dev_ctx, ctx, x, y, out, mkldnn_engine); + auto mul = + GetMulPrimitive(dev_ctx, ctx, x, y, out, mkldnn_engine); if (out_dims.size() != 2) { out->Resize(out_dims); @@ -393,7 +392,7 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel { } }; -template +template class MulMKLDNNKernel : public framework::OpKernel { public: void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } @@ -411,7 +410,7 @@ class MulMKLDNNKernel : public framework::OpKernel { bool trans_y, Tensor *out) const { static const std::vector vec_placeholder; - MatMulV2MKLDNNHandler handler(ctx, + MatMulV2MKLDNNHandler handler(ctx, onednn_engine, ctx.GetPlace(), x_dims, @@ -487,13 +486,12 @@ class MulMKLDNNKernel : public framework::OpKernel { } }; -template -class MulGradMKLDNNKernel : public MulMKLDNNKernel { +template +class MulGradMKLDNNKernel : public MulMKLDNNKernel { public: void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } private: - template void RunKernel(const ExecutionContext &ctx) const { const auto &dev_ctx = ctx.template device_context(); const auto &onednn_engine = dev_ctx.GetEngine(); @@ -569,57 +567,17 @@ class MulGradMKLDNNKernel : public MulMKLDNNKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, - MKLDNN, - ::paddle::platform::CPUPlace, - U8, - ops::kMULMKLDNNINT8, - ops::MulMKLDNNINT8Kernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, - MKLDNN, - ::paddle::platform::CPUPlace, - S8, - ops::kMULMKLDNNINT8, - ops::MulMKLDNNINT8Kernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kMULMKLDNNFP32, - ops::MulMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - mul, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kMULMKLDNNFP32, - ops::MulMKLDNNKernel); REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace, - ops::MulMKLDNNINT8Kernel, - ops::MulMKLDNNKernel, - ops::MulMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kMULMKLDNNFP32, - ops::MulGradMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - mul_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kMULMKLDNNFP32, - ops::MulGradMKLDNNKernel, - ops::MulGradMKLDNNKernel); + ops::MulMKLDNNINT8Kernel, + ops::MulMKLDNNINT8Kernel, + ops::MulMKLDNNKernel, + ops::MulMKLDNNKernel); + +REGISTER_OP_KERNEL(mul_grad, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::MulGradMKLDNNKernel, + ops::MulGradMKLDNNKernel); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 85be9df180a34275c1701d575c42c5cc8da88416..1c0cee0de62d48ae93422f31bb6f514fc74236dc 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -31,9 +31,6 @@ namespace operators { using framework::OpKernelType; -constexpr int kMULMKLDNNINT8 = 1; -constexpr int kMULMKLDNNFP32 = 2; - class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -41,29 +38,6 @@ class MulOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - -#ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - if (input_data_type == framework::DataTypeTrait::DataType() || - input_data_type == framework::DataTypeTrait::DataType()) { - customized_type_value = kMULMKLDNNINT8; - } else if (input_data_type == - framework::DataTypeTrait< - paddle::platform::bfloat16>::DataType() || - input_data_type == - framework::DataTypeTrait::DataType()) { - customized_type_value = kMULMKLDNNFP32; - } - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - customized_type_value); - } -#endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -136,29 +110,6 @@ class MulGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - -#ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - if (input_data_type == framework::DataTypeTrait::DataType() || - input_data_type == framework::DataTypeTrait::DataType()) { - customized_type_value = kMULMKLDNNINT8; - } else if (input_data_type == - framework::DataTypeTrait< - paddle::platform::bfloat16>::DataType() || - input_data_type == - framework::DataTypeTrait::DataType()) { - customized_type_value = kMULMKLDNNFP32; - } - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - customized_type_value); - } -#endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/platform/mkldnn_op_list.h b/paddle/fluid/platform/mkldnn_op_list.h index 499060fb1fbeae18a9193fa5cd42b8a4ecc28f8f..35046bcd9c1911160154896495fbf9d533295e4a 100644 --- a/paddle/fluid/platform/mkldnn_op_list.h +++ b/paddle/fluid/platform/mkldnn_op_list.h @@ -69,9 +69,7 @@ static const std::unordered_set mkldnn_white_list = { "reduce_sum_grad", // NOTE(jiahongyu): Below ops register kernel with customized_type_value, we // need to analysis and solve them one-by-one. - "prior_box", - "mul", - "mul_grad"}; + "prior_box"}; inline bool in_mkldnn_white_list(const std::string& op_name) { return mkldnn_white_list.find(op_name) != mkldnn_white_list.end();