From eded6013197678a3ff726b2b5c6396c0fa45fd3a Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 14 Oct 2022 08:10:59 -0500 Subject: [PATCH] Simplify conv_mkldnn op registration (#46907) * simplify conv_mkldnn op registration * remove custom type value in conv grad op --- cmake/operators.cmake | 7 - paddle/fluid/operators/conv_op.cc | 23 +- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 667 +++++++++--------- .../operators/mkldnn/test_mkldnn_caching.cc | 2 +- 4 files changed, 321 insertions(+), 378 deletions(-) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index c3c8474b69f..5a1e4e2619c 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -511,13 +511,6 @@ function(op_library TARGET) # Append first implemented MKLDNN activation operator if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") - elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n") - file(APPEND ${pybind_file} - "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n") elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n") diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index ce335cff52d..9720abcae70 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -229,31 +229,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - int customized_type_value = - (input_data_type == framework::DataTypeTrait::DataType() || - input_data_type == framework::DataTypeTrait::DataType()) - ? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") == - framework::DataTypeTrait::DataType() - ? kConvMKLDNNINT8WS8 - : kConvMKLDNNINT8 - : kConvMKLDNNFP32; return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - customized_type_value); + framework::LibraryType::kMKLDNN); } #endif - // #ifndef PADDLE_WITH_ASCEND_CL - // if (input_data_type == framework::proto::VarType::FP16) { - // PADDLE_ENFORCE_EQ( - // library, framework::LibraryType::kCUDNN, - // platform::errors::InvalidArgument( - // "float16 can only be used when CUDNN or NPU is used")); - // } - // #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); } @@ -517,8 +499,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN, - kConvMKLDNNFP32); + framework::LibraryType::kMKLDNN); } #endif diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index cdd064467eb..2f99a4019a4 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -20,6 +20,8 @@ #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/phi/core/visit_type.h" + namespace paddle { namespace operators { namespace { @@ -774,7 +776,22 @@ class ConvMKLDNNHandlerT } // anonymous namespace -template +#define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +template class ConvMKLDNNOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -828,48 +845,52 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { ctx.HasInput("Bias") ? ctx.Input("Bias") : nullptr; auto* output = ctx.Output("Output"); - ConvMKLDNNHandlerT handler( - ctx, - dev_ctx, - mkldnn_engine, - ctx.GetPlace(), - input, - filter, - bias, - output, - ctx.InputName("Input") + ctx.InputName("Filter")); - - auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); - - auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, ctx.Attr("groups"), is_conv3d, is_test); - - std::shared_ptr dst_memory_p; - if (fuse_residual_conn) { - auto* residual_param = ctx.Input("ResidualData"); - dst_memory_p = - handler.AcquireDstMemoryWithResidual(output, residual_param); - } else { - dst_memory_p = handler.template AcquireDstMemory(output); - } - - auto conv_p = handler.AcquireForwardPrimitive(); - - std::unordered_map args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - if (bias) { - auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); - args.insert({DNNL_ARG_BIAS, *bias_memory_p}); - } - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - conv_p->execute(astream, args); - astream.wait(); + PD_VISIT_FLOAT_AND_INT8_TYPES( + filter->dtype(), "ConvMKLDNNHandlerT", ([&] { + ConvMKLDNNHandlerT handler( + ctx, + dev_ctx, + mkldnn_engine, + ctx.GetPlace(), + input, + filter, + bias, + output, + ctx.InputName("Input") + ctx.InputName("Filter")); + + auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); + + auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( + filter, ctx.Attr("groups"), is_conv3d, is_test); + + std::shared_ptr dst_memory_p; + if (fuse_residual_conn) { + auto* residual_param = ctx.Input("ResidualData"); + dst_memory_p = + handler.AcquireDstMemoryWithResidual(output, residual_param); + } else { + dst_memory_p = handler.template AcquireDstMemory(output); + } + + auto conv_p = handler.AcquireForwardPrimitive(); + + std::unordered_map args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (bias) { + auto bias_memory_p = + handler.AcquireBiasMemoryWithReorder(bias, is_test); + args.insert({DNNL_ARG_BIAS, *bias_memory_p}); + } + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + conv_p->execute(astream, args); + astream.wait(); - output->set_mem_desc(dst_memory_p->get_desc()); + output->set_mem_desc(dst_memory_p->get_desc()); + })); } template @@ -905,90 +926,113 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { ctx.HasInput("Bias") ? ctx.Input("Bias") : nullptr; auto* output = ctx.Output("Output"); - ConvMKLDNNHandlerT handler( - ctx, - dev_ctx, - mkldnn_engine, - ctx.GetPlace(), - input, - filter, - bias, - output, - ctx.InputName("Input") + ctx.InputName("Filter")); - - auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); - - const auto& scale_weights_data = - ctx.Attr>("Scale_weights"); - const bool is_multi_channel = scale_weights_data.size() > 1; - const int& groups = ctx.Attr("groups"); - int mask_reorder = - is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; - auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, groups, false, true, scale_weights_data, mask_reorder); - - std::shared_ptr dst_memory_p; - if (fuse_residual_conn) { - auto* residual_param = ctx.Input("ResidualData"); - PADDLE_ENFORCE_EQ( - output->dims(), - residual_param->dims(), - platform::errors::InvalidArgument( - "Output and elementwise parameter need to have the " - "same dimension sizes, but got output's dimension = %d" - " and residual param's dimension =%d .", - output->dims().size(), - residual_param->dims().size())); - dst_memory_p = - handler.AcquireDstMemoryWithResidual(output, residual_param); - need_s8_to_u8 = (platform::MKLDNNGetDataType() == - dnnl::memory::data_type::s8) && - unsigned_output; - } else { - dst_memory_p = handler.template AcquireDstMemory(output); - } - - auto conv_p = handler.AcquireForwardPrimitive(); - - std::unordered_map args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - if (bias) { - std::vector bias_scales; - auto p_scales_tuple = - std::make_shared>>( - std::make_tuple(static_cast(mask_reorder), bias_scales)); - if (ctx.HasAttr("Bias_scales")) { - bias_scales = ctx.Attr>("Bias_scales"); - p_scales_tuple = - std::make_shared>>( - std::make_tuple(static_cast(mask_reorder), bias_scales)); - } else { - p_scales_tuple = handler.get_int8_bias_scales(ctx); - } - auto bias_memory_p = - handler.AcquireBiasMemoryWithReorder(bias, - true, - std::get<1>(*p_scales_tuple), - std::get<0>(*p_scales_tuple)); - args.insert({DNNL_ARG_BIAS, *bias_memory_p}); - } - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - conv_p->execute(astream, args); - astream.wait(); + PD_VISIT_FLOAT_AND_INT8_TYPES( + filter->dtype(), "ConvMKLDNNHandlerT", ([&] { + ConvMKLDNNHandlerT handler( + ctx, + dev_ctx, + mkldnn_engine, + ctx.GetPlace(), + input, + filter, + bias, + output, + ctx.InputName("Input") + ctx.InputName("Filter")); + + auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); + + const auto& scale_weights_data = + ctx.Attr>("Scale_weights"); + const bool is_multi_channel = scale_weights_data.size() > 1; + const int& groups = ctx.Attr("groups"); + int mask_reorder = + is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) + : 0; + auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( + filter, groups, false, true, scale_weights_data, mask_reorder); + + std::shared_ptr dst_memory_p; + if (fuse_residual_conn) { + auto* residual_param = ctx.Input("ResidualData"); + PADDLE_ENFORCE_EQ( + output->dims(), + residual_param->dims(), + platform::errors::InvalidArgument( + "Output and elementwise parameter need to have the " + "same dimension sizes, but got output's dimension = %d" + " and residual param's dimension =%d .", + output->dims().size(), + residual_param->dims().size())); + dst_memory_p = + handler.AcquireDstMemoryWithResidual(output, residual_param); + need_s8_to_u8 = (platform::MKLDNNGetDataType() == + dnnl::memory::data_type::s8) && + unsigned_output; + } else { + dst_memory_p = handler.template AcquireDstMemory(output); + } + + auto conv_p = handler.AcquireForwardPrimitive(); + + std::unordered_map args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (bias) { + std::vector bias_scales; + auto p_scales_tuple = + std::make_shared>>( + std::make_tuple(static_cast(mask_reorder), + bias_scales)); + if (ctx.HasAttr("Bias_scales")) { + bias_scales = ctx.Attr>("Bias_scales"); + p_scales_tuple = + std::make_shared>>( + std::make_tuple(static_cast(mask_reorder), + bias_scales)); + } else { + p_scales_tuple = handler.get_int8_bias_scales(ctx); + } + auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( + bias, + true, + std::get<1>(*p_scales_tuple), + std::get<0>(*p_scales_tuple)); + args.insert({DNNL_ARG_BIAS, *bias_memory_p}); + } + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + conv_p->execute(astream, args); + astream.wait(); - if (need_s8_to_u8) { - output->mutable_data(ctx.GetPlace()); - } + if (need_s8_to_u8) { + output->mutable_data(ctx.GetPlace()); + } - output->set_mem_desc(dst_memory_p->get_desc()); + output->set_mem_desc(dst_memory_p->get_desc()); + })); } }; -template +#define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::BFLOAT16, \ + ::phi::dtype::bfloat16, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +template class ConvMKLDNNGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -1013,119 +1057,123 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel { if (!input_grad && !filter_grad) return; - // TODO(jczaja): Are all tensors really needed? - ConvMKLDNNHandlerT handler( - ctx, - dev_ctx, - ctx.GetPlace(), - input, - filter, - bias, - output_grad, - filter_grad, - input_grad, - ctx.InputName("Input") + ctx.InputName("Filter")); - - // create mkldnn memory from input tensors (data/weights) - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - if (filter_grad) { - auto src_memory_p = - handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input); - auto diff_dst_memory_p = - handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( - output_grad); - - // For convoluition with groups write filter grad into - // oneDNN buffer and then we reorder it into filter_grad tensor - int g = std::max(ctx.Attr("groups"), 1); - auto diff_weights_memory_p = - g > 1 ? handler.AcquireDiffWeightsMemory() - : handler.AcquireDiffWeightsMemory(filter_grad); - - auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive(); - - conv_bwd_weights_p->execute( - astream, - {{DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); - astream.wait(); - - // For convolution with groups convert from blocked to NCHW - // otherwise there will be problems in next operators working on this data - if (g > 1) { - // in OneDNN groups in convolution are treated as separate dimension - // which is not the case in paddlepaddle - - dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(filter->dtype())); - // for 3d conv with groups (six dimensional data reorder to goidhw) - // for 2d conv with groups (five dimensional data reorder to goihw) - // auto weights_tz = phi::vectorize(filter->dims()); - - auto weights_tz = diff_weights_memory_p->get_desc().dims(); - dnnl::memory::format_tag out_format = - weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw - : dnnl::memory::format_tag::goihw; - platform::ReorderMKLDNNHandler handler( - weights_tz, - framework::TransToProtoVarType(filter->dtype()), - in_type, - mkldnn_engine); - auto reorder_dst_memory_p = - handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace()); - - auto reorder_p = - handler.AcquireReorder(reorder_dst_memory_p, diff_weights_memory_p); - - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder_p->execute( - astream, *diff_weights_memory_p, *reorder_dst_memory_p); - astream.wait(); - } - - // So here we have a data in goihw , which can be interpreted as OIHW - // (OIDHW for conv3d) - // because filter_grad shape is set for OIHW (OIDHW for conv3d) - dnnl::memory::format_tag target_format = - weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw - : dnnl::memory::format_tag::oihw; - filter_grad->set_mem_desc( - dnnl::memory::desc(phi::vectorize(filter_grad->dims()), - in_type, - target_format)); - } else { - filter_grad->set_mem_desc(diff_weights_memory_p->get_desc()); - } - } - if (input_grad) { - auto weights_memory_p = - handler.AcquireWeightsMemoryWithReorderFromDataPrimitive( + PD_VISIT_FLOAT_AND_BF16_TYPES( + filter->dtype(), "ConvMKLDNNHandlerT", ([&] { + // TODO(jczaja): Are all tensors really needed? + ConvMKLDNNHandlerT handler( + ctx, + dev_ctx, + ctx.GetPlace(), + input, filter, - ctx.Attr("groups"), - ctx.Attr>("strides").size() == 3U); - - auto diff_dst_memory_p = - handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( - output_grad); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad); - - auto conv_bwd_data_p = handler.AcquireBackwardPrimitive(); - - conv_bwd_data_p->execute(astream, - {{DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); - astream.wait(); - - input_grad->set_mem_desc(diff_src_memory_p->get_desc()); - } + bias, + output_grad, + filter_grad, + input_grad, + ctx.InputName("Input") + ctx.InputName("Filter")); + + // create mkldnn memory from input tensors (data/weights) + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + if (filter_grad) { + auto src_memory_p = + handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input); + auto diff_dst_memory_p = + handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( + output_grad); + + // For convoluition with groups write filter grad into + // oneDNN buffer and then we reorder it into filter_grad tensor + int g = std::max(ctx.Attr("groups"), 1); + auto diff_weights_memory_p = + g > 1 ? handler.AcquireDiffWeightsMemory() + : handler.AcquireDiffWeightsMemory(filter_grad); + + auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive(); + + conv_bwd_weights_p->execute( + astream, + {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, + {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); + astream.wait(); + + // For convolution with groups convert from blocked to NCHW + // otherwise there will be problems in next operators working on + // this data + if (g > 1) { + // in OneDNN groups in convolution are treated as separate + // dimension which is not the case in paddlepaddle + + dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( + framework::TransToProtoVarType(filter->dtype())); + // for 3d conv with groups (six dimensional data reorder to + // goidhw) for 2d conv with groups (five dimensional data reorder + // to goihw) auto weights_tz = phi::vectorize(filter->dims()); + + auto weights_tz = diff_weights_memory_p->get_desc().dims(); + dnnl::memory::format_tag out_format = + weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw + : dnnl::memory::format_tag::goihw; + platform::ReorderMKLDNNHandler handler( + weights_tz, + framework::TransToProtoVarType(filter->dtype()), + in_type, + mkldnn_engine); + auto reorder_dst_memory_p = handler.AcquireDstMemory( + filter_grad, out_format, ctx.GetPlace()); + + auto reorder_p = handler.AcquireReorder(reorder_dst_memory_p, + diff_weights_memory_p); + + { + platform::RecordEvent record_reorder( + "int_reorder", + platform::TracerEventType::UserDefined, + 2, + platform::EventRole::kUniqueOp); + reorder_p->execute( + astream, *diff_weights_memory_p, *reorder_dst_memory_p); + astream.wait(); + } + + // So here we have a data in goihw , which can be interpreted as + // OIHW (OIDHW for conv3d) because filter_grad shape is set for + // OIHW (OIDHW for conv3d) + dnnl::memory::format_tag target_format = + weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw + : dnnl::memory::format_tag::oihw; + filter_grad->set_mem_desc(dnnl::memory::desc( + phi::vectorize(filter_grad->dims()), + in_type, + target_format)); + } else { + filter_grad->set_mem_desc(diff_weights_memory_p->get_desc()); + } + } + if (input_grad) { + auto weights_memory_p = + handler.AcquireWeightsMemoryWithReorderFromDataPrimitive( + filter, + ctx.Attr("groups"), + ctx.Attr>("strides").size() == 3U); + + auto diff_dst_memory_p = + handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( + output_grad); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad); + + auto conv_bwd_data_p = handler.AcquireBackwardPrimitive(); + + conv_bwd_data_p->execute(astream, + {{DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, + {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); + astream.wait(); + + input_grad->set_mem_desc(diff_src_memory_p->get_desc()); + } + })); } }; @@ -1134,119 +1182,40 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel { namespace ops = paddle::operators; -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - U8, - ops::kConvMKLDNNINT8, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - U8WS8, - ops::kConvMKLDNNINT8WS8, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - S8, - ops::kConvMKLDNNINT8, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - S8WS8, - ops::kConvMKLDNNINT8WS8, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - conv2d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - depthwise_conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - U8, - ops::kConvMKLDNNINT8, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, - MKLDNN, - ::paddle::platform::CPUPlace, - S8, - ops::kConvMKLDNNINT8, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - depthwise_conv2d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); +REGISTER_OP_KERNEL(conv2d, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::ConvMKLDNNOpKernel, + ops::ConvMKLDNNOpKernel, + ops::ConvMKLDNNOpKernel, + ops::ConvMKLDNNOpKernel); + +REGISTER_OP_KERNEL(conv2d_grad, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::ConvMKLDNNGradOpKernel, + ops::ConvMKLDNNGradOpKernel); + +REGISTER_OP_KERNEL(depthwise_conv2d, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::ConvMKLDNNOpKernel, + ops::ConvMKLDNNOpKernel, + ops::ConvMKLDNNOpKernel, + ops::ConvMKLDNNOpKernel); + +REGISTER_OP_KERNEL(depthwise_conv2d_grad, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::ConvMKLDNNGradOpKernel, + ops::ConvMKLDNNGradOpKernel); + +REGISTER_OP_KERNEL(conv3d, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::ConvMKLDNNOpKernel); + +REGISTER_OP_KERNEL(conv3d_grad, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::ConvMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index 85b7fb98d26..60c9b8f2659 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_ITSELF(conv2d); -USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32); +USE_OP_DEVICE_KERNEL(conv2d, MKLDNN); namespace paddle { namespace operators { -- GitLab