diff --git a/paddle/fluid/operators/index_select_op_npu.cc b/paddle/fluid/operators/index_select_op_npu.cc index 8df6c4e5d9ea7203dee3958545c55a33899ae231..b624d03cc8555938ee6f527b890c0575d59799e3 100644 --- a/paddle/fluid/operators/index_select_op_npu.cc +++ b/paddle/fluid/operators/index_select_op_npu.cc @@ -21,12 +21,12 @@ namespace operators { template class IndexSelectNPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* index = ctx.Input("Index"); auto dim = ctx.Attr("dim"); - auto *out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto stream = @@ -43,7 +43,104 @@ class IndexSelectNPUKernel : public framework::OpKernel { } }; -// todo: add class 'IndexSelectGradNPUKernel' here. +template +class IndexSelectGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_grad = ctx.Output(framework::GradVarName("X")); + auto* index = ctx.Input("Index"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + + auto stream = + ctx.template device_context() + .stream(); + + auto x_dims = x_grad->dims(); + auto out_dims = out_grad->dims(); + + int dim = ctx.Attr("dim"); + if (dim < 0) { + dim += out_dims.size(); + } + + Tensor casted_index; + if (index->type() != framework::proto::VarType::INT32) { + casted_index.mutable_data(index->dims(), ctx.GetPlace()); + const auto& cast_runner = NpuOpRunner("Cast", {*index}, {casted_index}, + {{"dst_type", ACL_INT32}}); + cast_runner.Run(stream); + } else { + casted_index.ShareDataWith(*index); + } + + if (dim == 0) { + x_grad->mutable_data(ctx.GetPlace()); + const auto& zeros_runner = NpuOpRunner("ZerosLike", {*x_grad}, {*x_grad}); + zeros_runner.Run(stream); + + NpuOpRunner runner; + runner.SetType("UnsortedSegmentSum") + .AddInput(*out_grad) + .AddInput(casted_index) + .AddInput(std::vector{x_dims[dim]}) + .AddOutput(*x_grad); + runner.Run(stream); + } else { + Tensor transed_out_grad; + std::vector in_trans_perm; + in_trans_perm.push_back(dim); + for (int i = 0; i < out_dims.size(); ++i) { + if (i == dim) continue; + in_trans_perm.push_back(i); + } + framework::DDim transed_out_dims(out_dims); + for (size_t i = 0; i < in_trans_perm.size(); ++i) { + transed_out_dims[i] = out_dims[in_trans_perm[i]]; + } + transed_out_grad.mutable_data(transed_out_dims, ctx.GetPlace()); + framework::NPUAttributeMap in_trans_attr = {{"perm", in_trans_perm}}; + + const auto& in_trans_runner = NpuOpRunner( + "TransposeD", {*out_grad}, {transed_out_grad}, in_trans_attr); + in_trans_runner.Run(stream); + + Tensor sum_out; + framework::DDim sum_dims(x_dims); + sum_dims[0] = x_dims[dim]; + auto idx = 1; + for (int i = 0; i < x_dims.size(); ++i) { + if (i == dim) continue; + sum_dims[idx++] = x_dims[i]; + } + sum_out.mutable_data(sum_dims, ctx.GetPlace()); + const auto& zeros_runner = NpuOpRunner("ZerosLike", {sum_out}, {sum_out}); + zeros_runner.Run(stream); + + NpuOpRunner runner; + runner.SetType("UnsortedSegmentSum") + .AddInput(transed_out_grad) + .AddInput(casted_index) + .AddInput(std::vector{x_dims[dim]}) + .AddOutput(sum_out); + runner.Run(stream); + + std::vector out_trans_perm; + for (int i = 1; i < 1 + dim; ++i) { + out_trans_perm.push_back(i); + } + out_trans_perm.push_back(0); + for (int i = 1 + dim; i < x_dims.size(); ++i) { + out_trans_perm.push_back(i); + } + framework::NPUAttributeMap out_trans_attr = {{"perm", out_trans_perm}}; + x_grad->mutable_data(ctx.GetPlace()); + const auto& out_trans_runner = + NpuOpRunner("TransposeD", {sum_out}, {*x_grad}, out_trans_attr); + out_trans_runner.Run(stream); + } + } +}; } // namespace operators } // namespace paddle @@ -54,4 +151,8 @@ REGISTER_OP_NPU_KERNEL( ops::IndexSelectNPUKernel, ops::IndexSelectNPUKernel, ops::IndexSelectNPUKernel); -// todo: register npu index_select_grad kernel here. +REGISTER_OP_NPU_KERNEL( + index_select_grad, + ops::IndexSelectGradNPUKernel, + ops::IndexSelectGradNPUKernel, + ops::IndexSelectGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py index ff0d57d1d4da1028d0db28ee90f6a950ce33b9ea..57293ad5e56335aeb04949177b632e1e6763fefe 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py @@ -35,7 +35,10 @@ class TestNPUIndexSelect(OpTest): x_np = np.random.random(self.x_shape).astype(self.x_type) index_np = np.random.randint( - low=0, high=self.x_shape[self.dim], size=self.index_size) + low=0, + high=self.x_shape[self.dim], + size=self.index_size, + dtype=self.index_type) # compute real output as baseline. outer_loop = np.prod(self.x_shape[:self.dim]) @@ -56,18 +59,14 @@ class TestNPUIndexSelect(OpTest): self.attrs = {'dim': self.dim} self.outputs = {'Out': out} - # todo: comment second line when index_select grad npu op is ready. def set_npu(self): self.__class__.use_npu = True - self.__class__.no_need_check_grad = True def test_check_output(self): self.check_output_with_place(self.place) - # todo: replace first line with second line when index_select grad npu op is ready. def test_check_grad(self): - pass - #self.check_grad_with_place(self.place, ['X'], 'Out') + self.check_grad_with_place(self.place, ['X'], 'Out') def config(self): self.x_shape = (100, 4, 5) @@ -86,6 +85,24 @@ class TestNPUIndexSelectCase2(TestNPUIndexSelect): self.index_size = 10 +class TestNPUIndexSelectCase3(TestNPUIndexSelect): + def config(self): + self.dim = 0 + self.x_type = np.float32 + self.index_type = np.int32 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + +class TestNPUIndexSelectCase4(TestNPUIndexSelect): + def config(self): + self.dim = -1 + self.x_type = np.float32 + self.index_type = np.int32 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + class TestNPUIndexSelectAPI(unittest.TestCase): def input_data(self): self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],