From 3e088aafe4a0c5c169cfa2cc21cc047cfc320081 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Thu, 25 Nov 2021 16:49:53 +0800 Subject: [PATCH] [NPU] add int64 support for argsort op (#37434) * [NPU] add int64 support for argsort op * [NPU] delete debug codes --- paddle/fluid/operators/argsort_op_npu.cc | 129 +++++++++++++----- .../unittests/npu/test_argsort_op_npu.py | 82 +++++++++++ 2 files changed, 179 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/argsort_op_npu.cc b/paddle/fluid/operators/argsort_op_npu.cc index f2a57b4b9bd..477b9363ebe 100644 --- a/paddle/fluid/operators/argsort_op_npu.cc +++ b/paddle/fluid/operators/argsort_op_npu.cc @@ -46,6 +46,18 @@ static void CastToInt64(const framework::ExecutionContext& ctx, .Run(stream); } +static void CastToFP32(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& in, + Tensor* out) { + out->mutable_data(ctx.GetPlace()); + NpuOpRunner runner; + runner.SetType("Cast") + .AddInput(in) + .AddOutput(*out) + .AddAttr("dst_type", ACL_FLOAT) + .Run(stream); +} + template class ArgsortNPUKernel : public framework::OpKernel { public: @@ -66,41 +78,91 @@ class ArgsortNPUKernel : public framework::OpKernel { Tensor indices_tmp(framework::proto::VarType::INT32); indices_tmp.Resize(indices->dims()); - if (axis == -1 || axis + 1 == in_dims.size()) { - output->mutable_data(ctx.GetPlace()); - indices_tmp.mutable_data(ctx.GetPlace()); - const auto& runner = - NpuOpRunner("Sort", {*input}, {*output, indices_tmp}, attr); - runner.Run(stream); - } else { - std::vector perm; - for (int64_t i = 0; i < in_dims.size(); i++) { - perm.emplace_back(i); + if (input->type() == framework::proto::VarType::INT64) { + Tensor input_fp32(framework::proto::VarType::FP32); + input_fp32.Resize(input->dims()); + CastToFP32(ctx, stream, *input, &input_fp32); + + Tensor output_fp32(framework::proto::VarType::FP32); + output_fp32.Resize(output->dims()); + + if (axis == -1 || axis + 1 == in_dims.size()) { + output_fp32.mutable_data(ctx.GetPlace()); + indices_tmp.mutable_data(ctx.GetPlace()); + const auto& runner = + NpuOpRunner("Sort", {input_fp32}, {output_fp32, indices_tmp}, attr); + runner.Run(stream); + + CastToInt64(ctx, stream, output_fp32, output); + } else { + std::vector perm; + for (int64_t i = 0; i < in_dims.size(); i++) { + perm.emplace_back(i); + } + std::swap(perm[axis], perm[in_dims.size() - 1]); + + std::vector shape; + for (size_t i = 0; i < perm.size(); i++) { + shape.emplace_back(in_dims[perm[i]]); + } + auto trans_dims = framework::make_ddim(shape); + + Tensor trans_input(input_fp32.type()); + trans_input.Resize(trans_dims); + TranposeNPU(ctx, stream, &perm, input_fp32, &trans_input); + + Tensor trans_output(input_fp32.type()); + Tensor trans_indices(framework::proto::VarType::INT32); + trans_output.mutable_data(trans_dims, ctx.GetPlace()); + trans_indices.mutable_data(trans_dims, ctx.GetPlace()); + + const auto& runner = NpuOpRunner("Sort", {trans_input}, + {trans_output, trans_indices}, attr); + runner.Run(stream); + + TranposeNPU(ctx, stream, &perm, trans_output, &output_fp32); + TranposeNPU(ctx, stream, &perm, trans_indices, &indices_tmp); + + CastToInt64(ctx, stream, output_fp32, output); } - std::swap(perm[axis], perm[in_dims.size() - 1]); - - std::vector shape; - for (size_t i = 0; i < perm.size(); i++) { - shape.emplace_back(in_dims[perm[i]]); + } else { + if (axis == -1 || axis + 1 == in_dims.size()) { + output->mutable_data(ctx.GetPlace()); + indices_tmp.mutable_data(ctx.GetPlace()); + const auto& runner = + NpuOpRunner("Sort", {*input}, {*output, indices_tmp}, attr); + runner.Run(stream); + } else { + std::vector perm; + for (int64_t i = 0; i < in_dims.size(); i++) { + perm.emplace_back(i); + } + std::swap(perm[axis], perm[in_dims.size() - 1]); + + std::vector shape; + for (size_t i = 0; i < perm.size(); i++) { + shape.emplace_back(in_dims[perm[i]]); + } + auto trans_dims = framework::make_ddim(shape); + + Tensor trans_input(input->type()); + trans_input.Resize(trans_dims); + TranposeNPU(ctx, stream, &perm, *input, &trans_input); + + Tensor trans_output(input->type()); + Tensor trans_indices(framework::proto::VarType::INT32); + trans_output.mutable_data(trans_dims, ctx.GetPlace()); + trans_indices.mutable_data(trans_dims, ctx.GetPlace()); + + const auto& runner = NpuOpRunner("Sort", {trans_input}, + {trans_output, trans_indices}, attr); + runner.Run(stream); + + TranposeNPU(ctx, stream, &perm, trans_output, output); + TranposeNPU(ctx, stream, &perm, trans_indices, &indices_tmp); } - auto trans_dims = framework::make_ddim(shape); - - Tensor trans_input(input->type()); - trans_input.Resize(trans_dims); - TranposeNPU(ctx, stream, &perm, *input, &trans_input); - - Tensor trans_output(input->type()); - Tensor trans_indices(framework::proto::VarType::INT32); - trans_output.mutable_data(trans_dims, ctx.GetPlace()); - trans_indices.mutable_data(trans_dims, ctx.GetPlace()); - - const auto& runner = NpuOpRunner("Sort", {trans_input}, - {trans_output, trans_indices}, attr); - runner.Run(stream); - - TranposeNPU(ctx, stream, &perm, trans_output, output); - TranposeNPU(ctx, stream, &perm, trans_indices, &indices_tmp); } + CastToInt64(ctx, stream, indices_tmp, indices); } }; @@ -208,6 +270,9 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL(argsort, ops::ArgsortNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ArgsortNPUKernel, +#endif ops::ArgsortNPUKernel); REGISTER_OP_NPU_KERNEL(argsort_grad, ops::ArgsortGradNPUKernel, diff --git a/python/paddle/fluid/tests/unittests/npu/test_argsort_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_argsort_op_npu.py index 2589b2a316a..ebabea93dd0 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_argsort_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_argsort_op_npu.py @@ -209,5 +209,87 @@ class TestArgsortOpDescendingAxisNeg2NPUFP32(TestArgsortOpAxisNeg2NPUFP32): self.descending = True +# test cases for int64 +class TestArgsortOpAxis0NPUINT64(TestArgsortOp): + def setUp(self): + self.set_npu() + self.op_type = "argsort" + self.place = paddle.NPUPlace(0) + self.init_dtype() + self.init_inputshape() + self.init_axis() + self.init_direction() + + self.x = np.random.randint( + low=-100, high=100, size=self.input_shape, + dtype=self.dtype).astype(self.dtype) + self.inputs = {"X": self.x} + self.attrs = {"axis": self.axis, "descending": self.descending} + self.get_output() + self.outputs = {"Out": self.sorted_x, "Indices": self.indices} + + def init_axis(self): + self.axis = 0 + + def init_dtype(self): + self.dtype = np.int64 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + def set_npu(self): + self.__class__.use_npu = True + + +class TestArgsortOpAxis1NPUINT64(TestArgsortOpAxis0NPUINT64): + def init_axis(self): + self.axis = 1 + + +class TestArgsortOpAxis2NPUINT64(TestArgsortOpAxis0NPUINT64): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1NPUINT64(TestArgsortOpAxis0NPUINT64): + def init_axis(self): + self.axis = -1 + + +class TestArgsortOpAxisNeg2NPUINT64(TestArgsortOpAxis0NPUINT64): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpDescendingAxisNPUINT64(TestArgsortOpAxis0NPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0NPUINT64(TestArgsortOpAxis0NPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1NPUINT64(TestArgsortOpAxis1NPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2NPUINT64(TestArgsortOpAxis2NPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1NPUINT64(TestArgsortOpAxisNeg1NPUINT64): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2NPUINT64(TestArgsortOpAxisNeg2NPUINT64): + def init_direction(self): + self.descending = True + + if __name__ == '__main__': unittest.main() -- GitLab