From 1de6daff82fc32e7480399359b599599e97331a3 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 16 Mar 2021 22:41:27 +0800 Subject: [PATCH] [NPU] fix shape of dx in mul_grad (#31675) * fix shape of dx * refine code --- paddle/fluid/operators/mul_op_npu.cc | 29 +++++++++++----------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/mul_op_npu.cc b/paddle/fluid/operators/mul_op_npu.cc index cf057cc339c..b52cba9cb05 100644 --- a/paddle/fluid/operators/mul_op_npu.cc +++ b/paddle/fluid/operators/mul_op_npu.cc @@ -140,19 +140,15 @@ class MulGradNPUKernel : public framework::OpKernel { // matmul if (dx) { // matmul [2, 5] * [12, 5] => [2, 12] - Tensor tmp_matmul(y->type()); - tmp_matmul.Resize( - framework::make_ddim({dout->dims()[0], y->dims()[0]})); - tmp_matmul.mutable_data(ctx.GetPlace()); + dx->mutable_data(ctx.GetPlace()); + auto dx_dims = dx->dims(); + dx->Resize(framework::make_ddim({dout->dims()[0], y->dims()[0]})); auto runner_matmul = - NpuOpRunner("MatMul", {*dout, *y}, {tmp_matmul}, + NpuOpRunner("MatMul", {*dout, *y}, {*dx}, {{"transpose_x1", false}, {"transpose_x2", true}}); runner_matmul.Run(stream); // reshape [2, 12] => [2, 3, 4] - dx->mutable_data(ctx.GetPlace(), x->type()); - framework::TensorCopy( - tmp_matmul, ctx.GetPlace(), - ctx.template device_context(), dx); + dx->Resize(dx_dims); } if (dy) { @@ -193,18 +189,15 @@ class MulGradNPUKernel : public framework::OpKernel { if (dx) { // tmp_dout * y [6,5] * [4,5] => [6, 4] - Tensor tmp_matmul(y->type()); - tmp_matmul.Resize(framework::make_ddim({dout_first_dim, y->dims()[0]})); - tmp_matmul.mutable_data(ctx.GetPlace()); + dx->mutable_data(ctx.GetPlace()); + auto dx_dims = dx->dims(); + dx->Resize(framework::make_ddim({dout_first_dim, y->dims()[0]})); auto runner_matmul = - NpuOpRunner("MatMul", {tmp_dout, *y}, {tmp_matmul}, + NpuOpRunner("MatMul", {tmp_dout, *y}, {*dx}, {{"transpose_x1", false}, {"transpose_x2", true}}); runner_matmul.Run(stream); - // reshape [6,4] => [2, 3, 4] - dx->mutable_data(ctx.GetPlace(), x->type()); - framework::TensorCopy( - tmp_matmul, ctx.GetPlace(), - ctx.template device_context(), dx); + // reshape [2, 12] => [2, 3, 4] + dx->Resize(dx_dims); } if (dy) { // flatten x.shape [2,3,4] => [6, 4] -- GitLab