From 43527a2b4fc627d392c7e6cc44f744b7231b6418 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Fri, 30 Apr 2021 04:05:35 +0200 Subject: [PATCH] Reduce grad fix (#32592) --- .../mkldnn/reduce_mean_mkldnn_op.cc | 3 +- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 90 ++++++++++++------- .../reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc | 3 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 25 ++---- paddle/fluid/platform/mkldnn_reuse.h | 31 +++---- 5 files changed, 79 insertions(+), 73 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc index 33daeea8599..dfba933940b 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -45,7 +45,8 @@ class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel { number_of_elements = input_x->numel(); } - this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, + this->RunKernel(ctx, dnnl::algorithm::binary_add, + dnnl::algorithm::reduction_mean, 0.0f, 1.0L / number_of_elements); } }; diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 58416f479c0..40cd3ba974f 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -21,6 +21,27 @@ using paddle::framework::LoDTensor; using paddle::framework::Tensor; using platform::to_void_cast; +inline std::vector CalculateReducedDims(const Tensor* input, + const Tensor* output, + std::vector& reduce_dims, + bool reduce_all, + bool keep_dim) { + if (keep_dim) return framework::vectorize(output->dims()); + + if (reduce_all) + return std::vector(framework::vectorize(input->dims()).size(), 1); + + std::vector output_dims(framework::vectorize(input->dims())); + for (size_t i = 0; i < reduce_dims.size(); ++i) { + reduce_dims[i] = (reduce_dims[i] >= 0) + ? reduce_dims[i] + : input->dims().size() + reduce_dims[i]; + output_dims[reduce_dims[i]] = 1; + } + + return output_dims; +} + template class ReduceMKLDNNKernel : public framework::OpKernel { public: @@ -37,9 +58,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel { bool reduce_all = ctx.Attr("reduce_all"); bool keep_dim = ctx.Attr("keep_dim"); - std::vector output_dims = - CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim); - + auto output_dims = + CalculateReducedDims(input, output, reduce_dims, reduce_all, keep_dim); auto input_dims = framework::vectorize(input->dims()); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); @@ -96,53 +116,63 @@ class ReduceMKLDNNKernel : public framework::OpKernel { paddle::framework::vectorize(output->dims())))); } } - - private: - std::vector CalculateOutputDims(const Tensor* input, - const Tensor* output, - std::vector& reduce_dims, - bool reduce_all, - bool keep_dim) const { - if (keep_dim) return framework::vectorize(output->dims()); - - if (reduce_all) - return std::vector(framework::vectorize(input->dims()).size(), - 1); - - std::vector output_dims(framework::vectorize(input->dims())); - for (size_t i = 0; i < reduce_dims.size(); ++i) { - reduce_dims[i] = (reduce_dims[i] >= 0) - ? reduce_dims[i] - : input->dims().size() + reduce_dims[i]; - output_dims[reduce_dims[i]] = 1; - } - - return output_dims; - } }; template class ReduceGradMKLDNNKernel : public framework::OpKernel { public: void RunKernel(const framework::ExecutionContext& ctx, - dnnl::algorithm binary_type, float scale_x, - float scale_y) const { + dnnl::algorithm binary_type, dnnl::algorithm reduction_type, + float scale_x, float scale_y) const { const auto& dev_ctx = ctx.template device_context(); const auto& onednn_engine = dev_ctx.GetEngine(); + bool keep_dim = ctx.Attr("keep_dim"); + bool reduce_all = ctx.Attr("reduce_all"); auto dims = ctx.Attr>("dim"); auto* input_dy = ctx.Input(framework::GradVarName("Out")); auto* output_dx = ctx.Output(framework::GradVarName("X")); + mkldnn::memory::format_tag x_format_tag; + auto input_dims = + CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim); + + if (input_dims != framework::vectorize(output_dx->dims())) { + const std::string key_pd = + platform::CreateKey( + dev_ctx, framework::vectorize(output_dx->dims()), + ctx.InputName("X"), + (std::to_string(static_cast(reduction_type)))) + + "@fwd_pd"; + std::shared_ptr fwd_pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_pd)); + + PADDLE_ENFORCE_NOT_NULL( + fwd_pd, platform::errors::Unavailable( + "Forward primitive descriptor is not available in %s op, " + "cannot deduce memory format tag", + ctx.Type())); + + x_format_tag = platform::GetMKLDNNFormat(fwd_pd->src_desc()); + + PADDLE_ENFORCE_NE(x_format_tag, mkldnn::memory::format_tag::undef, + platform::errors::InvalidArgument( + "Cannot deduce format tag for %s op", ctx.Type())); + } else { // fwd descriptor not available because reorder was used instead + // of reduction + x_format_tag = getPlainFormatTag(output_dx); + } + output_dx->mutable_data(ctx.GetPlace()); - output_dx->set_format(getPlainFormatTag(output_dx)); + output_dx->set_format(x_format_tag); output_dx->set_layout(input_dy->layout()); platform::BroadcastDataMKLDNNHandler handler( binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx, input_dy, scale_x, scale_y, - ctx.InputName(framework::GradVarName("Out"))); + ctx.InputName(framework::GradVarName("Out")), input_dims); const auto src_dx_memory = handler.AcquireSrcMemory(output_dx); const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc index e62edcf5596..3f92d39ede1 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -29,7 +29,8 @@ template class ReduceSumGradMKLDNNKernel : public ReduceGradMKLDNNKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, 1.0f); + this->RunKernel(ctx, dnnl::algorithm::binary_add, + dnnl::algorithm::reduction_sum, 0.0f, 1.0f); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 913d941df88..390c4d9709a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -559,8 +559,11 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); + int in_dtype = ctx.Attr("in_dtype"); + auto input_data_type = + (in_dtype >= 0) ? static_cast(in_dtype) + : OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN auto CanMKLDNNReduceGradBeUsed = [&]() { @@ -568,18 +571,6 @@ class ReduceGradOp : public framework::OperatorWithKernel { if (dx_dims.size() > 5) return false; // max 5D tensor is supported - if (ctx.Attr("reduce_all") || - ((int)ctx.Attr>("dim").size() == dx_dims.size())) - return true; - - auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - - // Subtensor must be on rightmost part of the bigger tensor - for (int i = 0; i < dy_dims.size(); ++i) { - if (dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]) { - return false; - } - } return true; }; if (this->CanMKLDNNBeUsed(ctx, input_data_type) && @@ -590,12 +581,6 @@ class ReduceGradOp : public framework::OperatorWithKernel { } #endif - int in_dtype = ctx.Attr("in_dtype"); - if (in_dtype >= 0) { - return framework::OpKernelType( - static_cast(in_dtype), - ctx.GetPlace()); - } return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 54efa55cc4c..f1eb1f96363 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -639,7 +639,8 @@ class BroadcastDataMKLDNNHandler const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, float scale_x, float scale_y, - const std::string& uniq_name) + const std::string& uniq_name, + std::vector& input_dims) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), @@ -659,24 +660,12 @@ class BroadcastDataMKLDNNHandler y->format(), MKLDNNMemoryFormat::undef, platform::errors::InvalidArgument("Wrong format set for Y tensor.")); - auto src1_tz = framework::vectorize(y->dims()); const auto src0_tz = framework::vectorize(x->dims()); - // GetExpectedKernelType checks if smaller vector is a subvector with all - // the dims in correct order on the rightmost part of the bigger vector, - // i.e. a correct vector for broadcasting: - // x = 5, 7, 3, 2, 4, 8 - // y = 4, 8 - src1_tz.reserve(src0_tz.size()); - - for (size_t i = src1_tz.size(); i < src0_tz.size(); ++i) { - src1_tz.insert(src1_tz.begin(), 1L); - } - const auto src0_md = dnnl::memory::desc( src0_tz, platform::MKLDNNGetDataType(), x->format()); const auto src1_md = dnnl::memory::desc( - src1_tz, platform::MKLDNNGetDataType(), x->format()); + input_dims, platform::MKLDNNGetDataType(), x->format()); dnnl::primitive_attr attributes; attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); @@ -711,7 +700,7 @@ class ReductionMKLDNNHandler const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, const std::string& uniq_name, - std::vector output_dims) + std::vector y_tz) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), @@ -725,14 +714,14 @@ class ReductionMKLDNNHandler x->format(), MKLDNNMemoryFormat::undef, platform::errors::InvalidArgument("Wrong format set for X tensor.")); - const auto src_tz = framework::vectorize(x->dims()); + const auto x_tz = framework::vectorize(x->dims()); - const auto src_md = dnnl::memory::desc( - src_tz, platform::MKLDNNGetDataType(), x->format()); - const auto dst_md = memory::desc( - output_dims, platform::MKLDNNGetDataType(), x->format()); + const auto x_md = dnnl::memory::desc( + x_tz, platform::MKLDNNGetDataType(), x->format()); + const auto y_md = + memory::desc(y_tz, platform::MKLDNNGetDataType(), x->format()); - this->AcquireForwardPrimitiveDescriptor(algo, src_md, dst_md, p, eps); + this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps); } } }; -- GitLab