From 1de6daff82fc32e7480399359b599599e97331a3 Mon Sep 17 00:00:00 2001
From: Leo Chen <chenqiuliang@baidu.com>
Date: Tue, 16 Mar 2021 22:41:27 +0800
Subject: [PATCH] [NPU] fix shape of dx in mul_grad (#31675)

* fix shape of dx

* refine code
---
 paddle/fluid/operators/mul_op_npu.cc | 29 +++++++++++-----------------
 1 file changed, 11 insertions(+), 18 deletions(-)

diff --git a/paddle/fluid/operators/mul_op_npu.cc b/paddle/fluid/operators/mul_op_npu.cc
index cf057cc339..b52cba9cb0 100644
--- a/paddle/fluid/operators/mul_op_npu.cc
+++ b/paddle/fluid/operators/mul_op_npu.cc
@@ -140,19 +140,15 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
         // matmul
         if (dx) {
           // matmul [2, 5] * [12, 5] => [2, 12]
-          Tensor tmp_matmul(y->type());
-          tmp_matmul.Resize(
-              framework::make_ddim({dout->dims()[0], y->dims()[0]}));
-          tmp_matmul.mutable_data<T>(ctx.GetPlace());
+          dx->mutable_data<T>(ctx.GetPlace());
+          auto dx_dims = dx->dims();
+          dx->Resize(framework::make_ddim({dout->dims()[0], y->dims()[0]}));
           auto runner_matmul =
-              NpuOpRunner("MatMul", {*dout, *y}, {tmp_matmul},
+              NpuOpRunner("MatMul", {*dout, *y}, {*dx},
                           {{"transpose_x1", false}, {"transpose_x2", true}});
           runner_matmul.Run(stream);
           // reshape [2, 12] => [2, 3, 4]
-          dx->mutable_data(ctx.GetPlace(), x->type());
-          framework::TensorCopy(
-              tmp_matmul, ctx.GetPlace(),
-              ctx.template device_context<platform::DeviceContext>(), dx);
+          dx->Resize(dx_dims);
         }
 
         if (dy) {
@@ -193,18 +189,15 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
 
       if (dx) {
         // tmp_dout * y [6,5] * [4,5] => [6, 4]
-        Tensor tmp_matmul(y->type());
-        tmp_matmul.Resize(framework::make_ddim({dout_first_dim, y->dims()[0]}));
-        tmp_matmul.mutable_data<T>(ctx.GetPlace());
+        dx->mutable_data<T>(ctx.GetPlace());
+        auto dx_dims = dx->dims();
+        dx->Resize(framework::make_ddim({dout_first_dim, y->dims()[0]}));
         auto runner_matmul =
-            NpuOpRunner("MatMul", {tmp_dout, *y}, {tmp_matmul},
+            NpuOpRunner("MatMul", {tmp_dout, *y}, {*dx},
                         {{"transpose_x1", false}, {"transpose_x2", true}});
         runner_matmul.Run(stream);
-        // reshape [6,4] => [2, 3, 4]
-        dx->mutable_data(ctx.GetPlace(), x->type());
-        framework::TensorCopy(
-            tmp_matmul, ctx.GetPlace(),
-            ctx.template device_context<platform::DeviceContext>(), dx);
+        // reshape [2, 12] => [2, 3, 4]
+        dx->Resize(dx_dims);
       }
       if (dy) {
         // flatten x.shape [2,3,4] => [6, 4]
-- 
GitLab