From 4dc4d5fc45efd20d62ca9aebc94343ee7c5f8f30 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 20 Oct 2022 19:54:01 +0800 Subject: [PATCH] [MKLDNN] Delete mkldnn hard code of fc (#47138) * remove fc mkldnn hardcode * remove useless enum of kFCMKLDNN * fix macro error * update operators.cmake --- cmake/operators.cmake | 7 -- paddle/fluid/operators/fc_op.cc | 12 --- paddle/fluid/operators/fc_op.h | 2 - paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 74 ++++++++----------- paddle/fluid/platform/mkldnn_op_list.h | 1 - 5 files changed, 32 insertions(+), 64 deletions(-) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 650bcc40259..0cd21c942e0 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -511,13 +511,6 @@ function(op_library TARGET) # Append first implemented MKLDNN activation operator if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") - elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n") else() foreach(mkldnn_src ${mkldnn_cc_srcs}) set(op_name "") diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 7c8b7a6544f..d4d160d315d 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -128,18 +128,6 @@ class FCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - if (ctx.Attr("use_mkldnn")) { - using framework::proto::VarType; - int customized_type_value = (input_data_type == VarType::INT8 || - input_data_type == VarType::UINT8) - ? kFCMKLDNNINT8 - : kFCMKLDNNFP32; - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - phi::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - customized_type_value); - } return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fc_op.h b/paddle/fluid/operators/fc_op.h index 73f9559107b..433288e885d 100644 --- a/paddle/fluid/operators/fc_op.h +++ b/paddle/fluid/operators/fc_op.h @@ -22,8 +22,6 @@ limitations under the License. */ namespace paddle { namespace operators { -enum { kFCMKLDNNFP32 = 1, kFCMKLDNNINT8 = 2 }; - using Tensor = phi::DenseTensor; inline void FCOutputSize(const framework::DDim& in_dims, diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 12d2bdef791..a831c64aa8f 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -320,27 +320,38 @@ class FCMKLDNNHandler } // namespace operators }; // namespace paddle -template +#define IF_CHANGE_FC_TW_TYPENAME(condition, ...) \ + if (condition) { \ + using T_w = int8_t; \ + __VA_ARGS__(); \ + } else { \ + using T_w = T_in; \ + __VA_ARGS__(); \ + } + +template class FCMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { bool force_fp32_output = ctx.Attr("force_fp32_output"); bool fuse_relu = ctx.Attr("activation_type") == "relu"; - if (force_fp32_output) { - this->RunKernel(ctx); - } else if (IsInt8()) { - if (fuse_relu) { - this->RunKernel(ctx); - } else { - this->RunKernel(ctx); - } - } else { - this->RunKernel(ctx); - } + IF_CHANGE_FC_TW_TYPENAME((std::is_same::value), ([&] { + if (force_fp32_output) { + this->RunKernel(ctx); + } else if (IsInt8()) { + if (fuse_relu) { + this->RunKernel(ctx); + } else { + this->RunKernel(ctx); + } + } else { + this->RunKernel(ctx); + } + })); } - template + template void RunKernel(const framework::ExecutionContext& ctx) const { const auto& dev_ctx = ctx.template device_context(); @@ -422,32 +433,11 @@ class FCMKLDNNKernel : public framework::OpKernel { // data type implies their destination data type. (What's eventually going to // be used during computations of kernel). namespace ops = paddle::operators; -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kFCMKLDNNFP32, - ops::FCMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - fc, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kFCMKLDNNFP32, - ops::FCMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, - MKLDNN, - ::paddle::platform::CPUPlace, - U8, - ops::kFCMKLDNNINT8, - ops::FCMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, - MKLDNN, - ::paddle::platform::CPUPlace, - S8, - ops::kFCMKLDNNINT8, - ops::FCMKLDNNKernel); + +REGISTER_OP_KERNEL(fc, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::FCMKLDNNKernel, + ops::FCMKLDNNKernel, + ops::FCMKLDNNKernel, + ops::FCMKLDNNKernel); diff --git a/paddle/fluid/platform/mkldnn_op_list.h b/paddle/fluid/platform/mkldnn_op_list.h index b9fab4a699c..499060fb1fb 100644 --- a/paddle/fluid/platform/mkldnn_op_list.h +++ b/paddle/fluid/platform/mkldnn_op_list.h @@ -70,7 +70,6 @@ static const std::unordered_set mkldnn_white_list = { // NOTE(jiahongyu): Below ops register kernel with customized_type_value, we // need to analysis and solve them one-by-one. "prior_box", - "fc", "mul", "mul_grad"}; -- GitLab