未验证 提交 a66b9fba 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] modify transpose2 and index_select_grad kernels for model xlnet (#36214)

* [NPU] modify transpose2 and index_select_grad kernels for model xlnet

* add transpose2 int64_t unit test

* add more transpose2 unit tests

* update test_transpose_op_npu.py
上级 5e0f199a
......@@ -99,10 +99,11 @@ class IndexSelectGradNPUKernel : public framework::OpKernel<T> {
transed_out_dims[i] = out_dims[in_trans_perm[i]];
}
transed_out_grad.mutable_data<T>(transed_out_dims, ctx.GetPlace());
framework::NPUAttributeMap in_trans_attr = {{"perm", in_trans_perm}};
const auto& in_trans_runner = NpuOpRunner(
"TransposeD", {*out_grad}, {transed_out_grad}, in_trans_attr);
NpuOpRunner in_trans_runner;
in_trans_runner.SetType("Transpose")
.AddInput(*out_grad)
.AddInput(std::move(in_trans_perm))
.AddOutput(transed_out_grad);
in_trans_runner.Run(stream);
Tensor sum_out;
......@@ -133,10 +134,12 @@ class IndexSelectGradNPUKernel : public framework::OpKernel<T> {
for (int i = 1 + dim; i < x_dims.size(); ++i) {
out_trans_perm.push_back(i);
}
framework::NPUAttributeMap out_trans_attr = {{"perm", out_trans_perm}};
x_grad->mutable_data<T>(ctx.GetPlace());
const auto& out_trans_runner =
NpuOpRunner("TransposeD", {sum_out}, {*x_grad}, out_trans_attr);
NpuOpRunner out_trans_runner;
out_trans_runner.SetType("Transpose")
.AddInput(sum_out)
.AddInput(std::move(out_trans_perm))
.AddOutput(*x_grad);
out_trans_runner.Run(stream);
}
}
......
......@@ -27,9 +27,12 @@ class TransposeNPUKernel : public framework::OpKernel<T> {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
framework::NPUAttributeMap attr_input = {{"perm", axis}};
out->mutable_data<T>(ctx.device_context().GetPlace());
const auto& runner = NpuOpRunner("TransposeD", {*x}, {*out}, attr_input);
NpuOpRunner runner;
runner.SetType("Transpose")
.AddInput(*x)
.AddInput(std::move(axis))
.AddOutput(*out);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -51,9 +54,11 @@ class TransposeGradNPUKernel : public framework::OpKernel<T> {
reversed_axis[axis[i]] = i;
}
x_grad->mutable_data<T>(ctx.GetPlace());
framework::NPUAttributeMap attr_input = {{"perm", reversed_axis}};
const auto& runner =
NpuOpRunner("TransposeD", {*out_grad}, {*x_grad}, attr_input);
NpuOpRunner runner;
runner.SetType("Transpose")
.AddInput(*out_grad)
.AddInput(std::move(reversed_axis))
.AddOutput(*x_grad);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -72,11 +77,17 @@ REGISTER_OP_NPU_KERNEL(
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>,
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
#endif
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, uint8_t>,
ops::TransposeNPUKernel<paddle::platform::NPUDeviceContext, int8_t>);
REGISTER_OP_NPU_KERNEL(transpose2_grad, ops::TransposeGradNPUKernel<float>,
ops::TransposeGradNPUKernel<paddle::platform::float16>,
ops::TransposeGradNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::TransposeGradNPUKernel<int64_t>,
#endif
ops::TransposeGradNPUKernel<uint8_t>,
ops::TransposeGradNPUKernel<int8_t>);
......@@ -31,40 +31,104 @@ class TestTransposeOp(OpTest):
self.op_type = "transpose2"
self.place = paddle.NPUPlace(0)
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.init_shape_axis()
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)}
self.attrs = {'axis': [0, 2, 1, 3], 'data_format': 'AnyLayout'}
self.outputs = {'Out': self.out}
self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)}
self.attrs = {'axis': self.axis, 'data_format': 'AnyLayout'}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
def set_npu(self):
self.__class__.use_npu = True
def init_kernel_type(self):
self.use_mkldnn = False
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [8, 512, 12, 64]).astype(self.dtype)
self.out = np.transpose(self.x, [0, 2, 1, 3])
def init_dtype(self):
self.dtype = np.float32
def init_axis(self):
self.axis = -1
def init_shape_axis(self):
self.shape = (3, 40)
self.axis = (1, 0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestCase0(TestTransposeOp):
def init_shape_axis(self):
self.shape = (100, )
self.axis = (0, )
class TestCase1(TestTransposeOp):
def init_shape_axis(self):
self.shape = (3, 4, 10)
self.axis = (0, 2, 1)
class TestCase2(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
class TestCase5(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 16, 96)
self.axis = (0, 2, 1)
class TestTransposeOpFP16(TestTransposeOp):
no_need_check_grad = True
class TestCase6(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 10, 12, 16)
self.axis = (3, 1, 2, 0)
class TestCase7(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 10, 2, 16)
self.axis = (0, 1, 3, 2)
class TestCase8(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6, 7)
class TestCase9(TestTransposeOp):
def init_shape_axis(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestTransposeOpFP16(TestTransposeOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
pass
class TestTransposeOpInt64(TestTransposeOp):
def init_dtype(self):
self.dtype = np.int64
def test_check_grad(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册