diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 3eb833b6b3d35697ba571d81ceedb95185856728..51e4bd3ac41c92b1afdd216876d45c30579d8fe5 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -265,6 +265,7 @@ function(op_library TARGET) if (WITH_ASCEND_CL) if (CANN_VERSION LESS 504000) list(REMOVE_ITEM npu_cc_srcs "multinomial_op_npu.cc") + list(REMOVE_ITEM npu_cc_srcs "take_along_axis_op_npu.cc") endif() endif() # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`. diff --git a/paddle/fluid/operators/take_along_axis_op_npu.cc b/paddle/fluid/operators/take_along_axis_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d8ef0675c19e938b00b3b98c7db28b98635380c --- /dev/null +++ b/paddle/fluid/operators/take_along_axis_op_npu.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// TODO(Aganlengzi): delete this macro control and remove REMOVE_ITEM in +// cmake/operators.cmake when Paddle supports +#if (CANN_VERSION_CODE >= 504000) + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/npu/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class NPUTakeAlongAxisKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("Input"); + auto axis = ctx.Attr("Axis"); + auto index = ctx.Input("Index"); + auto result = ctx.Output("Result"); + result->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + const auto& runner = NpuOpRunner("GatherElements", {*input, *index}, + {*result}, {{"dim", axis}}); + runner.Run(stream); + } +}; + +template +class NPUTakeAlongAxisGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto axis = ctx.Attr("Axis"); + auto index = ctx.Input("Index"); + auto result_grad = ctx.Input(framework::GradVarName("Result")); + + auto input_grad = ctx.Output(framework::GradVarName("Input")); + input_grad->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + const auto& runner = + NpuOpRunner("ScatterAddWithAxis", {*input_grad, *index, *result_grad}, + {*input_grad}, {{"axis", axis}}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_NPU_KERNEL( + take_along_axis, + ops::NPUTakeAlongAxisKernel, + ops::NPUTakeAlongAxisKernel, + ops::NPUTakeAlongAxisKernel, + ops::NPUTakeAlongAxisKernel) +REGISTER_OP_NPU_KERNEL( + take_along_axis_grad, + ops::NPUTakeAlongAxisGradKernel, + ops::NPUTakeAlongAxisGradKernel, + ops::NPUTakeAlongAxisGradKernel, + ops::NPUTakeAlongAxisGradKernel) + +#endif diff --git a/python/paddle/fluid/tests/unittests/npu/test_take_along_axis_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_take_along_axis_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..4aad02f7df06edf4b576dac2daff487412ad3d70 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_take_along_axis_op_npu.py @@ -0,0 +1,130 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.framework import core + +paddle.enable_static() + + +@unittest.skip(reason="Skip unsupported ut, need paddle surpport cann 5.0.4+") +class TestTakeAlongAxisOp(OpTest): + def setUp(self): + self.set_npu() + self.init_data() + self.op_type = "take_along_axis" + self.xnp = np.random.random(self.x_shape).astype(self.x_type) + self.target = np.take_along_axis(self.xnp, self.index, self.axis) + broadcast_shape_list = list(self.x_shape) + broadcast_shape_list[self.axis] = 1 + self.braodcast_shape = tuple(broadcast_shape_list) + self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape) + self.inputs = { + 'Input': self.xnp, + 'Index': self.index_broadcast, + } + self.attrs = {'Axis': self.axis} + self.outputs = {'Result': self.target} + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['Input'], 'Result') + + def init_data(self): + self.x_type = "float64" + self.x_shape = (5, 5, 5) + self.index_type = "int32" + self.index = np.array( + [[[1]], [[1]], [[2]], [[4]], [[3]]]).astype(self.index_type) + self.axis = 2 + self.axis_type = "int64" + + +class TestCase1(TestTakeAlongAxisOp): + def init_data(self): + self.x_type = "float64" + self.x_shape = (5, 5, 5) + self.index_type = "int32" + self.index = np.array([[[0, 1, 2, 1, 4]]]).astype(self.index_type) + self.axis = 0 + self.axis_type = "int64" + + +@unittest.skip(reason="Skip unsupported ut, need paddle surpport cann 5.0.4+") +class TestTakeAlongAxisAPI(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [3, 3] + self.index_shape = [1, 3] + self.index_np = np.array([[0, 1, 2]]).astype('int64') + self.x_np = np.random.random(self.shape).astype(np.float32) + self.place = paddle.NPUPlace(0) + self.axis = 0 + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + index = paddle.fluid.data('Index', self.index_shape, "int64") + out = paddle.take_along_axis(x, index, self.axis) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np, + 'Index': self.index_np}, + fetch_list=[out]) + out_ref = np.array( + np.take_along_axis(self.x_np, self.index_np, self.axis)) + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-03), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + x_tensor = paddle.to_tensor(self.x_np) + self.index = paddle.to_tensor(self.index_np) + out = paddle.take_along_axis(x_tensor, self.index, self.axis) + out_ref = np.array( + np.take_along_axis(self.x_np, self.index_np, self.axis)) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True) + paddle.enable_static() + + +@unittest.skip(reason="Skip unsupported ut, need paddle surpport cann 5.0.4+") +class TestTakeAlongAxisAPICase1(TestTakeAlongAxisAPI): + def setUp(self): + np.random.seed(0) + self.shape = [2, 2] + self.index_shape = [4, 2] + self.index_np = np.array( + [[0, 0], [1, 0], [0, 0], [1, 0]]).astype('int64') + self.x_np = np.random.random(self.shape).astype(np.float32) + self.place = paddle.NPUPlace(0) + self.axis = 0 + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()