From 5d29a27c2ef4d7398959a9e36feca42599e4ce38 Mon Sep 17 00:00:00 2001 From: oyjxer <1728722986@qq.com> Date: Fri, 12 Mar 2021 17:07:45 +0800 Subject: [PATCH] [NPU] fix npu op elementwise_mul_grad (#31592) --- .../elementwise/elementwise_mul_op_npu.cc | 20 ++++++++++++------- .../npu/test_elementwise_mul_op_npu.py | 1 + 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc index cecb62d3580..08df6d4e27a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PADDLE_WITH_ASCEND_CL #include #include @@ -58,17 +59,21 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel { auto place = ctx.GetPlace(); - dx->mutable_data(place); - dy->mutable_data(place); - auto stream = ctx.template device_context() .stream(); - auto dx_runner = NpuOpRunner("Mul", {*dout, *y}, {*dx}, {}); - dx_runner.Run(stream); - auto dy_runner = NpuOpRunner("Mul", {*x, *dout}, {*dy}, {}); - dy_runner.Run(stream); + if (dx) { + dx->mutable_data(place); + auto dx_runner = NpuOpRunner("Mul", {*dout, *y}, {*dx}, {}); + dx_runner.Run(stream); + } + + if (dy) { + dy->mutable_data(place); + auto dy_runner = NpuOpRunner("Mul", {*x, *dout}, {*dy}, {}); + dy_runner.Run(stream); + } } }; @@ -88,3 +93,4 @@ REGISTER_OP_NPU_KERNEL( ops::ElementwiseMulGradNPUKernel, ops::ElementwiseMulGradNPUKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_mul_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_mul_op_npu.py index d0baa5877f2..9bfb7e033e7 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_elementwise_mul_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_mul_op_npu.py @@ -122,6 +122,7 @@ class TestElementwiseMulNet(unittest.TestCase): e = paddle.multiply(a, b) f = paddle.multiply(c, d) + f.stop_gradient = True g = paddle.multiply(e, f) fc_1 = fluid.layers.fc(input=g, size=128) -- GitLab