未验证 提交 852c78b3 编写于 作者: D danleifeng 提交者: GitHub

[cherry-pick] fix flatten infershape (#35398)

* fix flatten infershape; test=develop

* fix flatten infershape; test=develop
上级 e04b66f2
...@@ -55,9 +55,17 @@ class FlattenOp : public framework::OperatorWithKernel { ...@@ -55,9 +55,17 @@ class FlattenOp : public framework::OperatorWithKernel {
int64_t outer = 1, inner = 1; int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) { for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) { if (i < axis) {
outer *= in_dims[i]; if (in_dims[i] == -1 || outer == -1) {
outer = -1;
} else {
outer *= in_dims[i];
}
} else { } else {
inner *= in_dims[i]; if (in_dims[i] == -1 || inner == -1) {
inner = -1;
} else {
inner *= in_dims[i];
}
} }
} }
std::vector<int32_t> out_shape(2); std::vector<int32_t> out_shape(2);
...@@ -296,7 +304,11 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { ...@@ -296,7 +304,11 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
out_shape.push_back(in_dims[i]); out_shape.push_back(in_dims[i]);
} }
for (int i = start_axis; i <= stop_axis; i++) { for (int i = start_axis; i <= stop_axis; i++) {
outer *= in_dims[i]; if (in_dims[i] == -1 || outer == -1) {
outer = -1;
} else {
outer *= in_dims[i];
}
} }
out_shape.push_back(outer); out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) { for (int i = stop_axis + 1; i < in_dims_size; i++) {
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
from op_test import OpTest from op_test import OpTest
...@@ -69,6 +70,20 @@ class TestFlattenOpSixDims(TestFlattenOp): ...@@ -69,6 +70,20 @@ class TestFlattenOpSixDims(TestFlattenOp):
self.new_shape = (36, 16) self.new_shape = (36, 16)
class TestStaticFlattenInferShapePythonAPI(unittest.TestCase):
def execute_api(self, x, axis=1):
return fluid.layers.flatten(x, axis=axis)
def test_static_api(self):
paddle.enable_static()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[-1, 3, -1, -1], dtype='float32')
out = self.execute_api(x, axis=2)
self.assertTrue((-1, -1) == out.shape)
class TestFlatten2OpError(unittest.TestCase): class TestFlatten2OpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
......
...@@ -201,6 +201,20 @@ class TestStaticFlattenPythonAPI(unittest.TestCase): ...@@ -201,6 +201,20 @@ class TestStaticFlattenPythonAPI(unittest.TestCase):
self.assertTrue((2, 3, 16) == fetch_out[0].shape) self.assertTrue((2, 3, 16) == fetch_out[0].shape)
class TestStaticFlattenInferShapePythonAPI(unittest.TestCase):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return paddle.flatten(x, start_axis, stop_axis)
def test_static_api(self):
paddle.enable_static()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[-1, 3, -1, -1], dtype='float32')
out = self.execute_api(x, start_axis=2, stop_axis=3)
self.assertTrue((-1, 3, -1) == out.shape)
class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI): class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI):
def execute_api(self, x, start_axis=0, stop_axis=-1): def execute_api(self, x, start_axis=0, stop_axis=-1):
return x.flatten_(start_axis, stop_axis) return x.flatten_(start_axis, stop_axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册