/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/phi/core/expect.h" #include "paddle/phi/core/visit_type.h" namespace paddle { namespace operators { namespace { inline MKLDNNMemoryFormat GetWeightsFormat(const int groups, const bool is_conv3d) { if (is_conv3d) { return (groups == 1) ? MKLDNNMemoryFormat::oidhw : MKLDNNMemoryFormat::goidhw; } else { return (groups == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw; } } static dnnl::memory::data_type GetDstType( bool is_int8, bool is_bfloat16, bool force_fp32_output, std::string fuse_activation, bool fuse_residual_conn, const phi::DenseTensor* residual_param) { auto dst_dt = dnnl::memory::data_type::f32; if (is_int8) { dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") ? dnnl::memory::data_type::u8 : dnnl::memory::data_type::s8; if (force_fp32_output) { dst_dt = dnnl::memory::data_type::f32; } if (fuse_residual_conn && residual_param) { auto residual_dt = framework::ToMKLDNNDataType( framework::TransToProtoVarType(residual_param->dtype())); if (dst_dt != residual_dt) dst_dt = residual_dt; } } else { if (!force_fp32_output && is_bfloat16) { dst_dt = dnnl::memory::data_type::bf16; if (fuse_residual_conn && residual_param) { dst_dt = framework::ToMKLDNNDataType( framework::TransToProtoVarType(residual_param->dtype())); } } } return dst_dt; } template class ConvMKLDNNHandlerT : public platform::MKLDNNHandlerT { public: ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx, const platform::MKLDNNDeviceContext& dev_ctx, const dnnl::engine mkldnn_engine, platform::Place cpu_place, const phi::DenseTensor* input, const phi::DenseTensor* filter, const phi::DenseTensor* bias, phi::DenseTensor* output, const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, platform::CreateKey( dev_ctx, phi::vectorize(input->dims()), unique_name)) { if (unlikely(!this->isCached())) { PADDLE_ENFORCE_EQ( input->layout(), phi::DataLayout::kMKLDNN, platform::errors::InvalidArgument( "The input tensor's layout should be %d, but got %d.", phi::DataLayout::kMKLDNN, input->layout())); PADDLE_ENFORCE_EQ( filter->layout(), phi::DataLayout::kMKLDNN, platform::errors::InvalidArgument( "The Filter tensor's layout should be %d, but got %d.", phi::DataLayout::kMKLDNN, filter->layout())); PADDLE_ENFORCE_GE( input->dims().size(), 4, platform::errors::InvalidArgument( "Input must be with 4 or 5 dimensions, i.e. NCHW or " "NCDHW, but got dimension = %d .", input->dims().size())); PADDLE_ENFORCE_LE( input->dims().size(), 5, platform::errors::InvalidArgument( "Input must be with 4 or 5 dimensions, i.e. NCHW or " "NCDHW, but got dimension = %d .", input->dims().size())); PADDLE_ENFORCE_GE( filter->dims().size(), 4, platform::errors::InvalidArgument( "Filter must be with 4 or 5 dimensions, i.e. OIHW or " "OIDHW, but got dimension = %d .", filter->dims().size())); PADDLE_ENFORCE_LE( filter->dims().size(), 5, platform::errors::InvalidArgument( "Filter must be with 4 or 5 dimensions, i.e. OIHW or " "OIDHW, but got dimension = %d .", filter->dims().size())); if (bias) { PADDLE_ENFORCE_EQ( bias->layout(), phi::DataLayout::kMKLDNN, platform::errors::InvalidArgument( "The Bias tensor's layout should be %d, but got %d.", phi::DataLayout::kMKLDNN, bias->layout())); PADDLE_ENFORCE_EQ(bias->dims().size(), 1, platform::errors::InvalidArgument( "Bias must only have 1 dimension, " "i.e. X, but got dimension = %d .", bias->dims().size())); } const int groups = ctx.Attr("groups"); const std::string padding_algorithm = ctx.Attr("padding_algorithm"); const auto input_dims = input->dims(); const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size()); const auto filter_dims = filter->dims(); const auto filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size()); const auto ksize = phi::vectorize(filter_data_dims); const bool is_test = ctx.Attr("is_test"); auto strides_temp = ctx.Attr>("strides"); std::vector strides(begin(strides_temp), end(strides_temp)); auto paddings_temp = ctx.Attr>("paddings"); std::vector paddings(begin(paddings_temp), end(paddings_temp)); auto dilations_temp = ctx.Attr>("dilations"); std::vector dilations(begin(dilations_temp), end(dilations_temp)); UpdatePaddingAndDilation( &paddings, &dilations, padding_algorithm, data_dims, strides, ksize); std::transform( dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) { return i - 1; }); const auto src_tz = phi::vectorize(input->dims()); auto weights_tz = phi::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); const auto dst_tz = phi::vectorize(output->dims()); const dnnl::memory::dims stride_dims = strides; const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); const dnnl::memory::dims dilations_dims = dilations; /* create memory descriptor for convolution without specified format * ('any') which lets a primitive (convolution in this case) choose * the memory format preferred for best performance */ auto chosen_memory_format = MKLDNNMemoryFormat::any; auto data_type = dnnl::memory::data_type::f32; if (ctx.Attr("mkldnn_data_type") == "bfloat16" || std::is_same::value) data_type = dnnl::memory::data_type::bf16; dnnl::memory::desc src_md, weights_md; if (platform::is_int8()) { src_md = platform::MKLDNNMemDesc( src_tz, framework::ToMKLDNNDataType( framework::TransToProtoVarType(input->dtype())), chosen_memory_format); weights_md = platform::MKLDNNMemDesc( weights_tz, dnnl::memory::data_type::s8, chosen_memory_format); } else { src_md = platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format); weights_md = platform::MKLDNNMemDesc( weights_tz, data_type, MKLDNNMemoryFormat::any); } const auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference : dnnl::prop_kind::forward_training; const dnnl::primitive_attr conv_attr = CreateConvAttrs(ctx); if (bias) { auto bias_tz = phi::vectorize(bias->dims()); dnnl::memory::desc bias_md; if (platform::is_int8()) { bias_md = platform::MKLDNNMemDesc( bias_tz, dnnl::memory::data_type::s32, MKLDNNMemoryFormat::x); } else { bias_md = platform::MKLDNNMemDesc( bias_tz, data_type, MKLDNNMemoryFormat::x); } this->AcquireForwardPrimitiveDescriptor( conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } else { this->AcquireForwardPrimitiveDescriptor( conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } } } ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx, const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const phi::DenseTensor* in, const phi::DenseTensor* filter, const phi::DenseTensor* bias, const phi::DenseTensor* out_grad, phi::DenseTensor* filter_grad, phi::DenseTensor* in_x_grad, const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, platform::CreateKey( dev_ctx, phi::vectorize(in->dims()), unique_name)) { if (unlikely(!this->isBwdCached())) { PADDLE_ENFORCE_EQ( in->layout(), phi::DataLayout::kMKLDNN, platform::errors::InvalidArgument( "The input tensor's layout should be %d, but got %d.", phi::DataLayout::kMKLDNN, in->layout())); PADDLE_ENFORCE_EQ( filter->layout(), phi::DataLayout::kMKLDNN, platform::errors::InvalidArgument( "The filter tensor's layout should be %d, but got %d.", phi::DataLayout::kMKLDNN, filter->layout())); PADDLE_ENFORCE_EQ( out_grad->layout(), phi::DataLayout::kMKLDNN, platform::errors::InvalidArgument( "The output_grad tensor's layout should be %d, but got %d.", phi::DataLayout::kMKLDNN, out_grad->layout())); PADDLE_ENFORCE_EQ( ctx.Attr("is_test"), false, platform::errors::InvalidArgument( "is_test attribute should be set to False in training phase.")); std::vector strides_temp = ctx.Attr>("strides"); std::vector strides(begin(strides_temp), end(strides_temp)); std::vector paddings_temp = ctx.Attr>("paddings"); std::vector paddings(begin(paddings_temp), end(paddings_temp)); std::vector dilations_temp = ctx.Attr>("dilations"); std::vector dilations(begin(dilations_temp), end(dilations_temp)); auto input_dims = in->dims(); auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size()); auto filter_dims = filter->dims(); auto filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size()); auto ksize = phi::vectorize(filter_data_dims); std::string padding_algorithm = ctx.Attr("padding_algorithm"); UpdatePaddingAndDilation( &paddings, &dilations, padding_algorithm, data_dims, strides, ksize); auto src_tz = phi::vectorize(in->dims()); auto weights_tz = phi::vectorize(filter->dims()); int groups = ctx.Attr("groups"); int g = std::max(groups, 1); platform::GetGroupConvWeightsTz(weights_tz, g); auto dst_tz = phi::vectorize(out_grad->dims()); /* create memory descriptor for conv backward without specified format * ('any') which lets a primitive (conv backward in this case) choose * the memory format preferred for best performance */ const auto chosen_memory_format = MKLDNNMemoryFormat::any; const auto weights_format = MKLDNNMemoryFormat::any; auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); const auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto diff_src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( weights_tz, platform::MKLDNNGetDataType(), weights_format); auto diff_weights_md = platform::MKLDNNMemDesc( weights_tz, platform::MKLDNNGetDataType(), weights_format); auto diff_dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); std::transform( dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) { return i - 1; }); const dnnl::memory::dims dilations_dims = dilations; const dnnl::memory::dims stride_dims = strides; // Recreating FWD PD. For training there are no post ops in convolution dnnl::primitive_attr conv_attr; if (bias) { auto bias_tz = phi::vectorize(bias->dims()); dnnl::memory::desc bias_md; if (platform::is_int8()) { bias_md = platform::MKLDNNMemDesc( bias_tz, dnnl::memory::data_type::s32, MKLDNNMemoryFormat::x); } else { bias_md = platform::MKLDNNMemDesc( bias_tz, dnnl::memory::data_type::f32, MKLDNNMemoryFormat::x); } this->AcquireForwardPrimitiveDescriptor( conv_attr, dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } else { this->AcquireForwardPrimitiveDescriptor( conv_attr, dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } this->AcquireBackwardPrimitiveDescriptor( dnnl::algorithm::convolution_direct, diff_src_md, weights_md, diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); this->AcquireBackwardWeightsPrimitiveDescriptor( dnnl::algorithm::convolution_direct, src_md, diff_weights_md, diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } } std::shared_ptr>> get_int8_bias_scales( const framework::ExecutionContext& ctx) { // Get scales int8 bias key const std::string key_bs = this->key_ + "@bs"; // Scales for int8 bias are to be cached to avoid // computing them each iteration auto bias_scale_tuple = std::static_pointer_cast>>( this->dev_ctx_.GetBlob(key_bs)); if (bias_scale_tuple) return bias_scale_tuple; const auto* filter = ctx.Input("Filter"); const auto& weights_tz = phi::vectorize(filter->dims()); const int groups = std::max(ctx.Attr("groups"), 1); const auto& scale_weights_data = ctx.Attr>("Scale_weights"); const auto& scale_in_data = ctx.Attr("Scale_in"); bool is_multi_channel = scale_weights_data.size() > 1; int mask_reorder = is_multi_channel ? 1 << 0 : 1; int count = 1; if (is_multi_channel) { count *= weights_tz[0]; if (groups > 1) { count *= weights_tz[1]; } } bias_scale_tuple = std::make_shared>>(std::make_tuple( static_cast(mask_reorder), std::vector(count))); for (int i = 0; i < count; i++) { std::get<1>(*bias_scale_tuple)[i] = scale_in_data * scale_weights_data[i]; } this->dev_ctx_.SetBlob(key_bs, bias_scale_tuple); return bias_scale_tuple; } std::tuple, float> get_int8_scales( const framework::ExecutionContext& ctx) const { const auto* filter = ctx.Input("Filter"); const auto& weights_tz = phi::vectorize(filter->dims()); const bool& force_fp32_output = ctx.Attr("force_fp32_output"); const bool& fuse_residual_conn = ctx.Attr("fuse_residual_connection"); const int groups = std::max(ctx.Attr("groups"), 1); const auto& scale_in_data = ctx.Attr("Scale_in"); const auto& scale_in_eltwise_data = ctx.Attr("Scale_in_eltwise"); auto scale_weights_data = ctx.Attr>("Scale_weights"); bool is_multi_channel = scale_weights_data.size() > 1; bool has_activation = !ctx.Attr("fuse_activation").empty(); float activation_scale = (!force_fp32_output && has_activation) ? ctx.Attr("Scale_out") : 1.0f; float scale_out_data = (force_fp32_output || has_activation) ? 1.0f : ctx.Attr("Scale_out"); float sum_scale = fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; int count = is_multi_channel ? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0]) : 1; std::vector output_shift_scale(count); #pragma omp parallel for if (count > 50) for (int i = 0; i < count; i++) { if (scale_weights_data[i] == 0.0) // weights data will contain 0 in some models, then weights // scale couldn't be calculated output_shift_scale[i] = scale_out_data; else output_shift_scale[i] = static_cast(static_cast(scale_out_data) / (static_cast(scale_in_data) * static_cast(scale_weights_data[i]))); } return std::make_tuple(sum_scale, output_shift_scale, activation_scale); } dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) { dnnl::primitive_attr conv_attr; dnnl::post_ops post_operations; const bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); float sum_scale = 1.0f; float activation_scale = 1.0f; std::vector output_shift_scale; if (platform::is_int8()) { if (ctx.HasAttr("Sum_scale")) { sum_scale = ctx.Attr("Sum_scale"); activation_scale = ctx.Attr("Activation_scale"); output_shift_scale = ctx.Attr>("Output_shift_scale"); } else { std::tie(sum_scale, output_shift_scale, activation_scale) = get_int8_scales(ctx); } if (output_shift_scale.size() > 0) { int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; conv_attr.set_output_scales(mask, output_shift_scale); } } // Fusion with Elementwise layer relies on adding a sum post-operation with // the scale parameter. It is assumed that when fuse_residual_connection is // true, the output tensor contains the data coming from residual // connection. The result of this post_op is: // Output = scale * Output + Conv_Out. if (fuse_residual_conn) { post_operations.append_sum(sum_scale); } platform::AppendActivation(ctx, post_operations, activation_scale); conv_attr.set_post_ops(post_operations); return conv_attr; } std::shared_ptr AcquireWeightsMemoryWithReorderFromDataPrimitive( const phi::DenseTensor* filter, const int groups, const bool is_conv3d) { const K* filter_data = filter->data(); auto weights_tz = phi::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); auto user_src_md = platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType(), GetWeightsFormat(groups, is_conv3d)); return this->AcquireMemoryWithReorder( user_src_md, this->bwd_pd_->weights_desc(), platform::to_void_cast(filter_data), "@weights_mem_d_p", false); } std::shared_ptr AcquireSrcMemoryWithReorder( const phi::DenseTensor* input) { return this->AcquireMemoryWithReorderPrimitive(input, "@src_mem_p_user", "@src_mem_p_target", "@src_mem_p", this->fwd_pd_->src_desc()); } std::shared_ptr AcquireSrcMemoryWithReorderFromWeightsPrimitive( const phi::DenseTensor* input) { return this->AcquireMemoryWithReorderPrimitive(input, "@src_mem_w_p_user", "@src_mem_w_p_target", "@src_mem_w_p", this->bwd_w_pd_->src_desc()); } std::shared_ptr AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( const phi::DenseTensor* out_grad) { return this->AcquireMemoryWithReorderPrimitive( out_grad, "@diff_dst_mem_w_p_user", "@diff_dst_mem_w_p_target", "@diff_dst_mem_w_p", this->bwd_w_pd_->diff_dst_desc()); } std::shared_ptr AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( const phi::DenseTensor* out_grad) { return this->AcquireMemoryWithReorderPrimitive( out_grad, "@diff_dst_mem_p_user", "@diff_dst_mem_p_target", "@diff_dst_mem_p", this->bwd_pd_->diff_dst_desc()); } std::shared_ptr AcquireMemoryWithReorderPrimitive( const phi::DenseTensor* in_mem, const char* key_mem_user, const char* key_mem_target, const char* key_mem, const dnnl::memory::desc& mem_md) { const T* in_mem_data = in_mem->data(); const std::string user_key_suffix{key_mem_user}; auto user_mem_p = this->AcquireMemory(user_key_suffix); if (!user_mem_p) { return this->AcquireMemoryWithReorder( in_mem->mem_desc(), mem_md, platform::to_void_cast(in_mem_data), key_mem); } else { const std::string target_key_suffix{key_mem_target}; const auto target_mem_p = this->AcquireMemory(target_key_suffix); user_mem_p->set_data_handle(platform::to_void_cast(in_mem_data)); if (user_mem_p != target_mem_p) { this->AcquireReorder(user_mem_p, target_mem_p); } return target_mem_p; } } std::shared_ptr AcquireWeightsMemoryWithReorder( const phi::DenseTensor* filter, const int groups, const bool is_conv3d, const bool is_test, const std::vector& scale_data = {1.0f}, int mask = 0) { // This is workaround to make execution faster, delete // if statement after including md inside Tensor auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); if (is_test && weights_mem_p) { return weights_mem_p; } else if (is_test) { const K* filter_data = filter->data(); auto weights_tz = phi::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); auto user_src_md = platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType(), GetWeightsFormat(groups, is_conv3d)); return this->AcquireMemoryWithReorder( user_src_md, this->fwd_pd_->weights_desc(), platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, scale_data, mask); } else { const T* filter_data = filter->data(); auto weights_tz = phi::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); auto user_src_md = platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType(), GetWeightsFormat(groups, is_conv3d)); return this->AcquireMemoryWithReorder( user_src_md, this->fwd_pd_->weights_desc(), platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, scale_data, mask); } } std::shared_ptr AcquireBiasMemoryWithReorder( const phi::DenseTensor* bias, const bool is_test, const std::vector& scale_data = {1.0f}, int mask = 0) { auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); if (is_test && bias_mem_p) { return bias_mem_p; } else { // if K is int8 (weights are int8) then biases are int32 using K_Bias = typename std:: conditional::value, int32_t, K>::type; if (std::is_same::value && bias->dtype() != phi::DataType::INT32) { LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype(); } const K_Bias* bias_data = bias->data(); return this->AcquireMemoryWithReorder( bias->mem_desc(), this->fwd_pd_->bias_desc(), platform::to_void_cast(bias_data), "@bias_mem_p", is_test, {}, scale_data, mask); } } std::shared_ptr AcquireResidualMemory( const phi::DenseTensor* residual_param) { void* residual_data = framework::TransToProtoVarType(residual_param->dtype()) == framework::DataTypeTrait::DataType() ? platform::to_void_cast(residual_param->data()) : platform::to_void_cast(residual_param->data()); auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p"); if (residual_mem_p) { residual_mem_p->set_data_handle(residual_data); return residual_mem_p; } else { return this->AcquireMemoryFromPrimitive(residual_param->mem_desc(), residual_data, "@user_residual_data_mem_p"); } } std::shared_ptr AcquireDstMemoryWithResidual( phi::DenseTensor* output, const phi::DenseTensor* residual_param) { std::shared_ptr dst_memory_p; if (residual_param->mem_desc() != this->fwd_pd_->dst_desc()) { auto residual_memory_p = this->AcquireResidualMemory(residual_param); dst_memory_p = this->template AcquireDstMemory(output); this->AcquireReorder(residual_memory_p, dst_memory_p); } else { // Changing ShareDataWith to TensorCopy results in performance drop // on ResNet architectures // (https://github.com/PaddlePaddle/Paddle/issues/22964) output->ShareDataWith(*residual_param); dst_memory_p = this->template AcquireDstMemory(output); } return dst_memory_p; } }; } // anonymous namespace #define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \ [&] { \ const auto& __dtype__ = TYPE; \ switch (__dtype__) { \ PD_PRIVATE_CASE_TYPE( \ NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ PD_PRIVATE_CASE_TYPE(NAME, \ ::paddle::DataType::BFLOAT16, \ ::phi::dtype::bfloat16, \ __VA_ARGS__) \ default: \ PD_THROW("function " #NAME " is not implemented for data type `", \ __dtype__, \ "`"); \ } \ }() template class ConvMKLDNNGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, platform::errors::PreconditionNotMet( "Operator DNNL ConvGrad must use CPUPlace")); auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); const phi::DenseTensor* input = ctx.Input("Input"); const phi::DenseTensor* filter = ctx.Input("Filter"); const phi::DenseTensor* bias = ctx.HasInput("Bias") ? ctx.Input("Bias") : nullptr; const phi::DenseTensor* output_grad = ctx.Input(framework::GradVarName("Output")); phi::DenseTensor* input_grad = ctx.Output(framework::GradVarName("Input")); phi::DenseTensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); if (!input_grad && !filter_grad) return; PD_VISIT_FLOAT_AND_BF16_TYPES( filter->dtype(), "ConvMKLDNNHandlerT", ([&] { // TODO(jczaja): Are all tensors really needed? ConvMKLDNNHandlerT handler( ctx, dev_ctx, ctx.GetPlace(), input, filter, bias, output_grad, filter_grad, input_grad, ctx.InputName("Input") + ctx.InputName("Filter")); // create mkldnn memory from input tensors (data/weights) auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); if (filter_grad) { auto src_memory_p = handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input); auto diff_dst_memory_p = handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( output_grad); // For convoluition with groups write filter grad into // oneDNN buffer and then we reorder it into filter_grad tensor int g = std::max(ctx.Attr("groups"), 1); auto diff_weights_memory_p = g > 1 ? handler.AcquireDiffWeightsMemory() : handler.AcquireDiffWeightsMemory(filter_grad); auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive(); conv_bwd_weights_p->execute( astream, {{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); astream.wait(); // For convolution with groups convert from blocked to NCHW // otherwise there will be problems in next operators working on // this data if (g > 1) { // in OneDNN groups in convolution are treated as separate // dimension which is not the case in paddlepaddle dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( framework::TransToProtoVarType(filter->dtype())); // for 3d conv with groups (six dimensional data reorder to // goidhw) for 2d conv with groups (five dimensional data reorder // to goihw) auto weights_tz = phi::vectorize(filter->dims()); auto weights_tz = diff_weights_memory_p->get_desc().dims(); dnnl::memory::format_tag out_format = weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw : dnnl::memory::format_tag::goihw; platform::ReorderMKLDNNHandler handler( weights_tz, framework::TransToProtoVarType(filter->dtype()), in_type, mkldnn_engine); auto reorder_dst_memory_p = handler.AcquireDstMemory( filter_grad, out_format, ctx.GetPlace()); auto reorder_p = handler.AcquireReorder(reorder_dst_memory_p, diff_weights_memory_p); { platform::RecordEvent record_reorder( "int_reorder", platform::TracerEventType::UserDefined, 2, platform::EventRole::kUniqueOp); reorder_p->execute( astream, *diff_weights_memory_p, *reorder_dst_memory_p); astream.wait(); } // So here we have a data in goihw , which can be interpreted as // OIHW (OIDHW for conv3d) because filter_grad shape is set for // OIHW (OIDHW for conv3d) dnnl::memory::format_tag target_format = weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw : dnnl::memory::format_tag::oihw; filter_grad->set_mem_desc(dnnl::memory::desc( phi::vectorize(filter_grad->dims()), in_type, target_format)); } else { filter_grad->set_mem_desc(diff_weights_memory_p->get_desc()); } } if (input_grad) { auto weights_memory_p = handler.AcquireWeightsMemoryWithReorderFromDataPrimitive( filter, ctx.Attr("groups"), ctx.Attr>("strides").size() == 3U); auto diff_dst_memory_p = handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( output_grad); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad); auto conv_bwd_data_p = handler.AcquireBackwardPrimitive(); conv_bwd_data_p->execute(astream, {{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); astream.wait(); input_grad->set_mem_desc(diff_src_memory_p->get_desc()); } })); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_KERNEL(depthwise_conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::ConvMKLDNNGradOpKernel, ops::ConvMKLDNNGradOpKernel); REGISTER_OP_KERNEL(conv3d_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::ConvMKLDNNGradOpKernel);