From 98e0079311029e5dc5972dff459d45604b794c91 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 11 Oct 2022 13:01:57 +0800 Subject: [PATCH] [Opt transpose2] Opt GetExpectedKernelType code of transpose2 (#46692) * solve transpose2, follow #22402 * fix CI cmake * update REGISTER_OP_KERNEL of transpose2 --- cmake/operators.cmake | 7 ---- .../operators/mkldnn/transpose_mkldnn_op.cc | 37 ++++--------------- paddle/fluid/operators/transpose_op.cc | 10 +---- 3 files changed, 9 insertions(+), 45 deletions(-) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index bbf77b6615d..c3c8474b69f 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -518,13 +518,6 @@ function(op_library TARGET) "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n") - elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n") elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n") diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 836a71f43b2..0527041eb38 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -172,35 +172,6 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kTransposeMKLDNNFP32, - ops::TransposeMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, - MKLDNN, - ::paddle::platform::CPUPlace, - U8, - ops::kTransposeMKLDNNINT8, - ops::TransposeMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, - MKLDNN, - ::paddle::platform::CPUPlace, - S8, - ops::kTransposeMKLDNNINT8, - ops::TransposeMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - transpose2, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kTransposeMKLDNNFP32, - ops::TransposeMKLDNNOpKernel); - REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, @@ -210,3 +181,11 @@ REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNGradOpKernel); + +REGISTER_OP_KERNEL(transpose2, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::TransposeMKLDNNOpKernel, + ops::TransposeMKLDNNOpKernel, + ops::TransposeMKLDNNOpKernel, + ops::TransposeMKLDNNOpKernel); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 334ce8983c5..df62da0b565 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -249,18 +249,10 @@ class Transpose2Op : public TransposeOp { OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, data_type)) { - using framework::proto::VarType; - auto input_data_type = framework::TransToProtoVarType( - ctx.Input("X")->dtype()); - int customized_type_value = (input_data_type == VarType::INT8 || - input_data_type == VarType::UINT8) - ? kTransposeMKLDNNINT8 - : kTransposeMKLDNNFP32; return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - customized_type_value); + framework::LibraryType::kMKLDNN); } #endif std::string data_format = ctx.Attr("data_format"); -- GitLab