From f91dfe1554733e3f9478dd7405bf75e39c9c62bb Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 21 Jun 2021 15:01:51 +0800 Subject: [PATCH] [NPU] optimize mul op, use BatchMatMul to realize (#33616) * use BatchMatMul * replace TensorCopy with ShareDataWith * remove check fp16 grad * fix format * add grad_check * fix grad check --- paddle/fluid/operators/mul_op_npu.cc | 132 ++++++++------- .../tests/unittests/npu/test_mul_op_npu.py | 158 +++++++++++++----- 2 files changed, 180 insertions(+), 110 deletions(-) diff --git a/paddle/fluid/operators/mul_op_npu.cc b/paddle/fluid/operators/mul_op_npu.cc index cfa75bc1ce1..9dcf012d512 100644 --- a/paddle/fluid/operators/mul_op_npu.cc +++ b/paddle/fluid/operators/mul_op_npu.cc @@ -46,11 +46,7 @@ class MulNPUKernel : public framework::OpKernel { Tensor tmp_x(x->type()); int64_t sec_dim = x->dims()[1] * x->dims()[2]; int64_t first_dim = x->dims()[0]; - tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); - tmp_x.mutable_data(ctx.GetPlace()); - framework::TensorCopy( - *x, ctx.GetPlace(), - ctx.template device_context(), &tmp_x); + tmp_x.ShareDataWith(*x); tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); out->mutable_data(ctx.GetPlace()); // matmul @@ -69,36 +65,39 @@ class MulNPUKernel : public framework::OpKernel { platform::errors::InvalidArgument( "now only support x_num_col_dims == 2: but got %d", x_num_col_dims)); - // flatten => x.shape=[6, 4] - Tensor tmp_x(x->type()); - int64_t first_dim = x->dims()[0] * x->dims()[1]; - int64_t sec_dim = x->dims()[2]; - tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); - tmp_x.mutable_data(ctx.GetPlace()); - framework::TensorCopy( - *x, ctx.GetPlace(), - ctx.template device_context(), &tmp_x); - tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); - - // matmul [6,4] , [4, 5] => [6, 5] - Tensor tmp_matmul(x->type()); - tmp_matmul.Resize(framework::make_ddim({first_dim, y->dims()[1]})); - tmp_matmul.mutable_data(ctx.GetPlace()); - - const auto& runner_matmul = - NpuOpRunner("MatMul", {tmp_x, *y}, {tmp_matmul}, - {{"transpose_x1", false}, {"transpose_x2", false}}); - - runner_matmul.Run(stream); - // reshape [6, 5] => [2, 3, 5] - (*out).Resize( - framework::make_ddim({x->dims()[0], x->dims()[1], y->dims()[1]})); - out->mutable_data(ctx.GetPlace(), x->type()); - framework::TensorCopy( - tmp_matmul, ctx.GetPlace(), - ctx.template device_context(), out); - (*out).Resize( - framework::make_ddim({x->dims()[0], x->dims()[1], y->dims()[1]})); + if (x->type() == framework::proto::VarType::FP16 && + y->type() == framework::proto::VarType::FP16) { + // NOTE: When the dim of the input and output shapes is inconsistent, + // (Boradcast) BatchMatMul NPU OP only support FP16. + out->mutable_data(ctx.GetPlace()); + const auto& runner = + NpuOpRunner("BatchMatMul", {*x, *y}, {*out}, + {{"adj_x1", false}, {"adj_x2", false}}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } else { + // flatten => x.shape=[6, 4] + Tensor tmp_x(x->type()); + int64_t first_dim = x->dims()[0] * x->dims()[1]; + int64_t sec_dim = x->dims()[2]; + tmp_x.ShareDataWith(*x); + tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); + + // matmul [6,4] , [4, 5] => [6, 5] + out->mutable_data(ctx.GetPlace()); + + Tensor tmp_out(x->type()); + tmp_out.ShareDataWith(*out); + tmp_out.Resize(framework::make_ddim({first_dim, y->dims()[1]})); + + const auto& runner_matmul = + NpuOpRunner("MatMul", {tmp_x, *y}, {tmp_out}, + {{"transpose_x1", false}, {"transpose_x2", false}}); + runner_matmul.Run(stream); + } } } }; @@ -142,14 +141,14 @@ class MulGradNPUKernel : public framework::OpKernel { if (dx) { // matmul [2, 5] * [12, 5] => [2, 12] dx->mutable_data(ctx.GetPlace()); - auto dx_dims = dx->dims(); - dx->Resize(framework::make_ddim({dout->dims()[0], y->dims()[0]})); + Tensor tmp_dx(x->type()); + tmp_dx.ShareDataWith(*dx); + tmp_dx.Resize(framework::make_ddim({dout->dims()[0], y->dims()[0]})); + const auto& runner_matmul = - NpuOpRunner("MatMul", {*dout, *y}, {*dx}, + NpuOpRunner("MatMul", {*dout, *y}, {tmp_dx}, {{"transpose_x1", false}, {"transpose_x2", true}}); runner_matmul.Run(stream); - // reshape [2, 12] => [2, 3, 4] - dx->Resize(dx_dims); } if (dy) { @@ -157,11 +156,7 @@ class MulGradNPUKernel : public framework::OpKernel { Tensor tmp_x(x->type()); int64_t sec_dim = x->dims()[1] * x->dims()[2]; int64_t first_dim = x->dims()[0]; - tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); - tmp_x.mutable_data(ctx.GetPlace()); - framework::TensorCopy( - *x, ctx.GetPlace(), - ctx.template device_context(), &tmp_x); + tmp_x.ShareDataWith(*x); tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); dy->mutable_data(ctx.GetPlace()); const auto& runner_dy = @@ -181,35 +176,42 @@ class MulGradNPUKernel : public framework::OpKernel { Tensor tmp_dout(x->type()); int64_t dout_first_dim = dout->dims()[0] * dout->dims()[1]; int64_t dout_sec_dim = dout->dims()[2]; - tmp_dout.Resize(framework::make_ddim({dout_first_dim, dout_sec_dim})); - tmp_dout.mutable_data(ctx.GetPlace()); - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), &tmp_dout); + tmp_dout.ShareDataWith(*dout); tmp_dout.Resize(framework::make_ddim({dout_first_dim, dout_sec_dim})); if (dx) { - // tmp_dout * y [6,5] * [4,5] => [6, 4] - dx->mutable_data(ctx.GetPlace()); - auto dx_dims = dx->dims(); - dx->Resize(framework::make_ddim({dout_first_dim, y->dims()[0]})); - const auto& runner_matmul = - NpuOpRunner("MatMul", {tmp_dout, *y}, {*dx}, - {{"transpose_x1", false}, {"transpose_x2", true}}); - runner_matmul.Run(stream); - // reshape [2, 12] => [2, 3, 4] - dx->Resize(dx_dims); + // tmp_dout * y [2, 3, 5] * [4,5] => [2, 3, 4] + if (dout->type() == framework::proto::VarType::FP16 && + y->type() == framework::proto::VarType::FP16) { + // NOTE: When the dim of the input and output shapes is inconsistent, + // (Boradcast) BatchMatMul NPU OP only support FP16. + dx->mutable_data(ctx.GetPlace()); + const auto& runner = + NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, + {{"adj_x1", false}, {"adj_x2", true}}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } else { + dx->mutable_data(ctx.GetPlace()); + Tensor tmp_dx(x->type()); + tmp_dx.ShareDataWith(*dx); + tmp_dx.Resize(framework::make_ddim({dout_first_dim, y->dims()[0]})); + + const auto& runner_matmul = + NpuOpRunner("MatMul", {tmp_dout, *y}, {tmp_dx}, + {{"transpose_x1", false}, {"transpose_x2", true}}); + runner_matmul.Run(stream); + } } if (dy) { // flatten x.shape [2,3,4] => [6, 4] Tensor tmp_x(x->type()); int64_t first_dim = x->dims()[0] * x->dims()[1]; int64_t sec_dim = x->dims()[2]; - tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); - tmp_x.mutable_data(ctx.GetPlace()); - framework::TensorCopy( - *x, ctx.GetPlace(), - ctx.template device_context(), &tmp_x); + tmp_x.ShareDataWith(*x); tmp_x.Resize(framework::make_ddim({first_dim, sec_dim})); // mamtul [6,4] [6,5] =>[4,5] dy->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py index 4fcfd33b32f..07f187a0f0d 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py @@ -18,7 +18,7 @@ import numpy as np import unittest import sys sys.path.append("..") -from op_test import OpTest +from op_test import OpTest, skip_check_grad_ci import paddle import paddle.fluid as fluid @@ -27,6 +27,7 @@ SEED = 2021 class TestMul(OpTest): + # case 1: (32, 5) * (5, 100) -> (32, 100) def config(self): self.x_shape = (32, 5) self.y_shape = (5, 100) @@ -46,7 +47,6 @@ class TestMul(OpTest): def set_npu(self): self.__class__.use_npu = True - self.__class__.no_need_check_grad = True def init_dtype(self): self.dtype = np.float32 @@ -54,25 +54,51 @@ class TestMul(OpTest): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False, atol=1e-5) - - # + def test_check_grad_normal(self): + self.check_grad_with_place( + self.place, ['X', 'Y'], + 'Out', + max_relative_error=0.0065, + check_dygraph=False) + + def test_check_grad_ingore_x(self): + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + max_relative_error=0.0065, + check_dygraph=False) + + def test_check_grad_ingore_y(self): + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + max_relative_error=0.0065, + check_dygraph=False) + + +@skip_check_grad_ci( + reason="Don't support grad checking for NPU OP with FP16 data type.") class TestMulFP16(TestMul): - """ - case 2 - """ - def init_dtype(self): self.dtype = np.float16 + def test_check_grad_normal(self): + pass -class TestMul3(TestMul): - """ - case 3 - """ + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + +class TestMul2(TestMul): + # case 2: (20, 2, 5) * (10, 50) -> (20, 50), x_num_col_dims = 1 def config(self): - self.x_shape = (2, 2, 5) - self.y_shape = (10, 5) + self.x_shape = (20, 2, 5) + self.y_shape = (10, 50) def setUp(self): self.set_npu() @@ -86,18 +112,32 @@ class TestMul3(TestMul): 'Y': np.random.random(self.y_shape).astype(self.dtype) } self.outputs = { - 'Out': np.dot(self.inputs['X'].reshape(2, 10), self.inputs['Y']) + 'Out': np.dot(self.inputs['X'].reshape(20, 10), self.inputs['Y']) } -class TestMul4(TestMul): - """ - case 4 - """ +@skip_check_grad_ci( + reason="Don't support grad checking for NPU OP with FP16 data type.") +class TestMul2FP16(TestMul2): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestMul3(TestMul): + # case 3: (20, 3, 4) * (4, 50) -> (20, 3, 50), x_num_col_dims = 2 def config(self): - self.x_shape = (2, 3, 4) - self.y_shape = (4, 5) + self.x_shape = (20, 3, 4) + self.y_shape = (4, 50) def setUp(self): self.set_npu() @@ -114,9 +154,28 @@ class TestMul4(TestMul): self.outputs = {'Out': np.matmul(self.inputs['X'], self.inputs['Y'])} +@skip_check_grad_ci( + reason="Don't support grad checking for NPU OP with FP16 data type.") +class TestMul3FP16(TestMul3): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") class TestMulNet(unittest.TestCase): + def init_dtype(self): + self.dtype = np.float32 + def _test(self, run_npu=True): main_prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -124,17 +183,17 @@ class TestMulNet(unittest.TestCase): startup_prog.random_seed = SEED np.random.seed(SEED) - a_np = np.random.random(size=(2, 3)).astype('float32') - b_np = np.random.random(size=(2, 3)).astype('float32') - c_np = np.random.random(size=(3, 2)).astype('float32') - d_np = np.random.random(size=(3, 2)).astype('float32') + a_np = np.random.random(size=(2, 3)).astype(self.dtype) + b_np = np.random.random(size=(2, 3)).astype(self.dtype) + c_np = np.random.random(size=(3, 2)).astype(self.dtype) + d_np = np.random.random(size=(3, 2)).astype(self.dtype) label_np = np.random.randint(2, size=(2, 1)).astype('int64') with paddle.static.program_guard(main_prog, startup_prog): - a = paddle.static.data(name="a", shape=[2, 3], dtype='float32') - b = paddle.static.data(name="b", shape=[2, 3], dtype='float32') - c = paddle.static.data(name="c", shape=[3, 2], dtype='float32') - d = paddle.static.data(name="d", shape=[3, 2], dtype='float32') + a = paddle.static.data(name="a", shape=[2, 3], dtype=self.dtype) + b = paddle.static.data(name="b", shape=[2, 3], dtype=self.dtype) + c = paddle.static.data(name="c", shape=[3, 2], dtype=self.dtype) + d = paddle.static.data(name="d", shape=[3, 2], dtype=self.dtype) label = paddle.static.data( name="label", shape=[2, 1], dtype='int64') @@ -176,6 +235,7 @@ class TestMulNet(unittest.TestCase): return pred_res, loss_res def test_npu(self): + self.init_dtype() cpu_pred, cpu_loss = self._test(False) npu_pred, npu_loss = self._test(True) @@ -186,6 +246,9 @@ class TestMulNet(unittest.TestCase): @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") class TestMulNet3_2(unittest.TestCase): + def init_dtype(self): + self.dtype = np.float32 + def _test(self, run_npu=True): main_prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -193,17 +256,17 @@ class TestMulNet3_2(unittest.TestCase): startup_prog.random_seed = SEED np.random.seed(SEED) - a_np = np.random.random(size=(2, 3, 4)).astype('float32') - b_np = np.random.random(size=(2, 3, 4)).astype('float32') - c_np = np.random.random(size=(12, 5)).astype('float32') - d_np = np.random.random(size=(12, 5)).astype('float32') + a_np = np.random.random(size=(2, 3, 4)).astype(self.dtype) + b_np = np.random.random(size=(2, 3, 4)).astype(self.dtype) + c_np = np.random.random(size=(12, 5)).astype(self.dtype) + d_np = np.random.random(size=(12, 5)).astype(self.dtype) label_np = np.random.randint(2, size=(2, 1)).astype('int64') with paddle.static.program_guard(main_prog, startup_prog): - a = paddle.static.data(name="a", shape=[2, 3, 4], dtype='float32') - b = paddle.static.data(name="b", shape=[2, 3, 4], dtype='float32') - c = paddle.static.data(name="c", shape=[12, 5], dtype='float32') - d = paddle.static.data(name="d", shape=[12, 5], dtype='float32') + a = paddle.static.data(name="a", shape=[2, 3, 4], dtype=self.dtype) + b = paddle.static.data(name="b", shape=[2, 3, 4], dtype=self.dtype) + c = paddle.static.data(name="c", shape=[12, 5], dtype=self.dtype) + d = paddle.static.data(name="d", shape=[12, 5], dtype=self.dtype) label = paddle.static.data( name="label", shape=[2, 1], dtype='int64') @@ -245,6 +308,7 @@ class TestMulNet3_2(unittest.TestCase): return pred_res, loss_res def test_npu(self): + self.init_dtype() cpu_pred, cpu_loss = self._test(False) npu_pred, npu_loss = self._test(True) @@ -256,6 +320,9 @@ class TestMulNet3_2(unittest.TestCase): @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") class TestMulNet3_2_xc2(unittest.TestCase): + def init_dtype(self): + self.dtype = np.float32 + def _test(self, run_npu=True): main_prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -263,17 +330,17 @@ class TestMulNet3_2_xc2(unittest.TestCase): startup_prog.random_seed = SEED np.random.seed(SEED) - a_np = np.random.random(size=(2, 3, 4)).astype('float32') - b_np = np.random.random(size=(2, 3, 4)).astype('float32') - c_np = np.random.random(size=(4, 5)).astype('float32') - d_np = np.random.random(size=(4, 5)).astype('float32') + a_np = np.random.random(size=(2, 3, 4)).astype(self.dtype) + b_np = np.random.random(size=(2, 3, 4)).astype(self.dtype) + c_np = np.random.random(size=(4, 5)).astype(self.dtype) + d_np = np.random.random(size=(4, 5)).astype(self.dtype) label_np = np.random.randint(2, size=(2, 1)).astype('int64') with paddle.static.program_guard(main_prog, startup_prog): - a = paddle.static.data(name="a", shape=[2, 3, 4], dtype='float32') - b = paddle.static.data(name="b", shape=[2, 3, 4], dtype='float32') - c = paddle.static.data(name="c", shape=[4, 5], dtype='float32') - d = paddle.static.data(name="d", shape=[4, 5], dtype='float32') + a = paddle.static.data(name="a", shape=[2, 3, 4], dtype=self.dtype) + b = paddle.static.data(name="b", shape=[2, 3, 4], dtype=self.dtype) + c = paddle.static.data(name="c", shape=[4, 5], dtype=self.dtype) + d = paddle.static.data(name="d", shape=[4, 5], dtype=self.dtype) label = paddle.static.data( name="label", shape=[2, 1], dtype='int64') @@ -316,6 +383,7 @@ class TestMulNet3_2_xc2(unittest.TestCase): return pred_res, loss_res def test_npu(self): + self.init_dtype() cpu_pred, cpu_loss = self._test(False) npu_pred, npu_loss = self._test(True) -- GitLab