diff --git a/paddle/fluid/operators/argsort_op_npu.cc b/paddle/fluid/operators/argsort_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..e36dd322e0ea1d1f018564473dd9a3f6b7a7734c --- /dev/null +++ b/paddle/fluid/operators/argsort_op_npu.cc @@ -0,0 +1,261 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/argsort_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class ArgsortNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + output->mutable_data(ctx.GetPlace()); + auto* indices = ctx.Output("Indices"); + indices->mutable_data(ctx.GetPlace()); + + int32_t axis = ctx.Attr("axis"); + auto in_dims = indices->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + bool descending = ctx.Attr("descending"); + auto stream = + ctx.template device_context() + .stream(); + framework::NPUAttributeMap sort_attr_input = { + {"axis", static_cast(-1)}, {"descending", descending}}; + + if (axis == -1 || axis + 1 == in_dims.size()) { + const auto& sort_runner = + NpuOpRunner("Sort", {*input}, {*output, *indices}, sort_attr_input); + sort_runner.Run(stream); + } else { + // transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + framework::NPUAttributeMap trans_attr_input = {{"perm", trans}}; + Tensor trans_input; + trans_input.mutable_data(trans_dims, ctx.GetPlace()); + const auto& trans_input_runner = + NpuOpRunner("TransposeD", {*input}, {trans_input}, trans_attr_input); + trans_input_runner.Run(stream); + Tensor trans_indices; + trans_indices.mutable_data(trans_dims, ctx.GetPlace()); + const auto& trans_indice_runner = NpuOpRunner( + "TransposeD", {*indices}, {trans_indices}, trans_attr_input); + trans_indice_runner.Run(stream); + Tensor trans_output; + trans_output.mutable_data(trans_dims, ctx.GetPlace()); + const auto& trans_output_runner = NpuOpRunner( + "TransposeD", {*output}, {trans_output}, trans_attr_input); + trans_output_runner.Run(stream); + const auto& sort_runner = + NpuOpRunner("Sort", {trans_input}, {trans_output, trans_indices}, + sort_attr_input); + sort_runner.Run(stream); + // transpose back + const auto& trans_indices_back_runner = NpuOpRunner( + "TransposeD", {trans_indices}, {*indices}, trans_attr_input); + trans_indices_back_runner.Run(stream); + const auto& trans_output_back_runner = NpuOpRunner( + "TransposeD", {trans_output}, {*output}, trans_attr_input); + trans_output_back_runner.Run(stream); + } + } +}; + +template +static void ReshapeNPU(const framework::Tensor* input, + const std::vector& input_shapes, + framework::Tensor* output) { + output->ShareDataWith(*input); + output->Resize(framework::make_ddim(std::move(input_shapes))); +} + +template +static void FullAssignNPU(const framework::ExecutionContext& ctx, + Type ind_lastdim, Type outer_dim, + const framework::DDim& trans_dims, + const framework::Tensor* input, + const framework::Tensor* indices, + framework::Tensor* t_out) { + // reshape input + Type input_shape = ind_lastdim * outer_dim; + std::vector input_shapes = {input_shape}; + Tensor input_reshape_tensor(input->type()); + ReshapeNPU(input, input_shapes, &input_reshape_tensor); + // reshape index + std::vector index_shapes = {outer_dim, ind_lastdim}; + framework::DDim ind_2d = framework::make_ddim({outer_dim, ind_lastdim}); + Tensor ind_2d_tensor(indices->type()); + ReshapeNPU(indices, index_shapes, &ind_2d_tensor); + // range_flatten_index + std::vector range_flatten_index; + for (Type i = 0; i < input_shape; i += ind_lastdim) { + range_flatten_index.push_back(static_cast(i)); + } + Tensor range_flatten_index_tensor(framework::proto::VarType::INT32); + range_flatten_index_tensor.Resize(framework::make_ddim({outer_dim})); + range_flatten_index_tensor.mutable_data( + {static_cast(range_flatten_index.size())}, ctx.GetPlace()); + TensorFromVector(range_flatten_index, ctx.device_context(), + &range_flatten_index_tensor); + Tensor range_flatten_index_expand_tensor(range_flatten_index_tensor.type()); + std::vector flatten_shape = {outer_dim, 1}; + ReshapeNPU(&range_flatten_index_tensor, flatten_shape, + &range_flatten_index_expand_tensor); + auto stream = + ctx.template device_context() + .stream(); + Tensor ind_2d_add_tensor; + ind_2d_add_tensor.mutable_data(ind_2d, ctx.GetPlace()); + const auto& runner_ind_2d_tensor = NpuOpRunner( + std::string("Add"), {ind_2d_tensor, range_flatten_index_expand_tensor}, + {ind_2d_add_tensor}, {}); + runner_ind_2d_tensor.Run(stream); + Tensor ind_reshape_tensor(ind_2d_add_tensor.type()); + ReshapeNPU(&ind_2d_add_tensor, input_shapes, &ind_reshape_tensor); + Tensor ind_reshape_expand_tensor(ind_reshape_tensor.type()); + std::vector ind_shape = {input_shape, 1}; + ReshapeNPU(&ind_reshape_tensor, ind_shape, &ind_reshape_expand_tensor); + // expand_index + Tensor input_scatter_tensor; + input_scatter_tensor.Resize({input_shape}); + input_scatter_tensor.mutable_data(ctx.GetPlace()); + Tensor input_scatter_tensor_ori; + input_scatter_tensor_ori.Resize({input_shape}); + input_scatter_tensor_ori.mutable_data(ctx.GetPlace()); + std::vector trans_shapes; + + for (int i = 0; i < trans_dims.size(); i++) { + trans_shapes.push_back(trans_dims[i]); + } + NpuOpRunner runner_scatter; + runner_scatter.SetType("TensorScatterUpdate") + .AddInput(input_scatter_tensor_ori) + .AddInput(ind_reshape_expand_tensor) + .AddInput(input_reshape_tensor) + .AddOutput(input_scatter_tensor); + runner_scatter.Run(stream); + framework::TensorCopy(input_scatter_tensor, ctx.GetPlace(), + ctx.template device_context(), + t_out); + t_out->Resize(framework::make_ddim(trans_shapes)); +} + +template +class ArgsortGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); + auto in_dims = indices->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + auto place = ctx.GetPlace(); + + auto stream = + ctx.template device_context() + .stream(); + dX->mutable_data(ctx.GetPlace()); + Tensor dxt; + dxt.mutable_data(dX->dims(), place); + const auto& runner_flatten = + NpuOpRunner(std::string("Flatten"), {*dX}, {dxt}, {}); + runner_flatten.Run(stream); + FillNpuTensorWithConstant(&dxt, static_cast(0)); + if (dO->numel() == 0) return; + // Do full assig n + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t outer_dim = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t ind_lastdim = in_dims[in_dims.size() - 1]; + FullAssignNPU(ctx, ind_lastdim, outer_dim, in_dims, dO, + indices, dX); + + } else { + // If not full assign do transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + std::vector axis; + for (size_t i = 0; i < trans.size(); i++) { + axis.push_back(in_dims[trans[i]]); + } + framework::NPUAttributeMap attr_input = {{"perm", trans}}; + Tensor trans_dO; + trans_dO.mutable_data(trans_dims, ctx.GetPlace()); + Tensor trans_ind; + trans_ind.mutable_data(trans_dims, ctx.GetPlace()); + // Do transpose + const auto& runner_transpose_dx = NpuOpRunner( + std::string("TransposeD"), {*dO}, {trans_dO}, {attr_input}); + runner_transpose_dx.Run(stream); + const auto& runner_transpose_ind = NpuOpRunner( + std::string("TransposeD"), {*indices}, {trans_ind}, {attr_input}); + runner_transpose_ind.Run(stream); + + const int64_t outer_dim = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t ind_lastdim = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + + FullAssignNPU(ctx, ind_lastdim, outer_dim, trans_dims, + &trans_dO, &trans_ind, &tmp_out); + + // transpose back + const auto& runner_transpose_out = NpuOpRunner( + std::string("TransposeD"), {tmp_out}, {*dX}, {attr_input}); + runner_transpose_out.Run(stream); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + argsort, ops::ArgsortNPUKernel, + ops::ArgsortNPUKernel); + +REGISTER_OP_NPU_KERNEL(argsort_grad, + ops::ArgsortGradNPUKernel, + ops::ArgsortGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_argsort_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_argsort_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..824266578b9e571ace99db01b8ecc95827e1afe3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_argsort_op_npu.py @@ -0,0 +1,215 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest, _set_use_system_allocator +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + +from paddle.fluid import ParamAttr +from paddle.fluid.framework import Program, grad_var_name +from paddle.fluid.executor import Executor +from paddle.fluid.backward import append_backward + +paddle.enable_static() + + +class TestArgsortOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "argsort" + self.place = paddle.NPUPlace(0) + self.init_dtype() + self.init_inputshape() + self.init_axis() + self.init_direction() + + self.x = np.random.random(self.input_shape).astype(self.dtype) + self.inputs = {"X": self.x} + self.attrs = {"axis": self.axis, "descending": self.descending} + self.get_output() + self.outputs = {"Out": self.sorted_x, "Indices": self.indices} + + def get_output(self): + if self.descending: + self.indices = np.flip( + np.argsort( + self.x, kind='heapsort', axis=self.axis), self.axis) + self.sorted_x = np.flip( + np.sort( + self.x, kind='heapsort', axis=self.axis), self.axis) + else: + self.indices = np.argsort(self.x, kind='heapsort', axis=self.axis) + self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis) + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def init_kernel_type(self): + self.use_mkldnn = False + + def init_inputshape(self): + self.input_shape = (2, 2, 2, 3, 3) + + def init_dtype(self): + self.dtype = np.float16 + + def init_axis(self): + self.axis = -1 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_direction(self): + self.descending = False + + +class TestArgsortOpAxis0NPU(TestArgsortOp): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpAxis1NPU(TestArgsortOp): + def init_axis(self): + self.axis = 1 + + +class TestArgsortOpAxis2NPU(TestArgsortOp): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1NPU(TestArgsortOp): + def init_axis(self): + self.axis = -1 + + +class TestArgsortOpAxisNeg2NPU(TestArgsortOp): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpDescendingAxisNPU(TestArgsortOp): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0NPU(TestArgsortOpAxis0NPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1NPU(TestArgsortOpAxis1NPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2NPU(TestArgsortOpAxis2NPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1NPU(TestArgsortOpAxisNeg1NPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2NPU(TestArgsortOpAxisNeg2NPU): + def init_direction(self): + self.descending = True + + +# liurui25: argsort of npu has bug with type fp32, +# it will change the type from fp32 to fp16, +# so the check_output_with_place add thw atol +# this test is only used to test the grad +# issue: https://gitee.com/ascend/modelzoo/issues/I44I7K + + +class TestArgsortOpAxis0NPUFP32(TestArgsortOp): + def init_axis(self): + self.axis = 0 + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_grad(self): + self.check_grad_with_place(self.place, ["X"], "Out") + + +class TestArgsortOpAxis1NPUFP32(TestArgsortOpAxis0NPUFP32): + def init_axis(self): + self.axis = 1 + + +class TestArgsortOpAxis2NPUFP32(TestArgsortOpAxis0NPUFP32): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1NPUFP32(TestArgsortOpAxis0NPUFP32): + def init_axis(self): + self.axis = -1 + + +class TestArgsortOpAxisNeg2NPUFP32(TestArgsortOpAxis0NPUFP32): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpDescendingAxisNPUFP32(TestArgsortOpAxis0NPUFP32): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0NPUFP32(TestArgsortOpAxis0NPUFP32): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1NPUFP32(TestArgsortOpAxis1NPUFP32): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2NPUFP32(TestArgsortOpAxis2NPUFP32): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1NPUFP32(TestArgsortOpAxisNeg1NPUFP32): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2NPUFP32(TestArgsortOpAxisNeg2NPUFP32): + def init_direction(self): + self.descending = True + + +if __name__ == '__main__': + unittest.main()