diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index deff0ff8d5aa4656372a0d18489cd704de8c9efa..85d501f6bf7f8f856040c120d49a73a4f4d6696d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { int axis = ctx.Attr("axis"); int rankdiff = ctx.Input("X")->dims().size() - ctx.Input("Y")->dims().size(); - return (axis == -1) || (axis == rankdiff); + return (rankdiff == 0) || (axis == -1) || (axis == rankdiff); }; if (platform::CanMKLDNNBeUsed(ctx) && @@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN // If broadcasting is needed, use native implementation auto CanMKLDNNElementwiseAddGradBeUsed = [&]() { - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Output(framework::GradVarName("Y")); - return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()); + return (ctx.Input("X")->dims() == ctx.Input("Y")->dims()); }; if (platform::CanMKLDNNBeUsed(ctx) && diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 89face8faaeed8c306ebd482dfb5d4371a92b6a3..98b79d6bb22fcff09533c2e9325d94659b3ef0c1 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { in->set_format(out->format()); }; + // TODO(jczaja): Double check if vcopy works for blocked data auto blas = math::GetBlas(ctx); if (dx) { blas.VCOPY(dout->numel(), dout->data(),