From b27270201a9f63a5d7395549d74727ce9f6a9969 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 4 Aug 2022 10:48:05 +0200 Subject: [PATCH] added conv and conv_tranpose support for md (#44677) --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 92 ++++++------------- .../mkldnn/conv_transpose_mkldnn_op.cc | 21 +---- 2 files changed, 29 insertions(+), 84 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 8ee97c281e..fc8f299130 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -24,13 +24,13 @@ namespace paddle { namespace operators { namespace { -inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format, - const int groups, +inline MKLDNNMemoryFormat GetWeightsFormat(const int groups, const bool is_conv3d) { if (is_conv3d) { - return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw; + return (groups == 1) ? MKLDNNMemoryFormat::oidhw + : MKLDNNMemoryFormat::goidhw; } else { - return (groups == 1) ? format : MKLDNNMemoryFormat::goihw; + return (groups == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw; } } @@ -98,10 +98,6 @@ class ConvMKLDNNHandlerT "The input tensor's layout should be %d, but got %d.", framework::DataLayout::kMKLDNN, input->layout())); - PADDLE_ENFORCE_NE(input->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Wrong format set for Input tensor")); PADDLE_ENFORCE_EQ( filter->layout(), @@ -110,10 +106,6 @@ class ConvMKLDNNHandlerT "The Filter tensor's layout should be %d, but got %d.", framework::DataLayout::kMKLDNN, filter->layout())); - PADDLE_ENFORCE_NE(filter->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Wrong format set for Filter tensor")); PADDLE_ENFORCE_GE( input->dims().size(), @@ -153,10 +145,6 @@ class ConvMKLDNNHandlerT "The Bias tensor's layout should be %d, but got %d.", framework::DataLayout::kMKLDNN, bias->layout())); - PADDLE_ENFORCE_NE(bias->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong format for Bias tensor.")); PADDLE_ENFORCE_EQ(bias->dims().size(), 1, @@ -307,10 +295,6 @@ class ConvMKLDNNHandlerT "The input tensor's layout should be %d, but got %d.", framework::DataLayout::kMKLDNN, in->layout())); - PADDLE_ENFORCE_NE(in->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong format for Input tensor.")); PADDLE_ENFORCE_EQ( filter->layout(), @@ -319,10 +303,6 @@ class ConvMKLDNNHandlerT "The filter tensor's layout should be %d, but got %d.", framework::DataLayout::kMKLDNN, filter->layout())); - PADDLE_ENFORCE_NE(filter->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong format for Filter tensor.")); PADDLE_ENFORCE_EQ( out_grad->layout(), @@ -331,10 +311,6 @@ class ConvMKLDNNHandlerT "The output_grad tensor's layout should be %d, but got %d.", framework::DataLayout::kMKLDNN, out_grad->layout())); - PADDLE_ENFORCE_NE(out_grad->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Wrong format set for output_grad tensor")); PADDLE_ENFORCE_EQ( ctx.Attr("is_test"), @@ -596,10 +572,10 @@ class ConvMKLDNNHandlerT auto weights_tz = phi::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); - auto user_src_md = platform::MKLDNNMemDesc( - weights_tz, - platform::MKLDNNGetDataType(), - GetWeightsFormat(filter->format(), groups, is_conv3d)); + auto user_src_md = + platform::MKLDNNMemDesc(weights_tz, + platform::MKLDNNGetDataType(), + GetWeightsFormat(groups, is_conv3d)); return this->AcquireMemoryWithReorder( user_src_md, @@ -660,12 +636,11 @@ class ConvMKLDNNHandlerT auto user_mem_p = this->AcquireMemory(user_key_suffix); if (!user_mem_p) { - auto user_mem_md = - platform::MKLDNNMemDesc(phi::vectorize(in_mem->dims()), - platform::MKLDNNGetDataType(), - in_mem->format()); return this->AcquireMemoryWithReorder( - user_mem_md, mem_md, platform::to_void_cast(in_mem_data), key_mem); + in_mem->mem_desc(), + mem_md, + platform::to_void_cast(in_mem_data), + key_mem); } else { const std::string target_key_suffix{key_mem_target}; const auto target_mem_p = this->AcquireMemory(target_key_suffix); @@ -694,10 +669,10 @@ class ConvMKLDNNHandlerT auto weights_tz = phi::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); - auto user_src_md = platform::MKLDNNMemDesc( - weights_tz, - platform::MKLDNNGetDataType(), - GetWeightsFormat(filter->format(), groups, is_conv3d)); + auto user_src_md = + platform::MKLDNNMemDesc(weights_tz, + platform::MKLDNNGetDataType(), + GetWeightsFormat(groups, is_conv3d)); return this->AcquireMemoryWithReorder( user_src_md, @@ -713,10 +688,10 @@ class ConvMKLDNNHandlerT auto weights_tz = phi::vectorize(filter->dims()); platform::GetGroupConvWeightsTz(weights_tz, groups); - auto user_src_md = platform::MKLDNNMemDesc( - weights_tz, - platform::MKLDNNGetDataType(), - GetWeightsFormat(filter->format(), groups, is_conv3d)); + auto user_src_md = + platform::MKLDNNMemDesc(weights_tz, + platform::MKLDNNGetDataType(), + GetWeightsFormat(groups, is_conv3d)); return this->AcquireMemoryWithReorder( user_src_md, @@ -747,13 +722,9 @@ class ConvMKLDNNHandlerT LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype(); } const K_Bias* bias_data = bias->data(); - auto user_bias_md = - platform::MKLDNNMemDesc(phi::vectorize(bias->dims()), - platform::MKLDNNGetDataType(), - MKLDNNMemoryFormat::x); return this->AcquireMemoryWithReorder( - user_bias_md, + bias->mem_desc(), this->fwd_pd_->bias_desc(), platform::to_void_cast(bias_data), "@bias_mem_p", @@ -776,22 +747,16 @@ class ConvMKLDNNHandlerT residual_mem_p->set_data_handle(residual_data); return residual_mem_p; } else { - auto user_residual_md = platform::MKLDNNMemDesc( - phi::vectorize(residual_param->dims()), - framework::ToMKLDNNDataType( - framework::TransToProtoVarType(residual_param->dtype())), - residual_param->format()); - - return this->AcquireMemoryFromPrimitive( - user_residual_md, residual_data, "@user_residual_data_mem_p"); + return this->AcquireMemoryFromPrimitive(residual_param->mem_desc(), + residual_data, + "@user_residual_data_mem_p"); } } std::shared_ptr AcquireDstMemoryWithResidual( framework::Tensor* output, const framework::Tensor* residual_param) { std::shared_ptr dst_memory_p; - if (residual_param->format() != - platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) { + if (residual_param->mem_desc() != this->fwd_pd_->dst_desc()) { auto residual_memory_p = this->AcquireResidualMemory(residual_param); dst_memory_p = this->template AcquireDstMemory(output); this->AcquireReorder(residual_memory_p, dst_memory_p); @@ -903,8 +868,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { conv_p->execute(astream, args); astream.wait(); - output->set_layout(framework::DataLayout::kMKLDNN); - output->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); + output->set_mem_desc(dst_memory_p->get_desc()); } template @@ -1018,8 +982,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); } - output->set_layout(framework::DataLayout::kMKLDNN); - output->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); + output->set_mem_desc(dst_memory_p->get_desc()); } }; @@ -1078,7 +1041,6 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel { auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive(); - // TODO(grygielski) why no bias_diff? conv_bwd_weights_p->execute( astream, {{DNNL_ARG_SRC, *src_memory_p}, diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index cd81168753..8016338931 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -59,11 +59,6 @@ class ConvTransposeMKLDNNHandlerT DataLayout::kMKLDNN, platform::errors::InvalidArgument( "Got wrong layout = %d for Input tensor.", input->layout())); - PADDLE_ENFORCE_NE(input->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong format for Input tensor. The input " - "format is undefined.")); PADDLE_ENFORCE_EQ( filter->layout(), @@ -72,10 +67,6 @@ class ConvTransposeMKLDNNHandlerT "The filter tensor's layout should be %d, but got %d.", DataLayout::kMKLDNN, filter->layout())); - PADDLE_ENFORCE_NE(filter->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong formats for Filter tensor.")); PADDLE_ENFORCE_EQ( input->dims().size(), @@ -98,10 +89,6 @@ class ConvTransposeMKLDNNHandlerT "The bias tensor's laytout should be %d, but got %d.", DataLayout::kMKLDNN, bias->layout())); - PADDLE_ENFORCE_NE(bias->format(), - MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong format for Bias tensor.")); PADDLE_ENFORCE_EQ( bias->dims().size(), @@ -233,11 +220,8 @@ class ConvTransposeMKLDNNHandlerT std::shared_ptr AcquireSrcMemoryWithReorder( const framework::Tensor* input) { const T* input_data = input->data(); - auto user_src_md = platform::MKLDNNMemDesc(phi::vectorize(input->dims()), - platform::MKLDNNGetDataType(), - input->format()); return platform::MKLDNNHandlerNoCachingT:: - AcquireMemoryWithReorder(user_src_md, + AcquireMemoryWithReorder(input->mem_desc(), this->fwd_pd_->src_desc(), platform::to_void_cast(input_data)); } @@ -427,8 +411,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel { auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); conv_p->execute(astream, args); astream.wait(); - output->set_layout(DataLayout::kMKLDNN); - output->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); + output->set_mem_desc(dst_memory_p->get_desc()); } }; -- GitLab