From bb6339656cec2504149db845b0ee931a5c2cd527 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Thu, 2 Sep 2021 20:16:16 +0800 Subject: [PATCH] [NPU] Support npu kernel for gather_nd op (#34800) * [NPU] Support npu kernel for gather_ng op * [NPU] Support npu kernel for gather_nd op * [NPU] Support npu kernel for gather_nd and gather_nd_grad op * update py format error. * modify gather_nd_op_npu * modify gather_nd 910 test * modify gather_nd 910 test Co-authored-by: xiaoxiaohehe001 --- paddle/fluid/operators/gather_nd_op_npu.cc | 120 ++++++++ .../unittests/npu/test_gather_nd_op_npu.py | 289 ++++++++++++++++++ 2 files changed, 409 insertions(+) create mode 100644 paddle/fluid/operators/gather_nd_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_gather_nd_op_npu.py diff --git a/paddle/fluid/operators/gather_nd_op_npu.cc b/paddle/fluid/operators/gather_nd_op_npu.cc new file mode 100644 index 00000000000..d04e0bce36f --- /dev/null +++ b/paddle/fluid/operators/gather_nd_op_npu.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_nd_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class GatherNdNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *out = ctx.Output("Out"); + + out->template mutable_data(ctx.GetPlace()); + + if (x->numel() == 0) return; + + if (index->numel() == 0) { + framework::TensorCopy(*x, ctx.GetPlace(), ctx.device_context(), out); + return; + } + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + const auto &runner = NpuOpRunner("GatherNd", {*x, *index}, {*out}, {}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class GatherNdGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *index = ctx.Input("Index"); + auto *x = ctx.Input("X"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *p = dx->mutable_data(ctx.GetPlace()); + + if (dx->numel() == 0) return; + + if (index->numel() == 0) { + framework::TensorCopy(*dout, ctx.GetPlace(), ctx.device_context(), dx); + return; + } + + framework::Tensor tmp_tensor(index->type()); + framework::Tensor tmp_tensor2(dout->type()); + const auto index_dims = index->dims(); + if (index_dims.size() == 1) { + tmp_tensor.ShareDataWith(*index); + std::vector new_dim = {1, index_dims[0]}; + tmp_tensor.Resize(framework::make_ddim(new_dim)); + index = &tmp_tensor; + + tmp_tensor2.ShareDataWith(*dout); + std::vector new_dim2{1}; + for (int i = index->numel(); i < x->dims().size(); i++) { + new_dim2.push_back(x->dims()[i]); + } + tmp_tensor2.Resize(framework::make_ddim(new_dim2)); + dout = &tmp_tensor2; + } + + auto stream = + ctx.template device_context() + .stream(); + + platform::NPUMemsetAsync(static_cast(p), 0, dx->numel() * sizeof(T), + stream); + + const auto &runner_scatter = NpuOpRunner( + "ScatterNdAdd", {*dx, *index, *dout}, {*dx}, {{"use_locking", false}}); + runner_scatter.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + gather_nd, ops::GatherNdNPUKernel, + ops::GatherNdNPUKernel); + +REGISTER_OP_NPU_KERNEL( + gather_nd_grad, + ops::GatherNdGradNPUKernel, + ops::GatherNdGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_gather_nd_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_gather_nd_op_npu.py new file mode 100644 index 00000000000..b124a546241 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_gather_nd_op_npu.py @@ -0,0 +1,289 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle.fluid as fluid +import paddle + + +def gather_nd_grad(x, index): + dout_shape = index.shape[:-1] + x.shape[index.shape[-1]:] + numel = 1 + for i in dout_shape: + numel = numel * i + dout = np.full(dout_shape, 1. / numel) + dx = np.full_like(x, 0) + + index = tuple(index.reshape(-1, index.shape[-1]).T) + np.add.at(dx, index, dout) + + return dx + + +def test_class1(op_type, typename): + class TestGatherNdOpWithEmptyIndex(OpTest): + #Index has empty element, which means copy entire tensor + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.random((5, 20)).astype(typename) + self.inputs = { + 'X': xnp, + 'Index': np.array([[], []]).astype("int32") + } + self.outputs = { + 'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :])) + } + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_1".format(op_type, typename) + TestGatherNdOpWithEmptyIndex.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithEmptyIndex + + +def test_class2(op_type, typename): + class TestGatherNdOpWithIndex1(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.random((5, 20)).astype(typename) + self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")} + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_2".format(op_type, typename) + TestGatherNdOpWithIndex1.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithIndex1 + + +def test_class3(op_type, typename): + class TestGatherNdOpWithLowIndex(OpTest): + #Index has low rank, X has high rank + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) + index = np.array([[1], [2]]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': xnp[tuple(index.T)]} + self.x_grad = gather_nd_grad(xnp, index) + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place( + self.place, ['X'], 'Out', user_defined_grads=[self.x_grad]) + + cls_name = "{0}_{1}_3".format(op_type, typename) + TestGatherNdOpWithLowIndex.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithLowIndex + + +def test_class4(op_type, typename): + class TestGatherNdOpIndex1(OpTest): + #Index has low rank, X has high rank + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) + index = np.array([1, 2]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + + self.outputs = {'Out': xnp[tuple(index.T)]} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_4".format(op_type, typename) + TestGatherNdOpIndex1.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpIndex1 + + +def test_class5(op_type, typename): + class TestGatherNdOpWithSameIndexAsX(OpTest): + #Index has same rank as X's rank + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) + index = np.array([[1, 1], [2, 1]]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22] + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_5".format(op_type, typename) + TestGatherNdOpWithSameIndexAsX.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithSameIndexAsX + + +def test_class6(op_type, typename): + class TestGatherNdOpWithHighRankSame(OpTest): + #Both Index and X have high rank, and Rank(Index) = Rank(X) + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + shape = (5, 2, 3, 1, 10) + xnp = np.random.rand(*shape).astype(typename) + index = np.vstack([np.random.randint( + 0, s, size=2) for s in shape]).T + + self.inputs = {'X': xnp, 'Index': index.astype("int32")} + self.outputs = {'Out': xnp[tuple(index.T)]} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_6".format(op_type, typename) + TestGatherNdOpWithHighRankSame.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithHighRankSame + + +def test_class7(op_type, typename): + class TestGatherNdOpWithHighRankDiff(OpTest): + #Both Index and X have high rank, Rank(Index) < Rank(X) + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + shape = (2, 3, 4, 1, 10) + xnp = np.random.rand(*shape).astype(typename) + index = np.vstack( + [np.random.randint( + 0, s, size=200) for s in shape]).T + index_re = index.reshape([20, 5, 2, 5]) + + self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} + self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_7".format(op_type, typename) + TestGatherNdOpWithHighRankDiff.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithHighRankDiff + + +class TestGatherNdAPI(unittest.TestCase): + def test_imperative(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([[1]]) + input = fluid.dygraph.to_variable(input_1) + index = fluid.dygraph.to_variable(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([3, 4]) + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + +for _typename in {'float16', 'float32'}: + test_class1('gather_nd', _typename) + test_class2('gather_nd', _typename) + test_class3('gather_nd', _typename) + test_class4('gather_nd', _typename) + test_class5('gather_nd', _typename) + test_class6('gather_nd', _typename) + test_class7('gather_nd', _typename) + +if __name__ == "__main__": + unittest.main() -- GitLab