From 7bf2aa3883066cb880e4bca8f8691dcdaf470c51 Mon Sep 17 00:00:00 2001 From: TTerror Date: Thu, 21 Oct 2021 14:28:24 +0800 Subject: [PATCH] add fill_any_like/flatten ops to train ssd on kunlun (#36550) * add some ops to train ssd on kunlun * update test_fill_any_like_op_xpu.py --- .../fluid/operators/fill_any_like_op_xpu.cc | 79 +++++ paddle/fluid/operators/flatten_op_xpu.cc | 67 ++++ paddle/fluid/platform/xpu/xpu2_op_list.h | 36 ++ .../fluid/tests/unittests/op_test_xpu.py | 24 +- .../xpu/test_fill_any_like_op_xpu.py | 77 +++++ .../unittests/xpu/test_flatten2_op_xpu.py | 83 +++++ .../test_flatten_contiguous_range_op_xpu.py | 320 ++++++++++++++++++ .../unittests/xpu/test_flatten_op_xpu.py | 77 +++++ 8 files changed, 761 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/fill_any_like_op_xpu.cc create mode 100644 paddle/fluid/operators/flatten_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_fill_any_like_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_flatten2_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_flatten_op_xpu.py diff --git a/paddle/fluid/operators/fill_any_like_op_xpu.cc b/paddle/fluid/operators/fill_any_like_op_xpu.cc new file mode 100644 index 0000000000..76cf339fbf --- /dev/null +++ b/paddle/fluid/operators/fill_any_like_op_xpu.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/fill_any_like_op.h" + +namespace paddle { +namespace operators { + +template +class FillAnyLikeXPUKernel : public framework::OpKernel { + public: + using CommonType = typename std::common_type< + float, + typename std::conditional::value, + float, T>::type>::type; + using XPUInTDType = typename XPUTypeTrait::Type; + + void Compute(const framework::ExecutionContext& context) const override { + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + float value = context.Attr("value"); + + auto common_type_value = static_cast(value); + + PADDLE_ENFORCE_EQ( + (common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max())), + true, + platform::errors::InvalidArgument( + "The filled value is out of range for target type, " + "current kernel type is %s, the range should between %f " + "and %f, but now value is %f.", + typeid(T).name(), + static_cast(std::numeric_limits::lowest()), + static_cast(std::numeric_limits::max()), value)); + + PADDLE_ENFORCE_EQ( + std::isnan(value), false, + platform::errors::InvalidArgument("The filled value is NaN.")); + + auto& dev_ctx = + context.template device_context(); + auto out_data = reinterpret_cast(out->data()); + int ret = xpu::constant(dev_ctx.x_context(), out_data, out->numel(), + static_cast(value)); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU CONSTANT API return wrong value[%d %s].", ret, + XPUAPIErrorMsg[ret])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(fill_any_like, ops::FillAnyLikeXPUKernel, + ops::FillAnyLikeXPUKernel, + ops::FillAnyLikeXPUKernel, + ops::FillAnyLikeXPUKernel); + +#endif diff --git a/paddle/fluid/operators/flatten_op_xpu.cc b/paddle/fluid/operators/flatten_op_xpu.cc new file mode 100644 index 0000000000..53c0c688fd --- /dev/null +++ b/paddle/fluid/operators/flatten_op_xpu.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/flatten_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL( + flatten, ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel); +REGISTER_OP_XPU_KERNEL( + flatten_grad, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel); +REGISTER_OP_XPU_KERNEL( + flatten2, ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel); +REGISTER_OP_XPU_KERNEL( + flatten2_grad, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel); +REGISTER_OP_XPU_KERNEL( + flatten_contiguous_range, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel, + ops::FlattenContiguousRangeKernel); +REGISTER_OP_XPU_KERNEL( + flatten_contiguous_range_grad, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel, + ops::FlattenContiguousRangeGradKernel); +#endif diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index 5d45e5d9d5..0a9a9453b5 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -119,6 +119,42 @@ XPUOpMap& get_kl2_ops() { {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, + {"fill_any_like", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"flatten", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"flatten_grad", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"flatten2", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"flatten2_grad", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + + {"flatten_contiguous_range", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"flatten_contiguous_range_grad", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, // AddMore }; diff --git a/python/paddle/fluid/tests/unittests/op_test_xpu.py b/python/paddle/fluid/tests/unittests/op_test_xpu.py index 133367a5f3..239708cc17 100644 --- a/python/paddle/fluid/tests/unittests/op_test_xpu.py +++ b/python/paddle/fluid/tests/unittests/op_test_xpu.py @@ -91,11 +91,31 @@ class XPUOpTest(OpTest): # case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed if not hasattr(cls, "no_need_check_grad") \ and not is_empty_grad_op(cls.op_type): - if cls.dtype is not None and \ - cls.dtype != np.float32: + if cls.dtype is None or \ + (cls.dtype == np.float16 \ + and cls.op_type not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST \ + and not hasattr(cls, "exist_check_grad")): raise AssertionError("This test of %s op needs check_grad." % cls.op_type) + # check for op test with fp64 precision, but not check mkldnn op test for now + if cls.dtype in [np.float32, np.float64] \ + and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \ + and not hasattr(cls, 'exist_fp64_check_grad') \ + and not is_xpu_op_test() \ + and not is_mkldnn_op_test() \ + and not is_rocm_op_test() \ + and not is_npu_op_test(): + raise AssertionError( + "This test of %s op needs check_grad with fp64 precision." % + cls.op_type) + + if not cls.input_shape_is_large \ + and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST: + raise AssertionError( + "Input's shape should be large than or equal to 100 for " + + cls.op_type + " Op.") + def try_call_once(self, data_type): if not self.call_once: self.call_once = True diff --git a/python/paddle/fluid/tests/unittests/xpu/test_fill_any_like_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_fill_any_like_op_xpu.py new file mode 100644 index 0000000000..27c101b20f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_fill_any_like_op_xpu.py @@ -0,0 +1,77 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +import paddle.compat as cpt +import unittest +import numpy as np +from op_test import OpTest +from op_test_xpu import XPUOpTest + +paddle.enable_static() + + +class TestFillAnyLikeOp(OpTest): + def setUp(self): + self.op_type = "fill_any_like" + self.dtype = np.float32 + self.use_xpu = True + self.use_mkldnn = False + self.value = 0.0 + self.init() + self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)} + self.attrs = {'value': self.value, 'use_xpu': True} + self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])} + + def init(self): + pass + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + +class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp): + def init(self): + self.dtype = np.float32 + self.value = 0.0 + + +class TestFillAnyLikeOpValue1(TestFillAnyLikeOp): + def init(self): + self.value = 1.0 + + +class TestFillAnyLikeOpValue2(TestFillAnyLikeOp): + def init(self): + self.value = 1e-9 + + +class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp): + def init(self): + self.dtype = np.float16 + self.value = 0.05 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_flatten2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_flatten2_op_xpu.py new file mode 100644 index 0000000000..9cbc83950d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_flatten2_op_xpu.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys +sys.path.append("..") +import numpy as np +import paddle +import paddle.fluid as fluid +from op_test import OpTest +from op_test_xpu import XPUOpTest +paddle.enable_static() + + +class TestFlatten2Op(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "flatten2" + self.place = paddle.XPUPlace(0) + self.init_test_case() + self.inputs = {"X": np.random.random(self.in_shape).astype("float32")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.in_shape).astype("float32") + } + + def set_xpu(self): + self.__class__.use_xpu = True + + def test_check_output(self): + self.check_output_with_place(self.place, no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ["X"], "Out") + + def init_test_case(self): + self.in_shape = (3, 2, 4, 5) + self.axis = 1 + self.new_shape = (3, 40) + + def init_attrs(self): + self.attrs = {"axis": self.axis} + + +class TestFlatten2OpWithCornerAxis(TestFlatten2Op): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.axis = 0 + self.new_shape = (1, 120) + + +class TestFlatten2OpWithDefaultAxis(TestFlatten2Op): + def init_test_case(self): + self.in_shape = (10, 2, 2, 3) + self.new_shape = (10, 12) + + def init_attrs(self): + self.attrs = {} + + +class TestFlatten2OpSixDims(TestFlatten2Op): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.axis = 4 + self.new_shape = (36, 16) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py new file mode 100644 index 0000000000..dcad3c479f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py @@ -0,0 +1,320 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +class TestFlattenOp(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "flatten_contiguous_range" + self.place = paddle.XPUPlace(0) + self.use_xpu = True + self.use_mkldnn = False + + self.start_axis = 0 + self.stop_axis = -1 + self.dtype = np.float32 + self.init_test_case() + self.inputs = {"X": np.random.random(self.in_shape).astype(self.dtype)} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.in_shape).astype("float32") + } + + def set_xpu(self): + self.__class__.use_xpu = True + + def test_check_output(self): + self.check_output_with_place(self.place, no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ["X"], "Out") + + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = -1 + self.new_shape = (120) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis, + 'use_xpu': True, + } + + +class TestFlattenOp_1(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 1 + self.stop_axis = 2 + self.new_shape = (3, 10, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_2(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_3(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 2 + self.new_shape = (30, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_4(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = -2 + self.stop_axis = -1 + self.new_shape = (3, 2, 20) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_5(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 2 + self.stop_axis = 2 + self.new_shape = (3, 2, 5, 4) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOpSixDims(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.start_axis = 3 + self.stop_axis = 5 + self.new_shape = (3, 2, 3, 32) + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_Float32(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.float32 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + +class TestFlattenOp_int32(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.int32 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis, + 'use_xpu': True + } + + def test_check_grad(self): + pass + + +class TestFlattenOp_int8(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.int8 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + def test_check_grad(self): + pass + + +class TestFlattenOp_int64(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 5, 4) + self.start_axis = 0 + self.stop_axis = 1 + self.new_shape = (6, 5, 4) + self.dtype = np.int64 + + def init_attrs(self): + self.attrs = { + "start_axis": self.start_axis, + "stop_axis": self.stop_axis + } + + def test_check_grad(self): + pass + + +class TestFlatten2OpError(unittest.TestCase): + def test_errors(self): + image_shape = (2, 3, 4, 4) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x = x.astype('float32') + + def test_ValueError1(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32') + out = paddle.flatten(x_var, start_axis=2, stop_axis=1) + + self.assertRaises(ValueError, test_ValueError1) + + def test_ValueError2(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32') + paddle.flatten(x_var, start_axis=10, stop_axis=1) + + self.assertRaises(ValueError, test_ValueError2) + + def test_ValueError3(): + x_var = paddle.static.data( + name="x", shape=image_shape, dtype='float32') + paddle.flatten(x_var, start_axis=2, stop_axis=10) + + self.assertRaises(ValueError, test_ValueError3) + + def test_type(): + # dtype must be float32, float64, int8, int32, int64 + x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x2 = x2.astype('float16') + x2_var = paddle.fluid.data( + name='x2', shape=[3, 2, 4, 5], dtype='float16') + paddle.flatten(x2_var) + + self.assertRaises(TypeError, test_type) + + def test_InputError(): + out = paddle.flatten(x) + + self.assertRaises(ValueError, test_InputError) + + +class TestStaticFlattenPythonAPI(unittest.TestCase): + def execute_api(self, x, start_axis=0, stop_axis=-1): + return paddle.flatten(x, start_axis, stop_axis) + + def test_static_api(self): + paddle.enable_static() + np_x = np.random.rand(2, 3, 4, 4).astype('float32') + + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[2, 3, 4, 4], dtype='float32') + out = self.execute_api(x, start_axis=-2, stop_axis=-1) + + exe = paddle.static.Executor(place=paddle.XPUPlace(0)) + fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out]) + self.assertTrue((2, 3, 16) == fetch_out[0].shape) + + +class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI): + def execute_api(self, x, start_axis=0, stop_axis=-1): + return x.flatten_(start_axis, stop_axis) + + +class TestFlattenPython(unittest.TestCase): + def test_python_api(self): + image_shape = (2, 3, 4, 4) + x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * + image_shape[3]).reshape(image_shape) / 100. + x = x.astype('float32') + + def test_InputError(): + out = paddle.flatten(x) + + self.assertRaises(ValueError, test_InputError) + + def test_Negative(): + paddle.disable_static(paddle.XPUPlace(0)) + img = paddle.to_tensor(x) + out = paddle.flatten(img, start_axis=-2, stop_axis=-1) + return out.numpy().shape + + res_shape = test_Negative() + self.assertTrue((2, 3, 16) == res_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_flatten_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_flatten_op_xpu.py new file mode 100644 index 0000000000..ed43519835 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_flatten_op_xpu.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys +sys.path.append("..") +import numpy as np +import paddle +import paddle.fluid as fluid +from op_test import OpTest +from op_test_xpu import XPUOpTest +paddle.enable_static() + + +class TestFlattenOp(XPUOpTest): + def setUp(self): + self.op_type = "flatten" + self.use_xpu = True + self.place = paddle.XPUPlace(0) + self.init_test_case() + self.inputs = {"X": np.random.random(self.in_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ["X"], "Out") + + def init_test_case(self): + self.in_shape = (3, 2, 2, 10) + self.axis = 1 + self.new_shape = (3, 40) + + def init_attrs(self): + self.attrs = {"axis": self.axis} + + +class TestFlattenOp1(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 2, 10) + self.axis = 0 + self.new_shape = (1, 120) + + +class TestFlattenOpWithDefaultAxis(TestFlattenOp): + def init_test_case(self): + self.in_shape = (10, 2, 2, 3) + self.new_shape = (10, 12) + + def init_attrs(self): + self.attrs = {} + + +class TestFlattenOpSixDims(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.axis = 4 + self.new_shape = (36, 16) + + +if __name__ == "__main__": + unittest.main() -- GitLab