From 852c78b3501964f190c460a36f3f6f123aba996d Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Fri, 3 Sep 2021 14:22:03 +0800 Subject: [PATCH] [cherry-pick] fix flatten infershape (#35398) * fix flatten infershape; test=develop * fix flatten infershape; test=develop --- paddle/fluid/operators/flatten_op.cc | 18 +++++++++++++++--- .../fluid/tests/unittests/test_flatten2_op.py | 15 +++++++++++++++ .../test_flatten_contiguous_range_op.py | 14 ++++++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index c94ce4174f2..778bab9f4dd 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -55,9 +55,17 @@ class FlattenOp : public framework::OperatorWithKernel { int64_t outer = 1, inner = 1; for (int i = 0; i < in_dims.size(); ++i) { if (i < axis) { - outer *= in_dims[i]; + if (in_dims[i] == -1 || outer == -1) { + outer = -1; + } else { + outer *= in_dims[i]; + } } else { - inner *= in_dims[i]; + if (in_dims[i] == -1 || inner == -1) { + inner = -1; + } else { + inner *= in_dims[i]; + } } } std::vector out_shape(2); @@ -296,7 +304,11 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { out_shape.push_back(in_dims[i]); } for (int i = start_axis; i <= stop_axis; i++) { - outer *= in_dims[i]; + if (in_dims[i] == -1 || outer == -1) { + outer = -1; + } else { + outer *= in_dims[i]; + } } out_shape.push_back(outer); for (int i = stop_axis + 1; i < in_dims_size; i++) { diff --git a/python/paddle/fluid/tests/unittests/test_flatten2_op.py b/python/paddle/fluid/tests/unittests/test_flatten2_op.py index 0d50c65558a..af45ebe235a 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten2_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten2_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np import paddle.fluid as fluid +import paddle from op_test import OpTest @@ -69,6 +70,20 @@ class TestFlattenOpSixDims(TestFlattenOp): self.new_shape = (36, 16) +class TestStaticFlattenInferShapePythonAPI(unittest.TestCase): + def execute_api(self, x, axis=1): + return fluid.layers.flatten(x, axis=axis) + + def test_static_api(self): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[-1, 3, -1, -1], dtype='float32') + out = self.execute_api(x, axis=2) + self.assertTrue((-1, -1) == out.shape) + + class TestFlatten2OpError(unittest.TestCase): def test_errors(self): with fluid.program_guard(fluid.Program(), fluid.Program()): diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index bc9ff369771..ccf340a6943 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -201,6 +201,20 @@ class TestStaticFlattenPythonAPI(unittest.TestCase): self.assertTrue((2, 3, 16) == fetch_out[0].shape) +class TestStaticFlattenInferShapePythonAPI(unittest.TestCase): + def execute_api(self, x, start_axis=0, stop_axis=-1): + return paddle.flatten(x, start_axis, stop_axis) + + def test_static_api(self): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[-1, 3, -1, -1], dtype='float32') + out = self.execute_api(x, start_axis=2, stop_axis=3) + self.assertTrue((-1, 3, -1) == out.shape) + + class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI): def execute_api(self, x, start_axis=0, stop_axis=-1): return x.flatten_(start_axis, stop_axis) -- GitLab