diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index ce4a3f8a0931104be9955ca99297eb9068582849..9e58cff01d3ccad482c7482644b25b0d5db518d4 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -141,8 +141,6 @@ class ElementwiseAddGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { using Tensor = framework::Tensor; - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); @@ -150,13 +148,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel { if (dx != nullptr) dx->ShareDataWith(*dout); if (dy == nullptr) return; - if (x->dims() == y->dims()) { + const framework::DDim& x_dim = dout->dims(); + framework::DDim y_dim = dy->dims(); + if (x_dim == y_dim) { dy->ShareDataWith(*dout); } else { dy->mutable_data(ctx.GetPlace()); // Perform reduction to dout to calculate dy - const framework::DDim& x_dim = x->dims(); - framework::DDim y_dim = y->dims(); int axis = ctx.Attr("axis"); axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); y_dim = trim_trailing_singular_dims(y_dim);