From 79be842762abba3d00ff6c412a64707a7787e30d Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Tue, 10 Aug 2021 19:34:31 +0800 Subject: [PATCH] [NPU] Support npu kernel for flatten_contiguous_range op, test=develop (#34642) * fix npu compile error, test=develop * [NPU] Support npu kernel for flatten_contiguous_range op, test=develop * [NPU] Support npu kernel for flatten_contiguous_range op, test=develop * [NPU] Support npu kernel for flatten_contiguous_range op, test=develop * [NPU] Support npu kernel for flatten_contiguous_range op, test=develop * [NPU] Support npu kernel for flatten_contiguous_range op, test=develop * [NPU] Support npu kernel for flatten_contiguous_range op, test=develop * [NPU] Support npu kernel for flatten_contiguous_range op, test=develop * Update flatten_op_npu.cc * Update flatten_op_npu.cc Co-authored-by: qili93 --- paddle/fluid/operators/flatten_op_npu.cc | 39 +++ paddle/fluid/operators/npu_op_runner.cc | 1 + .../test_flatten_contiguous_range_op_npu.py | 318 ++++++++++++++++++ .../test_flatten_contiguous_range_op.py | 2 +- 4 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py diff --git a/paddle/fluid/operators/flatten_op_npu.cc b/paddle/fluid/operators/flatten_op_npu.cc index 385dad530d9..1569760fe3b 100644 --- a/paddle/fluid/operators/flatten_op_npu.cc +++ b/paddle/fluid/operators/flatten_op_npu.cc @@ -54,6 +54,30 @@ class Flatten2GradNPUKernel : public framework::OpKernel { } }; +using Tensor = framework::Tensor; + +template +class FlattenContiguousRangeNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); + auto *Out = ctx.Output("Out"); + int start_axis = ctx.Attr("start_axis"); + int stop_axis = ctx.Attr("stop_axis"); + + Out->mutable_data(ctx.GetPlace()); + + const auto &runner = + NpuOpRunner("FlattenV2", {*X}, {*Out}, + {{"axis", static_cast(start_axis)}, + {"end_axis", static_cast(stop_axis)}}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + } // namespace operators } // namespace paddle @@ -71,3 +95,18 @@ REGISTER_OP_NPU_KERNEL(flatten2_grad, ops::Flatten2GradNPUKernel, ops::Flatten2GradNPUKernel, ops::Flatten2GradNPUKernel, ops::Flatten2GradNPUKernel); + +REGISTER_OP_NPU_KERNEL( + flatten_contiguous_range, + ops::FlattenContiguousRangeNPUKernel, + ops::FlattenContiguousRangeNPUKernel, + ops::FlattenContiguousRangeNPUKernel, + ops::FlattenContiguousRangeNPUKernel, + ops::FlattenContiguousRangeNPUKernel, + ops::FlattenContiguousRangeNPUKernel); diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc index a134b542c24..bb6549c1119 100644 --- a/paddle/fluid/operators/npu_op_runner.cc +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -33,6 +33,7 @@ static std::map DTYPE_2_ACL_DTYPE = { {framework::proto::VarType::BOOL, ACL_BOOL}, {framework::proto::VarType::UINT8, ACL_UINT8}, + {framework::proto::VarType::INT8, ACL_INT8}, {framework::proto::VarType::INT16, ACL_INT16}, {framework::proto::VarType::INT32, ACL_INT32}, {framework::proto::VarType::INT64, ACL_INT64}, diff --git a/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py new file mode 100644 index 00000000000..88e711dcf06 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py @@ -0,0 +1,318 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +class TestFlattenOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "flatten_contiguous_range" + self.place = paddle.NPUPlace(0) + + self.start_axis = 0 + self.stop_axis = -1 + self.dtype = np.float64 + self.init_test_case() + self.inputs = {"X": np.random.random(self.in_shape).astype(self.dtype)} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.in_shape).astype("float32") + } + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place, no_check_set=["XShape"]) + + def test_check_grad(self): + pass + + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = -1 + self.new_shape = (120) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_1(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 1 + self.stop_axis = 2 + self.new_shape = (3, 10, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_2(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_3(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 2 + self.new_shape = (30, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_4(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = -2 + self.stop_axis = -1 + self.new_shape = (3, 2, 20) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_5(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 2 + self.stop_axis = 2 + self.new_shape = (3, 2, 5, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOpSixDims(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.start_axis = 3 + self.stop_axis = 5 + self.new_shape = (3, 2, 3, 32) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_Float32(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.float32 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_int(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.int + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_uint8(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.uint8 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_int8(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.int8 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_int64(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.int64 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlatten2OpError(unittest.TestCase): + def test_errors(self): + image_shape = (2, 3, 4, 4) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x = x.astype('float32') + + def test_ValueError1(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32') + out = paddle.flatten(x_var, start_axis=2, stop_axis=1) + + self.assertRaises(ValueError, test_ValueError1) + + def test_ValueError2(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32') + paddle.flatten(x_var, start_axis=10, stop_axis=1) + + self.assertRaises(ValueError, test_ValueError2) + + def test_ValueError3(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32') + paddle.flatten(x_var, start_axis=2, stop_axis=10) + + self.assertRaises(ValueError, test_ValueError3) + + def test_type(): + # dtype must be float32, float64, int8, int32, int64, uint8. + x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x2 = x2.astype('float16') + x2_var = paddle.fluid.data( + name='x2', shape=[3, 2, 4, 5], dtype='float16') + paddle.flatten(x2_var) + + self.assertRaises(TypeError, test_type) + + def test_InputError(): + out = paddle.flatten(x) + + self.assertRaises(ValueError, test_InputError) + + +class TestStaticFlattenPythonAPI(unittest.TestCase): + def execute_api(self, x, start_axis=0, stop_axis=-1): + return paddle.flatten(x, start_axis, stop_axis) + + def test_static_api(self): + paddle.enable_static() + np_x = np.random.rand(2, 3, 4, 4).astype('float32') + + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[2, 3, 4, 4], dtype='float32') + out = self.execute_api(x, start_axis=-2, stop_axis=-1) + + exe = paddle.static.Executor(place=paddle.NPUPlace(0)) + fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out]) + self.assertTrue((2, 3, 16) == fetch_out[0].shape) + + +class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI): + def execute_api(self, x, start_axis=0, stop_axis=-1): + return x.flatten_(start_axis, stop_axis) + + +class TestFlattenPython(unittest.TestCase): + def test_python_api(self): + image_shape = (2, 3, 4, 4) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x = x.astype('float32') + + def test_InputError(): + out = paddle.flatten(x) + + self.assertRaises(ValueError, test_InputError) + + def test_Negative(): + paddle.disable_static(paddle.NPUPlace(0)) + img = paddle.to_tensor(x) + out = paddle.flatten(img, start_axis=-2, stop_axis=-1) + return out.numpy().shape + + res_shape = test_Negative() + self.assertTrue((2, 3, 16) == res_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index bc9ff369771..f87b732d1b2 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- GitLab