diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index c94ce4174f2be32beae0547f6a8366fd2896e027..778bab9f4dd26829a6fa9b76b4381a41c75a3280 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -55,9 +55,17 @@ class FlattenOp : public framework::OperatorWithKernel { int64_t outer = 1, inner = 1; for (int i = 0; i < in_dims.size(); ++i) { if (i < axis) { - outer *= in_dims[i]; + if (in_dims[i] == -1 || outer == -1) { + outer = -1; + } else { + outer *= in_dims[i]; + } } else { - inner *= in_dims[i]; + if (in_dims[i] == -1 || inner == -1) { + inner = -1; + } else { + inner *= in_dims[i]; + } } } std::vector out_shape(2); @@ -296,7 +304,11 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { out_shape.push_back(in_dims[i]); } for (int i = start_axis; i <= stop_axis; i++) { - outer *= in_dims[i]; + if (in_dims[i] == -1 || outer == -1) { + outer = -1; + } else { + outer *= in_dims[i]; + } } out_shape.push_back(outer); for (int i = stop_axis + 1; i < in_dims_size; i++) { diff --git a/python/paddle/fluid/tests/unittests/test_flatten2_op.py b/python/paddle/fluid/tests/unittests/test_flatten2_op.py index a3c12a5fc01c39336f0476056a606c120f35f067..42b43cc46a69b2e5334f9f68cc60ca5cf78db12a 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten2_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten2_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np import paddle.fluid as fluid +import paddle from op_test import OpTest @@ -69,6 +70,20 @@ class TestFlattenOpSixDims(TestFlattenOp): self.new_shape = (36, 16) +class TestStaticFlattenInferShapePythonAPI(unittest.TestCase): + def execute_api(self, x, axis=1): + return fluid.layers.flatten(x, axis=axis) + + def test_static_api(self): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[-1, 3, -1, -1], dtype='float32') + out = self.execute_api(x, axis=2) + self.assertTrue((-1, -1) == out.shape) + + class TestFlatten2OpError(unittest.TestCase): def test_errors(self): with fluid.program_guard(fluid.Program(), fluid.Program()): diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index f87b732d1b2cc0622ce429acf12e87c43ac0cfbf..9093050d6d5c6c072d386b53d4099e39fc76a2ec 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -201,6 +201,20 @@ class TestStaticFlattenPythonAPI(unittest.TestCase): self.assertTrue((2, 3, 16) == fetch_out[0].shape) +class TestStaticFlattenInferShapePythonAPI(unittest.TestCase): + def execute_api(self, x, start_axis=0, stop_axis=-1): + return paddle.flatten(x, start_axis, stop_axis) + + def test_static_api(self): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[-1, 3, -1, -1], dtype='float32') + out = self.execute_api(x, start_axis=2, stop_axis=3) + self.assertTrue((-1, 3, -1) == out.shape) + + class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI): def execute_api(self, x, start_axis=0, stop_axis=-1): return x.flatten_(start_axis, stop_axis)