From 91ba86b1f3d087a3347efee2009573ff0a6f7ed5 Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Wed, 25 Aug 2021 10:54:29 +0800 Subject: [PATCH] [NPU] Fix the performance problem when 'axis' is not specified (#35116) --- .../elementwise/elementwise_add_op_npu.cc | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc index 1ba6c4cb19..cd1d50a017 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc @@ -42,27 +42,22 @@ class ElementwiseAddNPUKernel : public framework::OpKernel { auto y_dims = y->dims(); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); if (x_dims.size() >= y_dims.size()) { - direct_compute = - y_dims == framework::slice_ddim(x_dims, axis, x_dims.size()); + direct_compute = x_dims.size() == (y_dims.size() + axis); } else { - direct_compute = - x_dims == framework::slice_ddim(y_dims, axis, y_dims.size()); + direct_compute = y_dims.size() == (x_dims.size() + axis); } - Tensor transformed_x, transformed_y; if (direct_compute) { - transformed_x.ShareDataWith(*x); - transformed_y.ShareDataWith(*y); + const auto& runner = NpuOpRunner("Add", {*x, *y}, {*out}, {}); + runner.Run(dev_ctx.stream()); } else { + Tensor transformed_x, transformed_y; NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, &transformed_y); + const auto& runner = + NpuOpRunner("Add", {transformed_x, transformed_y}, {*out}, {}); + runner.Run(dev_ctx.stream()); } - const auto& runner = - NpuOpRunner("Add", {transformed_x, transformed_y}, {*out}, {}); - auto stream = - ctx.template device_context() - .stream(); - runner.Run(stream); } }; -- GitLab