From 4acc52208716da3a51ae2253663c1d1f28f258d6 Mon Sep 17 00:00:00 2001 From: liangan1 Date: Mon, 25 Feb 2019 08:46:02 +0000 Subject: [PATCH] Enable function coverage for U8/S8 ConvMKLDNNOpKernel test=develop --- cmake/operators.cmake | 3 +++ paddle/fluid/framework/op_registry.h | 2 +- paddle/fluid/operators/conv_op.cc | 8 ++++++-- paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc | 4 ++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index c2d0482856..4e8c49e62b 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -168,6 +168,9 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n") + else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") endif() diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 2c1648c81f..a53a81c270 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -290,7 +290,7 @@ struct OpKernelRegistrarFunctorEx("Input")->type(); std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout = framework::StringToDataLayout(data_format); @@ -94,11 +95,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( platform::CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; - customized_type_value = kConvMKLDNNFP32; + customized_type_value = + (input_data_type == framework::DataTypeTrait::DataType || + input_data_type == framework::DataTypeTrait::DataType) + ? kConvMKLDNNINT8 + : kConvMKLDNNFP32; } #endif - auto input_data_type = ctx.Input("Input")->type(); if (input_data_type != framework::proto::VarType::INT8 && input_data_type != framework::proto::VarType::UINT8) { auto filter_data_type = ctx.Input("Filter")->type(); diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 7ac64e6ba1..7994adb7c8 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -991,12 +991,12 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, U8, - ops::kConvMKLDNNFP32, + ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, S8, - ops::kConvMKLDNNFP32, + ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, -- GitLab