diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 9808bd831d1598d03f40cf659e18c7fc6de1fd56..31110428be54a7a0c9198092cabb25873bdb29f0 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -138,22 +138,11 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, ops::kTransposeMKLDNNINT8, ops::TransposeMKLDNNOpKernel); -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN, - ::paddle::platform::CPUPlace, FP32, - ops::kTransposeMKLDNNFP32, - ops::TransposeMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN, - ::paddle::platform::CPUPlace, U8, - ops::kTransposeMKLDNNINT8, - ops::TransposeMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose, MKLDNN, - ::paddle::platform::CPUPlace, S8, - ops::kTransposeMKLDNNINT8, - ops::TransposeMKLDNNOpKernel); +REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, + ops::TransposeMKLDNNOpKernel); REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNGradOpKernel); + REGISTER_OP_KERNEL(transpose2_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 9444bb44ad76a314fcf0b873ed2121c751e3d4d8..fdc95c6971eef6a081f9a6d88a893f0ba3f9bd38 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -71,24 +71,16 @@ class TransposeOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; - using framework::proto::VarType; - auto input_data_type = ctx.Input("X")->type(); - customized_type_value = (input_data_type == VarType::INT8 || - input_data_type == VarType::UINT8) - ? kTransposeMKLDNNINT8 - : kTransposeMKLDNNFP32; } #endif return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_, customized_type_value); + layout_, library_); } };