提交 4acc5220 编写于 作者: L liangan1

Enable function coverage for U8/S8 ConvMKLDNNOpKernel

test=develop
上级 3ccd8964
...@@ -168,6 +168,9 @@ function(op_library TARGET) ...@@ -168,6 +168,9 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif() endif()
......
...@@ -290,7 +290,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -290,7 +290,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
"USE_OP_DEVICE_KERNEL must be in global namespace"); \ "USE_OP_DEVICE_KERNEL must be in global namespace"); \
extern int \ extern int \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \ TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##DEFAULT_TYPE##_ = /* NOLINT */ \ UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##_ = /* NOLINT */ \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name() TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \ #define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
......
...@@ -81,6 +81,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -81,6 +81,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework::OpKernelType::kDefaultCustomizedTypeValue; framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto input_data_type = ctx.Input<Tensor>("Input")->type();
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout = framework::StringToDataLayout(data_format); framework::DataLayout layout = framework::StringToDataLayout(data_format);
...@@ -94,11 +95,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -94,11 +95,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32; customized_type_value =
(input_data_type == framework::DataTypeTrait<int8_t>::DataType ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType)
? kConvMKLDNNINT8
: kConvMKLDNNFP32;
} }
#endif #endif
auto input_data_type = ctx.Input<Tensor>("Input")->type();
if (input_data_type != framework::proto::VarType::INT8 && if (input_data_type != framework::proto::VarType::INT8 &&
input_data_type != framework::proto::VarType::UINT8) { input_data_type != framework::proto::VarType::UINT8) {
auto filter_data_type = ctx.Input<Tensor>("Filter")->type(); auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
......
...@@ -991,12 +991,12 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ...@@ -991,12 +991,12 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
::paddle::platform::CPUPlace, U8, ::paddle::platform::CPUPlace, U8,
ops::kConvMKLDNNFP32, ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<uint8_t, float>); ops::ConvMKLDNNOpKernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
::paddle::platform::CPUPlace, S8, ::paddle::platform::CPUPlace, S8,
ops::kConvMKLDNNFP32, ops::kConvMKLDNNINT8,
ops::ConvMKLDNNOpKernel<int8_t, float>); ops::ConvMKLDNNOpKernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册