diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc index 1ba6c4cb1932b1f06566097f20c2a753a32026b3..cd1d50a017c363f7bb0b2d48b591e21b1b2011db 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc @@ -42,27 +42,22 @@ class ElementwiseAddNPUKernel : public framework::OpKernel { auto y_dims = y->dims(); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); if (x_dims.size() >= y_dims.size()) { - direct_compute = - y_dims == framework::slice_ddim(x_dims, axis, x_dims.size()); + direct_compute = x_dims.size() == (y_dims.size() + axis); } else { - direct_compute = - x_dims == framework::slice_ddim(y_dims, axis, y_dims.size()); + direct_compute = y_dims.size() == (x_dims.size() + axis); } - Tensor transformed_x, transformed_y; if (direct_compute) { - transformed_x.ShareDataWith(*x); - transformed_y.ShareDataWith(*y); + const auto& runner = NpuOpRunner("Add", {*x, *y}, {*out}, {}); + runner.Run(dev_ctx.stream()); } else { + Tensor transformed_x, transformed_y; NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, &transformed_y); + const auto& runner = + NpuOpRunner("Add", {transformed_x, transformed_y}, {*out}, {}); + runner.Run(dev_ctx.stream()); } - const auto& runner = - NpuOpRunner("Add", {transformed_x, transformed_y}, {*out}, {}); - auto stream = - ctx.template device_context() - .stream(); - runner.Run(stream); } };