diff --git a/cmake/operators.cmake b/cmake/operators.cmake index c2d04828564e69d7ac965881057f185194aa0475..4e8c49e62b5b04eda45a9a5e8320e7dd06f3225a 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -168,6 +168,9 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n") + else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") endif() diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 2c1648c81fc999c6306d5b08bc243f3ad21fec04..a53a81c270aeec1b6ee4ed30e77526f4ea2e7977 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -290,7 +290,7 @@ struct OpKernelRegistrarFunctorEx("Input")->type(); std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout = framework::StringToDataLayout(data_format); @@ -94,11 +95,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( platform::CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; - customized_type_value = kConvMKLDNNFP32; + customized_type_value = + (input_data_type == framework::DataTypeTrait::DataType || + input_data_type == framework::DataTypeTrait::DataType) + ? kConvMKLDNNINT8 + : kConvMKLDNNFP32; } #endif - auto input_data_type = ctx.Input("Input")->type(); if (input_data_type != framework::proto::VarType::INT8 && input_data_type != framework::proto::VarType::UINT8) { auto filter_data_type = ctx.Input("Filter")->type(); diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 7ac64e6ba134c034acc58c7310cd51da0f03d16d..7994adb7c847eff0c1a011b975211ce5e71a459e 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -991,12 +991,12 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, U8, - ops::kConvMKLDNNFP32, + ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, S8, - ops::kConvMKLDNNFP32, + ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,