From ca68b13f85be0451297943081420cdf211598d7b Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Fri, 22 May 2020 04:42:29 +0200 Subject: [PATCH] [oneDNN] Fix to elementwise_add grad (#24639) --- paddle/fluid/operators/elementwise/elementwise_op.h | 6 ++---- .../elementwise/mkldnn/elementwise_add_mkldnn_op.cc | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index deff0ff8d5..85d501f6bf 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { int axis = ctx.Attr("axis"); int rankdiff = ctx.Input("X")->dims().size() - ctx.Input("Y")->dims().size(); - return (axis == -1) || (axis == rankdiff); + return (rankdiff == 0) || (axis == -1) || (axis == rankdiff); }; if (platform::CanMKLDNNBeUsed(ctx) && @@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN // If broadcasting is needed, use native implementation auto CanMKLDNNElementwiseAddGradBeUsed = [&]() { - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Output(framework::GradVarName("Y")); - return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()); + return (ctx.Input("X")->dims() == ctx.Input("Y")->dims()); }; if (platform::CanMKLDNNBeUsed(ctx) && diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 89face8faa..98b79d6bb2 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { in->set_format(out->format()); }; + // TODO(jczaja): Double check if vcopy works for blocked data auto blas = math::GetBlas(ctx); if (dx) { blas.VCOPY(dout->numel(), dout->data(), -- GitLab