未验证 提交 f91dfe15 编写于 作者: P pangyoki 提交者: GitHub

[NPU] optimize mul op, use BatchMatMul to realize (#33616)

* use BatchMatMul

* replace TensorCopy with ShareDataWith

* remove check fp16 grad

* fix format

* add grad_check

* fix grad check
上级 f88af205
......@@ -46,11 +46,7 @@ class MulNPUKernel : public framework::OpKernel<T> {
Tensor tmp_x(x->type());
int64_t sec_dim = x->dims()[1] * x->dims()[2];
int64_t first_dim = x->dims()[0];
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
tmp_x.mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(
*x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
tmp_x.ShareDataWith(*x);
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
out->mutable_data<T>(ctx.GetPlace());
// matmul
......@@ -69,36 +65,39 @@ class MulNPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"now only support x_num_col_dims == 2: but got %d",
x_num_col_dims));
if (x->type() == framework::proto::VarType::FP16 &&
y->type() == framework::proto::VarType::FP16) {
// NOTE: When the dim of the input and output shapes is inconsistent,
// (Boradcast) BatchMatMul NPU OP only support FP16.
out->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("BatchMatMul", {*x, *y}, {*out},
{{"adj_x1", false}, {"adj_x2", false}});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
} else {
// flatten => x.shape=[6, 4]
Tensor tmp_x(x->type());
int64_t first_dim = x->dims()[0] * x->dims()[1];
int64_t sec_dim = x->dims()[2];
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
tmp_x.mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(
*x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
tmp_x.ShareDataWith(*x);
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
// matmul [6,4] , [4, 5] => [6, 5]
Tensor tmp_matmul(x->type());
tmp_matmul.Resize(framework::make_ddim({first_dim, y->dims()[1]}));
tmp_matmul.mutable_data<T>(ctx.GetPlace());
out->mutable_data<T>(ctx.GetPlace());
Tensor tmp_out(x->type());
tmp_out.ShareDataWith(*out);
tmp_out.Resize(framework::make_ddim({first_dim, y->dims()[1]}));
const auto& runner_matmul =
NpuOpRunner("MatMul", {tmp_x, *y}, {tmp_matmul},
NpuOpRunner("MatMul", {tmp_x, *y}, {tmp_out},
{{"transpose_x1", false}, {"transpose_x2", false}});
runner_matmul.Run(stream);
// reshape [6, 5] => [2, 3, 5]
(*out).Resize(
framework::make_ddim({x->dims()[0], x->dims()[1], y->dims()[1]}));
out->mutable_data(ctx.GetPlace(), x->type());
framework::TensorCopy(
tmp_matmul, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
(*out).Resize(
framework::make_ddim({x->dims()[0], x->dims()[1], y->dims()[1]}));
}
}
}
};
......@@ -142,14 +141,14 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
if (dx) {
// matmul [2, 5] * [12, 5] => [2, 12]
dx->mutable_data<T>(ctx.GetPlace());
auto dx_dims = dx->dims();
dx->Resize(framework::make_ddim({dout->dims()[0], y->dims()[0]}));
Tensor tmp_dx(x->type());
tmp_dx.ShareDataWith(*dx);
tmp_dx.Resize(framework::make_ddim({dout->dims()[0], y->dims()[0]}));
const auto& runner_matmul =
NpuOpRunner("MatMul", {*dout, *y}, {*dx},
NpuOpRunner("MatMul", {*dout, *y}, {tmp_dx},
{{"transpose_x1", false}, {"transpose_x2", true}});
runner_matmul.Run(stream);
// reshape [2, 12] => [2, 3, 4]
dx->Resize(dx_dims);
}
if (dy) {
......@@ -157,11 +156,7 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
Tensor tmp_x(x->type());
int64_t sec_dim = x->dims()[1] * x->dims()[2];
int64_t first_dim = x->dims()[0];
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
tmp_x.mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(
*x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
tmp_x.ShareDataWith(*x);
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
dy->mutable_data<T>(ctx.GetPlace());
const auto& runner_dy =
......@@ -181,35 +176,42 @@ class MulGradNPUKernel : public framework::OpKernel<T> {
Tensor tmp_dout(x->type());
int64_t dout_first_dim = dout->dims()[0] * dout->dims()[1];
int64_t dout_sec_dim = dout->dims()[2];
tmp_dout.Resize(framework::make_ddim({dout_first_dim, dout_sec_dim}));
tmp_dout.mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &tmp_dout);
tmp_dout.ShareDataWith(*dout);
tmp_dout.Resize(framework::make_ddim({dout_first_dim, dout_sec_dim}));
if (dx) {
// tmp_dout * y [6,5] * [4,5] => [6, 4]
// tmp_dout * y [2, 3, 5] * [4,5] => [2, 3, 4]
if (dout->type() == framework::proto::VarType::FP16 &&
y->type() == framework::proto::VarType::FP16) {
// NOTE: When the dim of the input and output shapes is inconsistent,
// (Boradcast) BatchMatMul NPU OP only support FP16.
dx->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx},
{{"adj_x1", false}, {"adj_x2", true}});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
} else {
dx->mutable_data<T>(ctx.GetPlace());
auto dx_dims = dx->dims();
dx->Resize(framework::make_ddim({dout_first_dim, y->dims()[0]}));
Tensor tmp_dx(x->type());
tmp_dx.ShareDataWith(*dx);
tmp_dx.Resize(framework::make_ddim({dout_first_dim, y->dims()[0]}));
const auto& runner_matmul =
NpuOpRunner("MatMul", {tmp_dout, *y}, {*dx},
NpuOpRunner("MatMul", {tmp_dout, *y}, {tmp_dx},
{{"transpose_x1", false}, {"transpose_x2", true}});
runner_matmul.Run(stream);
// reshape [2, 12] => [2, 3, 4]
dx->Resize(dx_dims);
}
}
if (dy) {
// flatten x.shape [2,3,4] => [6, 4]
Tensor tmp_x(x->type());
int64_t first_dim = x->dims()[0] * x->dims()[1];
int64_t sec_dim = x->dims()[2];
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
tmp_x.mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(
*x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &tmp_x);
tmp_x.ShareDataWith(*x);
tmp_x.Resize(framework::make_ddim({first_dim, sec_dim}));
// mamtul [6,4] [6,5] =>[4,5]
dy->mutable_data<T>(ctx.GetPlace());
......
......@@ -18,7 +18,7 @@ import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
......@@ -27,6 +27,7 @@ SEED = 2021
class TestMul(OpTest):
# case 1: (32, 5) * (5, 100) -> (32, 100)
def config(self):
self.x_shape = (32, 5)
self.y_shape = (5, 100)
......@@ -46,7 +47,6 @@ class TestMul(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float32
......@@ -54,25 +54,51 @@ class TestMul(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-5)
#
def test_check_grad_normal(self):
self.check_grad_with_place(
self.place, ['X', 'Y'],
'Out',
max_relative_error=0.0065,
check_dygraph=False)
def test_check_grad_ingore_x(self):
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
max_relative_error=0.0065,
check_dygraph=False)
def test_check_grad_ingore_y(self):
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
max_relative_error=0.0065,
check_dygraph=False)
@skip_check_grad_ci(
reason="Don't support grad checking for NPU OP with FP16 data type.")
class TestMulFP16(TestMul):
"""
case 2
"""
def init_dtype(self):
self.dtype = np.float16
def test_check_grad_normal(self):
pass
class TestMul3(TestMul):
"""
case 3
"""
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestMul2(TestMul):
# case 2: (20, 2, 5) * (10, 50) -> (20, 50), x_num_col_dims = 1
def config(self):
self.x_shape = (2, 2, 5)
self.y_shape = (10, 5)
self.x_shape = (20, 2, 5)
self.y_shape = (10, 50)
def setUp(self):
self.set_npu()
......@@ -86,18 +112,32 @@ class TestMul3(TestMul):
'Y': np.random.random(self.y_shape).astype(self.dtype)
}
self.outputs = {
'Out': np.dot(self.inputs['X'].reshape(2, 10), self.inputs['Y'])
'Out': np.dot(self.inputs['X'].reshape(20, 10), self.inputs['Y'])
}
class TestMul4(TestMul):
"""
case 4
"""
@skip_check_grad_ci(
reason="Don't support grad checking for NPU OP with FP16 data type.")
class TestMul2FP16(TestMul2):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestMul3(TestMul):
# case 3: (20, 3, 4) * (4, 50) -> (20, 3, 50), x_num_col_dims = 2
def config(self):
self.x_shape = (2, 3, 4)
self.y_shape = (4, 5)
self.x_shape = (20, 3, 4)
self.y_shape = (4, 50)
def setUp(self):
self.set_npu()
......@@ -114,9 +154,28 @@ class TestMul4(TestMul):
self.outputs = {'Out': np.matmul(self.inputs['X'], self.inputs['Y'])}
@skip_check_grad_ci(
reason="Don't support grad checking for NPU OP with FP16 data type.")
class TestMul3FP16(TestMul3):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestMulNet(unittest.TestCase):
def init_dtype(self):
self.dtype = np.float32
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
......@@ -124,17 +183,17 @@ class TestMulNet(unittest.TestCase):
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(2, 3)).astype('float32')
b_np = np.random.random(size=(2, 3)).astype('float32')
c_np = np.random.random(size=(3, 2)).astype('float32')
d_np = np.random.random(size=(3, 2)).astype('float32')
a_np = np.random.random(size=(2, 3)).astype(self.dtype)
b_np = np.random.random(size=(2, 3)).astype(self.dtype)
c_np = np.random.random(size=(3, 2)).astype(self.dtype)
d_np = np.random.random(size=(3, 2)).astype(self.dtype)
label_np = np.random.randint(2, size=(2, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[2, 3], dtype='float32')
b = paddle.static.data(name="b", shape=[2, 3], dtype='float32')
c = paddle.static.data(name="c", shape=[3, 2], dtype='float32')
d = paddle.static.data(name="d", shape=[3, 2], dtype='float32')
a = paddle.static.data(name="a", shape=[2, 3], dtype=self.dtype)
b = paddle.static.data(name="b", shape=[2, 3], dtype=self.dtype)
c = paddle.static.data(name="c", shape=[3, 2], dtype=self.dtype)
d = paddle.static.data(name="d", shape=[3, 2], dtype=self.dtype)
label = paddle.static.data(
name="label", shape=[2, 1], dtype='int64')
......@@ -176,6 +235,7 @@ class TestMulNet(unittest.TestCase):
return pred_res, loss_res
def test_npu(self):
self.init_dtype()
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
......@@ -186,6 +246,9 @@ class TestMulNet(unittest.TestCase):
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestMulNet3_2(unittest.TestCase):
def init_dtype(self):
self.dtype = np.float32
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
......@@ -193,17 +256,17 @@ class TestMulNet3_2(unittest.TestCase):
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(2, 3, 4)).astype('float32')
b_np = np.random.random(size=(2, 3, 4)).astype('float32')
c_np = np.random.random(size=(12, 5)).astype('float32')
d_np = np.random.random(size=(12, 5)).astype('float32')
a_np = np.random.random(size=(2, 3, 4)).astype(self.dtype)
b_np = np.random.random(size=(2, 3, 4)).astype(self.dtype)
c_np = np.random.random(size=(12, 5)).astype(self.dtype)
d_np = np.random.random(size=(12, 5)).astype(self.dtype)
label_np = np.random.randint(2, size=(2, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[2, 3, 4], dtype='float32')
b = paddle.static.data(name="b", shape=[2, 3, 4], dtype='float32')
c = paddle.static.data(name="c", shape=[12, 5], dtype='float32')
d = paddle.static.data(name="d", shape=[12, 5], dtype='float32')
a = paddle.static.data(name="a", shape=[2, 3, 4], dtype=self.dtype)
b = paddle.static.data(name="b", shape=[2, 3, 4], dtype=self.dtype)
c = paddle.static.data(name="c", shape=[12, 5], dtype=self.dtype)
d = paddle.static.data(name="d", shape=[12, 5], dtype=self.dtype)
label = paddle.static.data(
name="label", shape=[2, 1], dtype='int64')
......@@ -245,6 +308,7 @@ class TestMulNet3_2(unittest.TestCase):
return pred_res, loss_res
def test_npu(self):
self.init_dtype()
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
......@@ -256,6 +320,9 @@ class TestMulNet3_2(unittest.TestCase):
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestMulNet3_2_xc2(unittest.TestCase):
def init_dtype(self):
self.dtype = np.float32
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
......@@ -263,17 +330,17 @@ class TestMulNet3_2_xc2(unittest.TestCase):
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(2, 3, 4)).astype('float32')
b_np = np.random.random(size=(2, 3, 4)).astype('float32')
c_np = np.random.random(size=(4, 5)).astype('float32')
d_np = np.random.random(size=(4, 5)).astype('float32')
a_np = np.random.random(size=(2, 3, 4)).astype(self.dtype)
b_np = np.random.random(size=(2, 3, 4)).astype(self.dtype)
c_np = np.random.random(size=(4, 5)).astype(self.dtype)
d_np = np.random.random(size=(4, 5)).astype(self.dtype)
label_np = np.random.randint(2, size=(2, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[2, 3, 4], dtype='float32')
b = paddle.static.data(name="b", shape=[2, 3, 4], dtype='float32')
c = paddle.static.data(name="c", shape=[4, 5], dtype='float32')
d = paddle.static.data(name="d", shape=[4, 5], dtype='float32')
a = paddle.static.data(name="a", shape=[2, 3, 4], dtype=self.dtype)
b = paddle.static.data(name="b", shape=[2, 3, 4], dtype=self.dtype)
c = paddle.static.data(name="c", shape=[4, 5], dtype=self.dtype)
d = paddle.static.data(name="d", shape=[4, 5], dtype=self.dtype)
label = paddle.static.data(
name="label", shape=[2, 1], dtype='int64')
......@@ -316,6 +383,7 @@ class TestMulNet3_2_xc2(unittest.TestCase):
return pred_res, loss_res
def test_npu(self):
self.init_dtype()
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册