From b1cc4a4608a4ab4f5fff8570a08814eb91332223 Mon Sep 17 00:00:00 2001 From: veyron95 <87417304+veyron95@users.noreply.github.com> Date: Mon, 16 Aug 2021 20:50:39 +0800 Subject: [PATCH] [NPU] Support npu op:(1)arg_min (2)arg_max (#34867) * [NPU] Support npu op:(1)arg_min (2)arg_max * Modify and add unit test cases * Modify unit test cases --- paddle/fluid/operators/arg_max_op_npu.cc | 54 ++++ paddle/fluid/operators/arg_min_op_npu.cc | 54 ++++ .../unittests/npu/test_arg_max_op_npu.py | 273 ++++++++++++++++++ .../unittests/npu/test_arg_min_op_npu.py | 273 ++++++++++++++++++ 4 files changed, 654 insertions(+) create mode 100644 paddle/fluid/operators/arg_max_op_npu.cc create mode 100644 paddle/fluid/operators/arg_min_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py create mode 100644 python/paddle/fluid/tests/unittests/npu/test_arg_min_op_npu.py diff --git a/paddle/fluid/operators/arg_max_op_npu.cc b/paddle/fluid/operators/arg_max_op_npu.cc new file mode 100644 index 0000000000..38f9813ad0 --- /dev/null +++ b/paddle/fluid/operators/arg_max_op_npu.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include "paddle/fluid/operators/arg_min_max_op_base.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class ArgMaxNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + int64_t axis = ctx.Attr("axis"); + auto dtype = ctx.Attr("dtype"); + + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + NpuOpRunner runner; + runner.SetType("ArgMaxV2") + .AddInput(*x) + .AddInput(std::vector{axis}) + .AddOutput(*out) + .AddAttr("dtype", dtype); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + arg_max, ops::ArgMaxNPUKernel, + ops::ArgMaxNPUKernel); diff --git a/paddle/fluid/operators/arg_min_op_npu.cc b/paddle/fluid/operators/arg_min_op_npu.cc new file mode 100644 index 0000000000..f776412c16 --- /dev/null +++ b/paddle/fluid/operators/arg_min_op_npu.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include "paddle/fluid/operators/arg_min_max_op_base.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class ArgMinNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + int64_t axis = ctx.Attr("axis"); + auto dtype = ctx.Attr("dtype"); + + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + NpuOpRunner runner; + runner.SetType("ArgMin") + .AddInput(*x) + .AddInput(std::vector{axis}) + .AddOutput(*out) + .AddAttr("dtype", dtype); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + arg_min, ops::ArgMinNPUKernel, + ops::ArgMinNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py new file mode 100644 index 0000000000..9bc46697c0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_arg_max_op_npu.py @@ -0,0 +1,273 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid.core as core + +paddle.enable_static() + + +class BaseTestCase(OpTest): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 1 + + def setUp(self): + self.initTestCase() + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + np.random.seed(2021) + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + if self.op_type == "arg_min": + self.outputs = {'Out': np.argmin(self.x, axis=self.axis)} + else: + self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} + + def test_check_output(self): + self.check_output_with_place(self.place) + + +# test argmax, dtype: float16 +class TestArgMaxFloat16Case1(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = -1 + + +class TestArgMaxFloat16Case2(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case3(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 1 + + +class TestArgMaxFloat16Case4(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 2 + + +class TestArgMaxFloat16Case5(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = -1 + + +class TestArgMaxFloat16Case6(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case7(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = 1 + + +class TestArgMaxFloat16Case8(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (1, ) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case9(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (2, ) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case10(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, ) + self.dtype = 'float16' + self.axis = 0 + + +# test argmax, dtype: float32 +class TestArgMaxFloat32Case1(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMaxFloat32Case2(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case3(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMaxFloat32Case4(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 2 + + +class TestArgMaxFloat32Case5(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMaxFloat32Case6(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case7(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMaxFloat32Case8(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (1, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case9(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (2, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case10(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxAPI(unittest.TestCase): + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + def setUp(self): + self.initTestCase() + self.__class__.use_npu = True + self.place = [paddle.NPUPlace(0)] + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax(numpy_input, axis=self.axis) + paddle_output = paddle.argmax(tensor_input, axis=self.axis) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + +class TestArgMaxAPI_2(unittest.TestCase): + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + self.keep_dims = True + + def setUp(self): + self.initTestCase() + self.__class__.use_npu = True + self.place = [paddle.NPUPlace(0)] + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax( + numpy_input, axis=self.axis).reshape(1, 4, 5) + paddle_output = paddle.argmax( + tensor_input, axis=self.axis, keepdim=self.keep_dims) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_arg_min_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_arg_min_op_npu.py new file mode 100644 index 0000000000..455f92b8ed --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_arg_min_op_npu.py @@ -0,0 +1,273 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid.core as core + +paddle.enable_static() + + +class BaseTestCase(OpTest): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 1 + + def setUp(self): + self.initTestCase() + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + np.random.seed(2021) + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + if self.op_type == "arg_min": + self.outputs = {'Out': np.argmin(self.x, axis=self.axis)} + else: + self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} + + def test_check_output(self): + self.check_output_with_place(self.place) + + +# test argmin, dtype: float16 +class TestArgMinFloat16Case1(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = -1 + + +class TestArgMinFloat16Case2(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMinFloat16Case3(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 1 + + +class TestArgMinFloat16Case4(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 2 + + +class TestArgMinFloat16Case5(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = -1 + + +class TestArgMinFloat16Case6(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMinFloat16Case7(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = 1 + + +class TestArgMinFloat16Case8(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (1, ) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMinFloat16Case9(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (2, ) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMinFloat16Case10(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, ) + self.dtype = 'float16' + self.axis = 0 + + +# test argmin, dtype: float32 +class TestArgMinFloat32Case1(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMinFloat32Case2(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMinFloat32Case3(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMinFloat32Case4(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 2 + + +class TestArgMinFloat32Case5(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMinFloat32Case6(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMinFloat32Case7(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMinFloat32Case8(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (1, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMinFloat32Case9(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (2, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMinFloat32Case10(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_min' + self.dims = (3, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMinAPI(unittest.TestCase): + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + def setUp(self): + self.initTestCase() + self.__class__.use_npu = True + self.place = [paddle.NPUPlace(0)] + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmin(numpy_input, axis=self.axis) + paddle_output = paddle.argmin(tensor_input, axis=self.axis) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + +class TestArgMaxAPI_2(unittest.TestCase): + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + self.keep_dims = True + + def setUp(self): + self.initTestCase() + self.__class__.use_npu = True + self.place = [paddle.NPUPlace(0)] + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmin( + numpy_input, axis=self.axis).reshape(1, 4, 5) + paddle_output = paddle.argmin( + tensor_input, axis=self.axis, keepdim=self.keep_dims) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == '__main__': + unittest.main() -- GitLab