未验证 提交 ccd42db7 编写于 作者: D danleifeng 提交者: GitHub

fix flatten infershape (#35321)

上级 a6cc567f
......@@ -55,11 +55,19 @@ class FlattenOp : public framework::OperatorWithKernel {
int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) {
if (in_dims[i] == -1 || outer == -1) {
outer = -1;
} else {
outer *= in_dims[i];
}
} else {
if (in_dims[i] == -1 || inner == -1) {
inner = -1;
} else {
inner *= in_dims[i];
}
}
}
std::vector<int32_t> out_shape(2);
out_shape[0] = outer;
out_shape[1] = inner;
......@@ -296,8 +304,12 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
out_shape.push_back(in_dims[i]);
}
for (int i = start_axis; i <= stop_axis; i++) {
if (in_dims[i] == -1 || outer == -1) {
outer = -1;
} else {
outer *= in_dims[i];
}
}
out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) {
out_shape.push_back(in_dims[i]);
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle
from op_test import OpTest
......@@ -69,6 +70,20 @@ class TestFlattenOpSixDims(TestFlattenOp):
self.new_shape = (36, 16)
class TestStaticFlattenInferShapePythonAPI(unittest.TestCase):
def execute_api(self, x, axis=1):
return fluid.layers.flatten(x, axis=axis)
def test_static_api(self):
paddle.enable_static()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[-1, 3, -1, -1], dtype='float32')
out = self.execute_api(x, axis=2)
self.assertTrue((-1, -1) == out.shape)
class TestFlatten2OpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
......
......@@ -201,6 +201,20 @@ class TestStaticFlattenPythonAPI(unittest.TestCase):
self.assertTrue((2, 3, 16) == fetch_out[0].shape)
class TestStaticFlattenInferShapePythonAPI(unittest.TestCase):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return paddle.flatten(x, start_axis, stop_axis)
def test_static_api(self):
paddle.enable_static()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[-1, 3, -1, -1], dtype='float32')
out = self.execute_api(x, start_axis=2, stop_axis=3)
self.assertTrue((-1, 3, -1) == out.shape)
class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return x.flatten_(start_axis, stop_axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册