未验证 提交 eded6013 编写于 作者: C Chen Weihang 提交者: GitHub

Simplify conv_mkldnn op registration (#46907)

* simplify conv_mkldnn op registration

* remove custom type value in conv grad op
上级 2010bdc3
......@@ -511,13 +511,6 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
......
......@@ -229,31 +229,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
(input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType())
? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") ==
framework::DataTypeTrait<int8_t>::DataType()
? kConvMKLDNNINT8WS8
: kConvMKLDNNINT8
: kConvMKLDNNFP32;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
framework::LibraryType::kMKLDNN);
}
#endif
// #ifndef PADDLE_WITH_ASCEND_CL
// if (input_data_type == framework::proto::VarType::FP16) {
// PADDLE_ENFORCE_EQ(
// library, framework::LibraryType::kCUDNN,
// platform::errors::InvalidArgument(
// "float16 can only be used when CUDNN or NPU is used"));
// }
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -517,8 +499,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
kConvMKLDNNFP32);
framework::LibraryType::kMKLDNN);
}
#endif
......
......@@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
USE_OP_DEVICE_KERNEL(conv2d, MKLDNN);
namespace paddle {
namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册