diff --git a/paddle/fluid/operators/stack_op_npu.cc b/paddle/fluid/operators/stack_op_npu.cc index a7e18e9c0c31b1a9c2254fe55c7e24adefde4bf4..3b685b3ab8dbb0166d50ec521b9b93c4508dab12 100644 --- a/paddle/fluid/operators/stack_op_npu.cc +++ b/paddle/fluid/operators/stack_op_npu.cc @@ -12,15 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_ASCEND_CL -#include -#include -#include - -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/stack_op.h" -#include "paddle/fluid/operators/unsqueeze_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" namespace paddle { namespace operators { @@ -32,64 +25,56 @@ class StackNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto x = ctx.MultiInput("X"); - int32_t N = x.size(); + auto* y = ctx.Output("Y"); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + int num = static_cast(x.size()); - PADDLE_ENFORCE_GT( - N, 0, platform::errors::InvalidArgument("number of input Tensor <= 0")); + PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument( + "number of input Tensor <= 0")); + + auto stream = + ctx.template device_context() + .stream(); std::vector x_list; - for (int i = 0; i < N; i++) { + for (int i = 0; i < num; i++) { x_list.push_back(*x[i]); } + y->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); + const auto& runner = + NpuOpRunner("Pack", {x_list}, {*y}, {{"axis", axis}, {"N", num}}); + runner.Run(stream); + } +}; - if (axis < 0) { - axis = axis + x_list[0].dims().size() + 1; - } - auto* out = ctx.Output("Y"); +template +class StackGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto dx = ctx.MultiOutput(framework::GradVarName("X")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += dy->dims().size(); + int num = dy->dims()[axis]; - auto place = ctx.GetPlace(); + PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument( + "number of input Tensor <= 0")); auto stream = ctx.template device_context() .stream(); - out->mutable_data(place); - - if (axis != 0) { - auto x_dim = x_list[0].dims(); - std::vector vec_dim_tmp; - vec_dim_tmp.push_back(N); - for (auto i = 0; i < x_dim.size(); ++i) { - vec_dim_tmp.push_back(x_dim[i]); - } - - Tensor tmp_stack(out->type()); - tmp_stack.Resize(framework::make_ddim(vec_dim_tmp)); - tmp_stack.mutable_data(ctx.GetPlace()); - - const auto& runner = - NpuOpRunner("Pack", {x_list}, {tmp_stack}, {{"axis", 0}, {"N", N}}); - runner.Run(stream); - - std::vector vec_trans; - for (auto i = 1; i <= x_dim.size(); ++i) { - vec_trans.push_back(i); - if (i == axis) { - vec_trans.push_back(0); - } - } - - const auto& runner_trans_final = - NpuOpRunner("TransposeD", {tmp_stack}, {*out}, {{"perm", vec_trans}}); - runner_trans_final.Run(stream); - - } else { - const auto& runner = - NpuOpRunner("Pack", {x_list}, {*out}, {{"axis", axis}, {"N", N}}); - runner.Run(stream); + std::vector dx_list; + for (int i = 0; i < num; i++) { + dx[i]->mutable_data(ctx.GetPlace()); + dx_list.push_back(*dx[i]); } + + const auto& runner = + NpuOpRunner("Unpack", {*dy}, {dx_list}, {{"axis", axis}, {"num", num}}); + runner.Run(stream); } }; @@ -103,4 +88,8 @@ REGISTER_OP_NPU_KERNEL( ops::StackNPUKernel); -#endif +REGISTER_OP_NPU_KERNEL( + stack_grad, + ops::StackGradNPUKernel, + ops::StackGradNPUKernel); diff --git a/paddle/fluid/operators/unstack_op_npu.cc b/paddle/fluid/operators/unstack_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..eaab4ee999de73370099a38ec41fde81b6afe1d8 --- /dev/null +++ b/paddle/fluid/operators/unstack_op_npu.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unstack_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class UnStackNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *dy = ctx.Input("X"); + auto dx = ctx.MultiOutput("Y"); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += dy->dims().size(); + int num = dy->dims()[axis]; + + auto stream = + ctx.template device_context() + .stream(); + + std::vector dx_list; + for (int i = 0; i < num; i++) { + dx[i]->mutable_data(ctx.GetPlace()); + dx_list.push_back(*dx[i]); + } + + const auto &runner = + NpuOpRunner("Unpack", {*dy}, {dx_list}, {{"axis", axis}, {"num", num}}); + runner.Run(stream); + } +}; + +template +class UnStackGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto x = ctx.MultiInput(framework::GradVarName("Y")); + auto *y = ctx.Output(framework::GradVarName("X")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + int num = static_cast(x.size()); + + auto stream = + ctx.template device_context() + .stream(); + + std::vector x_list; + for (int i = 0; i < num; i++) { + x_list.push_back(*x[i]); + } + y->mutable_data(ctx.GetPlace()); + + const auto &runner = + NpuOpRunner("Pack", {x_list}, {*y}, {{"axis", axis}, {"N", num}}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + unstack, ops::UnStackNPUKernel, + ops::UnStackNPUKernel); + +REGISTER_OP_NPU_KERNEL( + unstack_grad, ops::UnStackGradNPUKernel, + ops::UnStackGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py index 6db98be9328a4316821f89ebb5d6c145c6711975..721fb95dd9b72f989746bbe1a7e27596a6b18a34 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py @@ -24,17 +24,18 @@ import paddle.fluid as fluid import paddle.fluid.core as core paddle.enable_static() -SEED = 2021 @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") -class TestStack1(OpTest): +class TestStackOpBase(OpTest): def initDefaultParameters(self): self.num_inputs = 4 self.input_dim = (5, 6, 7) self.axis = 0 - self.dtype = 'float32' + + def initParameters(self): + pass def get_x_names(self): x_names = [] @@ -44,10 +45,10 @@ class TestStack1(OpTest): def setUp(self): self.initDefaultParameters() + self.initParameters() + self.op_type = 'stack' self.set_npu() - self.op_type = "stack" - self.place = paddle.NPUPlace(0) - + self.init_dtype() self.x = [] for i in range(self.num_inputs): self.x.append( @@ -64,89 +65,191 @@ class TestStack1(OpTest): def set_npu(self): self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_dygraph=False) + self.check_output_with_place(self.place) + def test_check_grad(self): + self.check_grad_with_place(self.place, self.get_x_names(), 'Y') -class TestStack2(OpTest): - def initDefaultParameters(self): - self.num_inputs = 4 - self.input_dim = (2, 3, 4) - self.axis = -1 - self.dtype = 'float32' - def get_x_names(self): - x_names = [] - for i in range(self.num_inputs): - x_names.append('x{}'.format(i)) - return x_names +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp1(TestStackOpBase): + def initParameters(self): + self.num_inputs = 16 - def setUp(self): - self.initDefaultParameters() - self.set_npu() - self.op_type = "stack" - self.place = paddle.NPUPlace(0) - self.x = [] - for i in range(self.num_inputs): - self.x.append( - np.random.random(size=self.input_dim).astype(self.dtype)) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp2(TestStackOpBase): + def initParameters(self): + self.num_inputs = 20 - tmp = [] - x_names = self.get_x_names() - for i in range(self.num_inputs): - tmp.append((x_names[i], self.x[i])) - self.inputs = {'X': tmp} - self.outputs = {'Y': np.stack(self.x, axis=self.axis)} - self.attrs = {'axis': self.axis} +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp3(TestStackOpBase): + def initParameters(self): + self.axis = -1 - def set_npu(self): - self.__class__.use_npu = True - def test_check_output(self): - self.check_output_with_place(self.place, check_dygraph=False) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp4(TestStackOpBase): + def initParameters(self): + self.axis = -4 -class TestStack3(OpTest): - def initDefaultParameters(self): - self.num_inputs = 4 - self.input_dim = (2, 3, 4) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp5(TestStackOpBase): + def initParameters(self): self.axis = 1 - self.dtype = 'float32' - def get_x_names(self): - x_names = [] - for i in range(self.num_inputs): - x_names.append('x{}'.format(i)) - return x_names + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp6(TestStackOpBase): + def initParameters(self): + self.axis = 3 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackAPIWithLoDTensorArray(unittest.TestCase): + """ + Test stack api when the input(x) is a LoDTensorArray. + """ def setUp(self): - self.initDefaultParameters() - self.set_npu() - self.op_type = "stack" - self.place = paddle.NPUPlace(0) + self.axis = 1 + self.iter_num = 3 + self.input_shape = [2, 3] + self.x = np.random.random(self.input_shape).astype("float32") + self.place = paddle.NPUPlace(0) \ + if paddle.is_compiled_with_npu() else paddle.CPUPlace() + self.set_program() + + def set_program(self): + self.program = fluid.Program() + with fluid.program_guard(self.program): + input = fluid.layers.assign(self.x) + tensor_array = fluid.layers.create_array(dtype='float32') + zero = fluid.layers.fill_constant(shape=[1], value=0, dtype="int64") + + for i in range(self.iter_num): + fluid.layers.array_write(input, zero + i, tensor_array) + + self.out_var = fluid.layers.stack(tensor_array, axis=self.axis) + + def test_case(self): + self.assertTrue(self.out_var.shape[self.axis] == -1) + exe = fluid.Executor(self.place) + res = exe.run(self.program, fetch_list=self.out_var) + self.assertTrue( + np.array_equal( + res[0], np.stack( + [self.x] * self.iter_num, axis=self.axis))) - self.x = [] - for i in range(self.num_inputs): - self.x.append( - np.random.random(size=self.input_dim).astype(self.dtype)) - tmp = [] - x_names = self.get_x_names() - for i in range(self.num_inputs): - tmp.append((x_names[i], self.x[i])) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestTensorStackAPIWithLoDTensorArray(unittest.TestCase): + """ + Test stack api when the input(x) is a LoDTensorArray. + """ - self.inputs = {'X': tmp} - self.outputs = {'Y': np.stack(self.x, axis=self.axis)} - self.attrs = {'axis': self.axis} + def setUp(self): + self.axis = 1 + self.iter_num = 3 + self.input_shape = [2, 3] + self.x = np.random.random(self.input_shape).astype("float32") + self.place = paddle.NPUPlace(0) \ + if paddle.is_compiled_with_npu() else paddle.CPUPlace() + self.set_program() + + def set_program(self): + self.program = fluid.Program() + with fluid.program_guard(self.program): + input = fluid.layers.assign(self.x) + tensor_array = fluid.layers.create_array(dtype='float32') + zero = fluid.layers.fill_constant(shape=[1], value=0, dtype="int64") + + for i in range(self.iter_num): + fluid.layers.array_write(input, zero + i, tensor_array) + + self.out_var = paddle.stack(tensor_array, axis=self.axis) + + def test_case(self): + self.assertTrue(self.out_var.shape[self.axis] == -1) + exe = fluid.Executor(self.place) + res = exe.run(self.program, fetch_list=self.out_var) + self.assertTrue( + np.array_equal( + res[0], np.stack( + [self.x] * self.iter_num, axis=self.axis))) - def set_npu(self): - self.__class__.use_npu = True - def test_check_output(self): - self.check_output_with_place(self.place, check_dygraph=False) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class API_test(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float64') + data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float64') + data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float64') + result_stack = paddle.stack([data1, data2, data3], axis=0) + place = paddle.NPUPlace(0) + exe = fluid.Executor(place) + input1 = np.random.random([1, 2]).astype('float64') + input2 = np.random.random([1, 2]).astype('float64') + input3 = np.random.random([1, 2]).astype('float64') + result, = exe.run( + feed={"data1": input1, + "data2": input2, + "data3": input3}, + fetch_list=[result_stack]) + expected_result = np.stack([input1, input2, input3], axis=0) + self.assertTrue(np.allclose(expected_result, result)) + + def test_single_tensor_error(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = paddle.rand([2, 3]) + self.assertRaises(TypeError, paddle.stack, x) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class API_DygraphTest(unittest.TestCase): + def test_out(self): + data1 = np.array([[1.0, 2.0]]) + data2 = np.array([[3.0, 4.0]]) + data3 = np.array([[5.0, 6.0]]) + with fluid.dygraph.guard(place=paddle.NPUPlace(0)): + x1 = fluid.dygraph.to_variable(data1) + x2 = fluid.dygraph.to_variable(data2) + x3 = fluid.dygraph.to_variable(data3) + result = paddle.stack([x1, x2, x3]) + result_np = result.numpy() + expected_result = np.stack([data1, data2, data3]) + self.assertTrue(np.allclose(expected_result, result_np)) + + with fluid.dygraph.guard(place=paddle.NPUPlace(0)): + y1 = fluid.dygraph.to_variable(data1) + result = paddle.stack([y1], axis=0) + result_np_2 = result.numpy() + expected_result_2 = np.stack([data1], axis=0) + self.assertTrue(np.allclose(expected_result_2, result_np_2)) + + def test_single_tensor_error(self): + with fluid.dygraph.guard(place=paddle.NPUPlace(0)): + x = paddle.to_tensor([1, 2, 3]) + self.assertRaises(Exception, paddle.stack, x) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/npu/test_unstack_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_unstack_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd3c30c272c237b9aff31274b2274cdfa08cf8e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_unstack_op_npu.py @@ -0,0 +1,107 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import unittest +import paddle + +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestUnStackOpBase(OpTest): + def initDefaultParameters(self): + self.input_dim = (5, 6, 7) + self.axis = 0 + + def initParameters(self): + pass + + def get_y_names(self): + y_names = [] + for i in range(self.input_dim[self.axis]): + y_names.append('y{}'.format(i)) + return y_names + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + self.op_type = 'unstack' + self.set_npu() + self.init_dtype() + + self.x = np.random.random(size=self.input_dim).astype(self.dtype) + + outs = np.split(self.x, self.input_dim[self.axis], self.axis) + new_shape = list(self.input_dim) + del new_shape[self.axis] + y_names = self.get_y_names() + tmp = [] + for i in range(self.input_dim[self.axis]): + tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + + self.inputs = {'X': self.x} + self.outputs = {'Y': tmp} + self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], self.get_y_names()) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp3(TestUnStackOpBase): + def initParameters(self): + self.axis = -1 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp4(TestUnStackOpBase): + def initParameters(self): + self.axis = -3 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp5(TestUnStackOpBase): + def initParameters(self): + self.axis = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestStackOp6(TestUnStackOpBase): + def initParameters(self): + self.axis = 2 + + +if __name__ == '__main__': + unittest.main()