From b0c3856842c0563df3fff4e52b6b6186335b3f86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Mon, 7 Nov 2022 03:04:51 +0100 Subject: [PATCH] [PHI] Migrate depthwise_conv2d_grad and conv3d_grad kernels (#47686) * remove fwd funcs * migrate conv grads --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 955 ------------------ paddle/phi/kernels/onednn/conv_grad_kernel.cc | 72 +- 2 files changed, 70 insertions(+), 957 deletions(-) delete mode 100644 paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc deleted file mode 100644 index a0defac03f..0000000000 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ /dev/null @@ -1,955 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include - -#include "paddle/fluid/operators/conv_op.h" -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/mkldnn_helper.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" -#include "paddle/phi/core/expect.h" - -#include "paddle/phi/core/visit_type.h" - -namespace paddle { -namespace operators { -namespace { - -inline MKLDNNMemoryFormat GetWeightsFormat(const int groups, - const bool is_conv3d) { - if (is_conv3d) { - return (groups == 1) ? MKLDNNMemoryFormat::oidhw - : MKLDNNMemoryFormat::goidhw; - } else { - return (groups == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw; - } -} - -static dnnl::memory::data_type GetDstType( - bool is_int8, - bool is_bfloat16, - bool force_fp32_output, - std::string fuse_activation, - bool fuse_residual_conn, - const phi::DenseTensor* residual_param) { - auto dst_dt = dnnl::memory::data_type::f32; - if (is_int8) { - dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") - ? dnnl::memory::data_type::u8 - : dnnl::memory::data_type::s8; - if (force_fp32_output) { - dst_dt = dnnl::memory::data_type::f32; - } - if (fuse_residual_conn && residual_param) { - auto residual_dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(residual_param->dtype())); - if (dst_dt != residual_dt) dst_dt = residual_dt; - } - } else { - if (!force_fp32_output && is_bfloat16) { - dst_dt = dnnl::memory::data_type::bf16; - if (fuse_residual_conn && residual_param) { - dst_dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(residual_param->dtype())); - } - } - } - return dst_dt; -} - -template -class ConvMKLDNNHandlerT - : public platform::MKLDNNHandlerT { - public: - ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx, - const platform::MKLDNNDeviceContext& dev_ctx, - const dnnl::engine mkldnn_engine, - platform::Place cpu_place, - const phi::DenseTensor* input, - const phi::DenseTensor* filter, - const phi::DenseTensor* bias, - phi::DenseTensor* output, - const std::string& unique_name) - : platform::MKLDNNHandlerT( - dev_ctx, - mkldnn_engine, - cpu_place, - platform::CreateKey( - dev_ctx, phi::vectorize(input->dims()), unique_name)) { - if (unlikely(!this->isCached())) { - PADDLE_ENFORCE_EQ( - input->layout(), - phi::DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The input tensor's layout should be %d, but got %d.", - phi::DataLayout::kMKLDNN, - input->layout())); - - PADDLE_ENFORCE_EQ( - filter->layout(), - phi::DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The Filter tensor's layout should be %d, but got %d.", - phi::DataLayout::kMKLDNN, - filter->layout())); - - PADDLE_ENFORCE_GE( - input->dims().size(), - 4, - platform::errors::InvalidArgument( - "Input must be with 4 or 5 dimensions, i.e. NCHW or " - "NCDHW, but got dimension = %d .", - input->dims().size())); - PADDLE_ENFORCE_LE( - input->dims().size(), - 5, - platform::errors::InvalidArgument( - "Input must be with 4 or 5 dimensions, i.e. NCHW or " - "NCDHW, but got dimension = %d .", - input->dims().size())); - - PADDLE_ENFORCE_GE( - filter->dims().size(), - 4, - platform::errors::InvalidArgument( - "Filter must be with 4 or 5 dimensions, i.e. OIHW or " - "OIDHW, but got dimension = %d .", - filter->dims().size())); - PADDLE_ENFORCE_LE( - filter->dims().size(), - 5, - platform::errors::InvalidArgument( - "Filter must be with 4 or 5 dimensions, i.e. OIHW or " - "OIDHW, but got dimension = %d .", - filter->dims().size())); - - if (bias) { - PADDLE_ENFORCE_EQ( - bias->layout(), - phi::DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The Bias tensor's layout should be %d, but got %d.", - phi::DataLayout::kMKLDNN, - bias->layout())); - - PADDLE_ENFORCE_EQ(bias->dims().size(), - 1, - platform::errors::InvalidArgument( - "Bias must only have 1 dimension, " - "i.e. X, but got dimension = %d .", - bias->dims().size())); - } - - const int groups = ctx.Attr("groups"); - const std::string padding_algorithm = - ctx.Attr("padding_algorithm"); - - const auto input_dims = input->dims(); - const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size()); - const auto filter_dims = filter->dims(); - const auto filter_data_dims = - phi::slice_ddim(filter_dims, 2, filter_dims.size()); - - const auto ksize = phi::vectorize(filter_data_dims); - const bool is_test = ctx.Attr("is_test"); - - auto strides_temp = ctx.Attr>("strides"); - std::vector strides(begin(strides_temp), end(strides_temp)); - - auto paddings_temp = ctx.Attr>("paddings"); - std::vector paddings(begin(paddings_temp), end(paddings_temp)); - - auto dilations_temp = ctx.Attr>("dilations"); - std::vector dilations(begin(dilations_temp), - end(dilations_temp)); - - UpdatePaddingAndDilation( - &paddings, &dilations, padding_algorithm, data_dims, strides, ksize); - - std::transform( - dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) { - return i - 1; - }); - - const auto src_tz = phi::vectorize(input->dims()); - - auto weights_tz = phi::vectorize(filter->dims()); - platform::GetGroupConvWeightsTz(weights_tz, groups); - - const auto dst_tz = phi::vectorize(output->dims()); - - const dnnl::memory::dims stride_dims = strides; - const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); - const dnnl::memory::dims dilations_dims = dilations; - - /* create memory descriptor for convolution without specified format - * ('any') which lets a primitive (convolution in this case) choose - * the memory format preferred for best performance - */ - auto chosen_memory_format = MKLDNNMemoryFormat::any; - auto data_type = dnnl::memory::data_type::f32; - if (ctx.Attr("mkldnn_data_type") == "bfloat16" || - std::is_same::value) - data_type = dnnl::memory::data_type::bf16; - - dnnl::memory::desc src_md, weights_md; - if (platform::is_int8()) { - src_md = platform::MKLDNNMemDesc( - src_tz, - framework::ToMKLDNNDataType( - framework::TransToProtoVarType(input->dtype())), - chosen_memory_format); - weights_md = platform::MKLDNNMemDesc( - weights_tz, dnnl::memory::data_type::s8, chosen_memory_format); - } else { - src_md = - platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format); - weights_md = platform::MKLDNNMemDesc( - weights_tz, data_type, MKLDNNMemoryFormat::any); - } - - const auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference - : dnnl::prop_kind::forward_training; - - const dnnl::primitive_attr conv_attr = CreateConvAttrs(ctx); - - if (bias) { - auto bias_tz = phi::vectorize(bias->dims()); - dnnl::memory::desc bias_md; - if (platform::is_int8()) { - bias_md = platform::MKLDNNMemDesc( - bias_tz, dnnl::memory::data_type::s32, MKLDNNMemoryFormat::x); - } else { - bias_md = platform::MKLDNNMemDesc( - bias_tz, data_type, MKLDNNMemoryFormat::x); - } - - this->AcquireForwardPrimitiveDescriptor( - conv_attr, - fwd_prop_kind, - dnnl::algorithm::convolution_direct, - src_md, - weights_md, - bias_md, - dst_md, - stride_dims, - dilations_dims, - mkldnn_paddings[0], - mkldnn_paddings[1]); - } else { - this->AcquireForwardPrimitiveDescriptor( - conv_attr, - fwd_prop_kind, - dnnl::algorithm::convolution_direct, - src_md, - weights_md, - dst_md, - stride_dims, - dilations_dims, - mkldnn_paddings[0], - mkldnn_paddings[1]); - } - } - } - - ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx, - const platform::MKLDNNDeviceContext& dev_ctx, - platform::Place cpu_place, - const phi::DenseTensor* in, - const phi::DenseTensor* filter, - const phi::DenseTensor* bias, - const phi::DenseTensor* out_grad, - phi::DenseTensor* filter_grad, - phi::DenseTensor* in_x_grad, - const std::string& unique_name) - : platform::MKLDNNHandlerT( - dev_ctx, - dev_ctx.GetEngine(), - cpu_place, - platform::CreateKey( - dev_ctx, phi::vectorize(in->dims()), unique_name)) { - if (unlikely(!this->isBwdCached())) { - PADDLE_ENFORCE_EQ( - in->layout(), - phi::DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The input tensor's layout should be %d, but got %d.", - phi::DataLayout::kMKLDNN, - in->layout())); - - PADDLE_ENFORCE_EQ( - filter->layout(), - phi::DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The filter tensor's layout should be %d, but got %d.", - phi::DataLayout::kMKLDNN, - filter->layout())); - - PADDLE_ENFORCE_EQ( - out_grad->layout(), - phi::DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The output_grad tensor's layout should be %d, but got %d.", - phi::DataLayout::kMKLDNN, - out_grad->layout())); - - PADDLE_ENFORCE_EQ( - ctx.Attr("is_test"), - false, - platform::errors::InvalidArgument( - "is_test attribute should be set to False in training phase.")); - - std::vector strides_temp = ctx.Attr>("strides"); - std::vector strides(begin(strides_temp), end(strides_temp)); - - std::vector paddings_temp = ctx.Attr>("paddings"); - std::vector paddings(begin(paddings_temp), end(paddings_temp)); - - std::vector dilations_temp = ctx.Attr>("dilations"); - std::vector dilations(begin(dilations_temp), - end(dilations_temp)); - - auto input_dims = in->dims(); - auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size()); - auto filter_dims = filter->dims(); - auto filter_data_dims = - phi::slice_ddim(filter_dims, 2, filter_dims.size()); - auto ksize = phi::vectorize(filter_data_dims); - - std::string padding_algorithm = - ctx.Attr("padding_algorithm"); - UpdatePaddingAndDilation( - &paddings, &dilations, padding_algorithm, data_dims, strides, ksize); - - auto src_tz = phi::vectorize(in->dims()); - auto weights_tz = phi::vectorize(filter->dims()); - - int groups = ctx.Attr("groups"); - int g = std::max(groups, 1); - platform::GetGroupConvWeightsTz(weights_tz, g); - auto dst_tz = phi::vectorize(out_grad->dims()); - - /* create memory descriptor for conv backward without specified format - * ('any') which lets a primitive (conv backward in this case) choose - * the memory format preferred for best performance - */ - const auto chosen_memory_format = MKLDNNMemoryFormat::any; - const auto weights_format = MKLDNNMemoryFormat::any; - - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - const auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - auto diff_src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), weights_format); - auto diff_weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), weights_format); - auto diff_dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - - auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); - std::transform( - dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) { - return i - 1; - }); - const dnnl::memory::dims dilations_dims = dilations; - - const dnnl::memory::dims stride_dims = strides; - // Recreating FWD PD. For training there are no post ops in convolution - dnnl::primitive_attr conv_attr; - if (bias) { - auto bias_tz = phi::vectorize(bias->dims()); - dnnl::memory::desc bias_md; - if (platform::is_int8()) { - bias_md = platform::MKLDNNMemDesc( - bias_tz, dnnl::memory::data_type::s32, MKLDNNMemoryFormat::x); - } else { - bias_md = platform::MKLDNNMemDesc( - bias_tz, dnnl::memory::data_type::f32, MKLDNNMemoryFormat::x); - } - - this->AcquireForwardPrimitiveDescriptor( - conv_attr, - dnnl::prop_kind::forward_training, - dnnl::algorithm::convolution_direct, - src_md, - weights_md, - bias_md, - dst_md, - stride_dims, - dilations_dims, - mkldnn_paddings[0], - mkldnn_paddings[1]); - } else { - this->AcquireForwardPrimitiveDescriptor( - conv_attr, - dnnl::prop_kind::forward_training, - dnnl::algorithm::convolution_direct, - src_md, - weights_md, - dst_md, - stride_dims, - dilations_dims, - mkldnn_paddings[0], - mkldnn_paddings[1]); - } - - this->AcquireBackwardPrimitiveDescriptor( - dnnl::algorithm::convolution_direct, - diff_src_md, - weights_md, - diff_dst_md, - strides, - dilations_dims, - mkldnn_paddings[0], - mkldnn_paddings[1]); - - this->AcquireBackwardWeightsPrimitiveDescriptor( - dnnl::algorithm::convolution_direct, - src_md, - diff_weights_md, - diff_dst_md, - strides, - dilations_dims, - mkldnn_paddings[0], - mkldnn_paddings[1]); - } - } - - std::shared_ptr>> get_int8_bias_scales( - const framework::ExecutionContext& ctx) { - // Get scales int8 bias key - const std::string key_bs = this->key_ + "@bs"; - - // Scales for int8 bias are to be cached to avoid - // computing them each iteration - auto bias_scale_tuple = - std::static_pointer_cast>>( - this->dev_ctx_.GetBlob(key_bs)); - if (bias_scale_tuple) return bias_scale_tuple; - - const auto* filter = ctx.Input("Filter"); - const auto& weights_tz = phi::vectorize(filter->dims()); - const int groups = std::max(ctx.Attr("groups"), 1); - - const auto& scale_weights_data = - ctx.Attr>("Scale_weights"); - const auto& scale_in_data = ctx.Attr("Scale_in"); - - bool is_multi_channel = scale_weights_data.size() > 1; - int mask_reorder = is_multi_channel ? 1 << 0 : 1; - - int count = 1; - if (is_multi_channel) { - count *= weights_tz[0]; - if (groups > 1) { - count *= weights_tz[1]; - } - } - - bias_scale_tuple = - std::make_shared>>(std::make_tuple( - static_cast(mask_reorder), std::vector(count))); - for (int i = 0; i < count; i++) { - std::get<1>(*bias_scale_tuple)[i] = scale_in_data * scale_weights_data[i]; - } - - this->dev_ctx_.SetBlob(key_bs, bias_scale_tuple); - - return bias_scale_tuple; - } - - std::tuple, float> get_int8_scales( - const framework::ExecutionContext& ctx) const { - const auto* filter = ctx.Input("Filter"); - const auto& weights_tz = phi::vectorize(filter->dims()); - - const bool& force_fp32_output = ctx.Attr("force_fp32_output"); - const bool& fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - const int groups = std::max(ctx.Attr("groups"), 1); - - const auto& scale_in_data = ctx.Attr("Scale_in"); - const auto& scale_in_eltwise_data = ctx.Attr("Scale_in_eltwise"); - auto scale_weights_data = ctx.Attr>("Scale_weights"); - bool is_multi_channel = scale_weights_data.size() > 1; - bool has_activation = !ctx.Attr("fuse_activation").empty(); - float activation_scale = (!force_fp32_output && has_activation) - ? ctx.Attr("Scale_out") - : 1.0f; - - float scale_out_data = (force_fp32_output || has_activation) - ? 1.0f - : ctx.Attr("Scale_out"); - float sum_scale = - fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; - int count = - is_multi_channel - ? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0]) - : 1; - std::vector output_shift_scale(count); - -#pragma omp parallel for if (count > 50) - for (int i = 0; i < count; i++) { - if (scale_weights_data[i] == 0.0) - // weights data will contain 0 in some models, then weights - // scale couldn't be calculated - output_shift_scale[i] = scale_out_data; - else - output_shift_scale[i] = - static_cast(static_cast(scale_out_data) / - (static_cast(scale_in_data) * - static_cast(scale_weights_data[i]))); - } - - return std::make_tuple(sum_scale, output_shift_scale, activation_scale); - } - - dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) { - dnnl::primitive_attr conv_attr; - dnnl::post_ops post_operations; - - const bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - - float sum_scale = 1.0f; - float activation_scale = 1.0f; - std::vector output_shift_scale; - if (platform::is_int8()) { - if (ctx.HasAttr("Sum_scale")) { - sum_scale = ctx.Attr("Sum_scale"); - activation_scale = ctx.Attr("Activation_scale"); - output_shift_scale = ctx.Attr>("Output_shift_scale"); - } else { - std::tie(sum_scale, output_shift_scale, activation_scale) = - get_int8_scales(ctx); - } - - if (output_shift_scale.size() > 0) { - int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; - conv_attr.set_output_scales(mask, output_shift_scale); - } - } - - // Fusion with Elementwise layer relies on adding a sum post-operation with - // the scale parameter. It is assumed that when fuse_residual_connection is - // true, the output tensor contains the data coming from residual - // connection. The result of this post_op is: - // Output = scale * Output + Conv_Out. - if (fuse_residual_conn) { - post_operations.append_sum(sum_scale); - } - - platform::AppendActivation(ctx, post_operations, activation_scale); - - conv_attr.set_post_ops(post_operations); - return conv_attr; - } - - std::shared_ptr - AcquireWeightsMemoryWithReorderFromDataPrimitive( - const phi::DenseTensor* filter, const int groups, const bool is_conv3d) { - const K* filter_data = filter->data(); - auto weights_tz = phi::vectorize(filter->dims()); - platform::GetGroupConvWeightsTz(weights_tz, groups); - - auto user_src_md = - platform::MKLDNNMemDesc(weights_tz, - platform::MKLDNNGetDataType(), - GetWeightsFormat(groups, is_conv3d)); - - return this->AcquireMemoryWithReorder( - user_src_md, - this->bwd_pd_->weights_desc(), - platform::to_void_cast(filter_data), - "@weights_mem_d_p", - false); - } - - std::shared_ptr AcquireSrcMemoryWithReorder( - const phi::DenseTensor* input) { - return this->AcquireMemoryWithReorderPrimitive(input, - "@src_mem_p_user", - "@src_mem_p_target", - "@src_mem_p", - this->fwd_pd_->src_desc()); - } - - std::shared_ptr AcquireSrcMemoryWithReorderFromWeightsPrimitive( - const phi::DenseTensor* input) { - return this->AcquireMemoryWithReorderPrimitive(input, - "@src_mem_w_p_user", - "@src_mem_w_p_target", - "@src_mem_w_p", - this->bwd_w_pd_->src_desc()); - } - - std::shared_ptr - AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( - const phi::DenseTensor* out_grad) { - return this->AcquireMemoryWithReorderPrimitive( - out_grad, - "@diff_dst_mem_w_p_user", - "@diff_dst_mem_w_p_target", - "@diff_dst_mem_w_p", - this->bwd_w_pd_->diff_dst_desc()); - } - - std::shared_ptr - AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( - const phi::DenseTensor* out_grad) { - return this->AcquireMemoryWithReorderPrimitive( - out_grad, - "@diff_dst_mem_p_user", - "@diff_dst_mem_p_target", - "@diff_dst_mem_p", - this->bwd_pd_->diff_dst_desc()); - } - - std::shared_ptr AcquireMemoryWithReorderPrimitive( - const phi::DenseTensor* in_mem, - const char* key_mem_user, - const char* key_mem_target, - const char* key_mem, - const dnnl::memory::desc& mem_md) { - const T* in_mem_data = in_mem->data(); - const std::string user_key_suffix{key_mem_user}; - auto user_mem_p = this->AcquireMemory(user_key_suffix); - - if (!user_mem_p) { - return this->AcquireMemoryWithReorder( - in_mem->mem_desc(), - mem_md, - platform::to_void_cast(in_mem_data), - key_mem); - } else { - const std::string target_key_suffix{key_mem_target}; - const auto target_mem_p = this->AcquireMemory(target_key_suffix); - user_mem_p->set_data_handle(platform::to_void_cast(in_mem_data)); - if (user_mem_p != target_mem_p) { - this->AcquireReorder(user_mem_p, target_mem_p); - } - return target_mem_p; - } - } - - std::shared_ptr AcquireWeightsMemoryWithReorder( - const phi::DenseTensor* filter, - const int groups, - const bool is_conv3d, - const bool is_test, - const std::vector& scale_data = {1.0f}, - int mask = 0) { - // This is workaround to make execution faster, delete - // if statement after including md inside Tensor - auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); - if (is_test && weights_mem_p) { - return weights_mem_p; - } else if (is_test) { - const K* filter_data = filter->data(); - auto weights_tz = phi::vectorize(filter->dims()); - platform::GetGroupConvWeightsTz(weights_tz, groups); - - auto user_src_md = - platform::MKLDNNMemDesc(weights_tz, - platform::MKLDNNGetDataType(), - GetWeightsFormat(groups, is_conv3d)); - - return this->AcquireMemoryWithReorder( - user_src_md, - this->fwd_pd_->weights_desc(), - platform::to_void_cast(filter_data), - "@weights_mem_p", - is_test, - {}, - scale_data, - mask); - } else { - const T* filter_data = filter->data(); - auto weights_tz = phi::vectorize(filter->dims()); - platform::GetGroupConvWeightsTz(weights_tz, groups); - - auto user_src_md = - platform::MKLDNNMemDesc(weights_tz, - platform::MKLDNNGetDataType(), - GetWeightsFormat(groups, is_conv3d)); - - return this->AcquireMemoryWithReorder( - user_src_md, - this->fwd_pd_->weights_desc(), - platform::to_void_cast(filter_data), - "@weights_mem_p", - is_test, - {}, - scale_data, - mask); - } - } - - std::shared_ptr AcquireBiasMemoryWithReorder( - const phi::DenseTensor* bias, - const bool is_test, - const std::vector& scale_data = {1.0f}, - int mask = 0) { - auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); - if (is_test && bias_mem_p) { - return bias_mem_p; - } else { - // if K is int8 (weights are int8) then biases are int32 - using K_Bias = typename std:: - conditional::value, int32_t, K>::type; - if (std::is_same::value && - bias->dtype() != phi::DataType::INT32) { - LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype(); - } - const K_Bias* bias_data = bias->data(); - - return this->AcquireMemoryWithReorder( - bias->mem_desc(), - this->fwd_pd_->bias_desc(), - platform::to_void_cast(bias_data), - "@bias_mem_p", - is_test, - {}, - scale_data, - mask); - } - } - - std::shared_ptr AcquireResidualMemory( - const phi::DenseTensor* residual_param) { - void* residual_data = - framework::TransToProtoVarType(residual_param->dtype()) == - framework::DataTypeTrait::DataType() - ? platform::to_void_cast(residual_param->data()) - : platform::to_void_cast(residual_param->data()); - auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p"); - if (residual_mem_p) { - residual_mem_p->set_data_handle(residual_data); - return residual_mem_p; - } else { - return this->AcquireMemoryFromPrimitive(residual_param->mem_desc(), - residual_data, - "@user_residual_data_mem_p"); - } - } - - std::shared_ptr AcquireDstMemoryWithResidual( - phi::DenseTensor* output, const phi::DenseTensor* residual_param) { - std::shared_ptr dst_memory_p; - if (residual_param->mem_desc() != this->fwd_pd_->dst_desc()) { - auto residual_memory_p = this->AcquireResidualMemory(residual_param); - dst_memory_p = this->template AcquireDstMemory(output); - this->AcquireReorder(residual_memory_p, dst_memory_p); - } else { - // Changing ShareDataWith to TensorCopy results in performance drop - // on ResNet architectures - // (https://github.com/PaddlePaddle/Paddle/issues/22964) - output->ShareDataWith(*residual_param); - dst_memory_p = this->template AcquireDstMemory(output); - } - return dst_memory_p; - } -}; - -} // anonymous namespace - -#define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::BFLOAT16, \ - ::phi::dtype::bfloat16, \ - __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() - -template -class ConvMKLDNNGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet( - "Operator DNNL ConvGrad must use CPUPlace")); - auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - const phi::DenseTensor* input = ctx.Input("Input"); - const phi::DenseTensor* filter = ctx.Input("Filter"); - const phi::DenseTensor* bias = - ctx.HasInput("Bias") ? ctx.Input("Bias") : nullptr; - const phi::DenseTensor* output_grad = - ctx.Input(framework::GradVarName("Output")); - phi::DenseTensor* input_grad = - ctx.Output(framework::GradVarName("Input")); - phi::DenseTensor* filter_grad = - ctx.Output(framework::GradVarName("Filter")); - - if (!input_grad && !filter_grad) return; - - PD_VISIT_FLOAT_AND_BF16_TYPES( - filter->dtype(), "ConvMKLDNNHandlerT", ([&] { - // TODO(jczaja): Are all tensors really needed? - ConvMKLDNNHandlerT handler( - ctx, - dev_ctx, - ctx.GetPlace(), - input, - filter, - bias, - output_grad, - filter_grad, - input_grad, - ctx.InputName("Input") + ctx.InputName("Filter")); - - // create mkldnn memory from input tensors (data/weights) - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - - if (filter_grad) { - auto src_memory_p = - handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input); - auto diff_dst_memory_p = - handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( - output_grad); - - // For convoluition with groups write filter grad into - // oneDNN buffer and then we reorder it into filter_grad tensor - int g = std::max(ctx.Attr("groups"), 1); - auto diff_weights_memory_p = - g > 1 ? handler.AcquireDiffWeightsMemory() - : handler.AcquireDiffWeightsMemory(filter_grad); - - auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive(); - - conv_bwd_weights_p->execute( - astream, - {{DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); - astream.wait(); - - // For convolution with groups convert from blocked to NCHW - // otherwise there will be problems in next operators working on - // this data - if (g > 1) { - // in OneDNN groups in convolution are treated as separate - // dimension which is not the case in paddlepaddle - - dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(filter->dtype())); - // for 3d conv with groups (six dimensional data reorder to - // goidhw) for 2d conv with groups (five dimensional data reorder - // to goihw) auto weights_tz = phi::vectorize(filter->dims()); - - auto weights_tz = diff_weights_memory_p->get_desc().dims(); - dnnl::memory::format_tag out_format = - weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw - : dnnl::memory::format_tag::goihw; - platform::ReorderMKLDNNHandler handler( - weights_tz, - framework::TransToProtoVarType(filter->dtype()), - in_type, - mkldnn_engine); - auto reorder_dst_memory_p = handler.AcquireDstMemory( - filter_grad, out_format, ctx.GetPlace()); - - auto reorder_p = handler.AcquireReorder(reorder_dst_memory_p, - diff_weights_memory_p); - - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder_p->execute( - astream, *diff_weights_memory_p, *reorder_dst_memory_p); - astream.wait(); - } - - // So here we have a data in goihw , which can be interpreted as - // OIHW (OIDHW for conv3d) because filter_grad shape is set for - // OIHW (OIDHW for conv3d) - dnnl::memory::format_tag target_format = - weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw - : dnnl::memory::format_tag::oihw; - filter_grad->set_mem_desc(dnnl::memory::desc( - phi::vectorize(filter_grad->dims()), - in_type, - target_format)); - } else { - filter_grad->set_mem_desc(diff_weights_memory_p->get_desc()); - } - } - if (input_grad) { - auto weights_memory_p = - handler.AcquireWeightsMemoryWithReorderFromDataPrimitive( - filter, - ctx.Attr("groups"), - ctx.Attr>("strides").size() == 3U); - - auto diff_dst_memory_p = - handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( - output_grad); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad); - - auto conv_bwd_data_p = handler.AcquireBackwardPrimitive(); - - conv_bwd_data_p->execute(astream, - {{DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); - astream.wait(); - - input_grad->set_mem_desc(diff_src_memory_p->get_desc()); - } - })); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(depthwise_conv2d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::ConvMKLDNNGradOpKernel, - ops::ConvMKLDNNGradOpKernel); - -REGISTER_OP_KERNEL(conv3d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::ConvMKLDNNGradOpKernel); diff --git a/paddle/phi/kernels/onednn/conv_grad_kernel.cc b/paddle/phi/kernels/onednn/conv_grad_kernel.cc index 69c8122966..8d2ff16999 100644 --- a/paddle/phi/kernels/onednn/conv_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/conv_grad_kernel.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/conv_grad_kernel.h" - +#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/data_layout_transform.h" @@ -54,7 +54,7 @@ void ConvGradKernel(const Context& dev_ctx, PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType(), AllocationType::CPU, phi::errors::PreconditionNotMet( - "Operator DNNL ConvGrad must use CPUPlace")); + "Operator oneDNN ConvGrad must use CPUPlace")); const auto& onednn_engine = dev_ctx.GetEngine(); const auto* bias = @@ -140,6 +140,11 @@ void ConvGradKernel(const Context& dev_ctx, diff_weights_memory_p); { + paddle::platform::RecordEvent record_reorder( + "int_reorder", + paddle::platform::TracerEventType::UserDefined, + 2, + paddle::platform::EventRole::kUniqueOp); reorder_p->execute( astream, *diff_weights_memory_p, *reorder_dst_memory_p); astream.wait(); @@ -182,6 +187,60 @@ void ConvGradKernel(const Context& dev_ctx, })); } +template +void DepthwiseConvGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad) { + ConvGradKernel(dev_ctx, + input, + filter, + out_grad, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + input_grad, + filter_grad); +} + +template +void Conv3DGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* input_grad, + DenseTensor* filter_grad) { + ConvGradKernel(dev_ctx, + input, + filter, + out_grad, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + input_grad, + filter_grad); +} + } // namespace phi PD_REGISTER_KERNEL(conv2d_grad, @@ -190,3 +249,12 @@ PD_REGISTER_KERNEL(conv2d_grad, phi::ConvGradKernel, float, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(depthwise_conv2d_grad, + OneDNN, + ONEDNN, + phi::DepthwiseConvGradKernel, + float, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(conv3d_grad, OneDNN, ONEDNN, phi::Conv3DGradKernel, float) {} -- GitLab