From 2aedf16981325e08c1760e1b88cb63974515ccf6 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Tue, 3 Aug 2021 14:03:45 +0800 Subject: [PATCH] support more dim for mul op npu (#34546) * support more dim for mul op npu * update unit test according to reviewer's comment. --- paddle/fluid/operators/mul_op_npu.cc | 16 ++- .../tests/unittests/npu/test_mul_op_npu.py | 113 ++++++++++++++++++ 2 files changed, 124 insertions(+), 5 deletions(-) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py diff --git a/paddle/fluid/operators/mul_op_npu.cc b/paddle/fluid/operators/mul_op_npu.cc index 9dcf012d512..a0cdd69515d 100644 --- a/paddle/fluid/operators/mul_op_npu.cc +++ b/paddle/fluid/operators/mul_op_npu.cc @@ -41,10 +41,13 @@ class MulNPUKernel : public framework::OpKernel { {{"transpose_x1", false}, {"transpose_x2", false}}); runner.Run(stream); - } else if (x->dims().size() == 3 && y->dims().size() == 2) { + } else if (x->dims().size() >= 3 && y->dims().size() == 2) { // reshape Tensor tmp_x(x->type()); - int64_t sec_dim = x->dims()[1] * x->dims()[2]; + int64_t sec_dim = x->dims()[1]; + for (auto i = 2; i < x->dims().size(); i++) { + sec_dim *= x->dims()[i]; + } int64_t first_dim = x->dims()[0]; tmp_x.ShareDataWith(*x); tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); @@ -56,7 +59,7 @@ class MulNPUKernel : public framework::OpKernel { runner.Run(stream); } else { PADDLE_THROW( - platform::errors::InvalidArgument("npu error: not suppert dims")); + platform::errors::InvalidArgument("npu error: not support dims")); } // to do other } else if (x->dims().size() == 3 && y->dims().size() == 2) { @@ -135,7 +138,7 @@ class MulGradNPUKernel : public framework::OpKernel { runner_dy.Run(stream); } - } else if (x->dims().size() == 3 && y->dims().size() == 2) { + } else if (x->dims().size() >= 3 && y->dims().size() == 2) { // flatten => x.shape=[6, 4] // matmul if (dx) { @@ -154,7 +157,10 @@ class MulGradNPUKernel : public framework::OpKernel { if (dy) { // flatten Tensor tmp_x(x->type()); - int64_t sec_dim = x->dims()[1] * x->dims()[2]; + int64_t sec_dim = x->dims()[1]; + for (auto i = 2; i < x->dims().size(); i++) { + sec_dim *= x->dims()[i]; + } int64_t first_dim = x->dims()[0]; tmp_x.ShareDataWith(*x); tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); diff --git a/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py old mode 100644 new mode 100755 index cb58a2a8d44..b6e3134439d --- a/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py @@ -170,6 +170,44 @@ class TestMul3FP16(TestMul3): pass +class TestMul4(TestMul): + # case 4: (20, 2, 2, 3) * (12, 50) -> (20, 50), x_num_col_dims = 1 + def config(self): + self.x_shape = (20, 2, 2, 3) + self.y_shape = (12, 50) + + def setUp(self): + self.set_npu() + self.op_type = "mul" + self.place = paddle.NPUPlace(0) + self.init_dtype() + self.config() + np.random.seed(SEED) + self.inputs = { + 'X': np.random.random(self.x_shape).astype(self.dtype), + 'Y': np.random.random(self.y_shape).astype(self.dtype) + } + self.outputs = { + 'Out': np.dot(self.inputs['X'].reshape(20, 12), self.inputs['Y']) + } + + +@skip_check_grad_ci( + reason="Don't support grad checking for NPU OP with FP16 data type.") +class TestMul4FP16(TestMul4): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + class TestMulNet(unittest.TestCase): def init_dtype(self): self.dtype = np.float32 @@ -385,5 +423,80 @@ class TestMulNet3_2_xc2(unittest.TestCase): self.assertTrue(np.allclose(npu_loss, cpu_loss)) +class TestMulNet4_2(unittest.TestCase): + def init_dtype(self): + self.dtype = np.float32 + + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(12, 5)).astype(self.dtype) + b_np = np.random.random(size=(12, 5)).astype(self.dtype) + c_np = np.random.random(size=(12, 5)).astype(self.dtype) + d_np = np.random.random(size=(12, 5)).astype(self.dtype) + label_np = np.random.randint(2, size=(2, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[12, 5], dtype=self.dtype) + b = paddle.static.data(name="b", shape=[12, 5], dtype=self.dtype) + c = paddle.static.data(name="c", shape=[12, 5], dtype=self.dtype) + d = paddle.static.data(name="d", shape=[12, 5], dtype=self.dtype) + label = paddle.static.data( + name="label", shape=[2, 1], dtype='int64') + + sum_1 = paddle.add(a, b) # [12, 5] + sum_2 = paddle.add(c, d) # [12, 5] + fc_1 = fluid.layers.fc(input=sum_1, size=2) # [12, 2] + fc_1_re_shape = paddle.reshape(fc_1, shape=[2, 3, 2, 2]) + fc_2 = fluid.layers.fc(input=sum_2, size=2) # [12, 2] + result = paddle.fluid.layers.mul(fc_1_re_shape, + fc_2) # [2, 3, 2, 2] * [12, 2] + + prediction = fluid.layers.fc(input=result, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("testMulNet4_2 tart run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run(main_prog, + feed={ + "a": a_np, + "b": b_np, + "c": c_np, + "d": d_np, + "label": label_np + }, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + self.init_dtype() + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose( + npu_pred, cpu_pred, atol=1e-5)) # atol needed on cann 20.3 + self.assertTrue(np.allclose(npu_loss, cpu_loss, atol=1e-5)) + + if __name__ == '__main__': unittest.main() -- GitLab