提交 869cef6d 编写于 作者: L liym27 提交者: Aurelius84

fix bug of infer shape in pool op. test=develop (#20213)

上级 acb02fd6
...@@ -93,7 +93,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -93,7 +93,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
} else { } else {
for (size_t i = 0; i < data_dims.size(); ++i) { for (size_t i = 0; i < data_dims.size(); ++i) {
if ((!ctx->IsRuntime()) && (data_dims[i] < 0)) { if ((!ctx->IsRuntime()) && (data_dims[i] < 0)) {
output_shape.push_back(in_x_dims[i]); output_shape.push_back(data_dims[i]);
} else { } else {
output_shape.push_back( output_shape.push_back(
PoolOutputSize(data_dims[i], ksize[i], paddings[2 * i], PoolOutputSize(data_dims[i], ksize[i], paddings[2 * i],
...@@ -118,8 +118,6 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -118,8 +118,6 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
// std::string data_format = ctx.Attr<std::string>("data_format"); // change:
// delete
std::string data_format = "AnyLayout"; std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
...@@ -150,8 +148,6 @@ void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -150,8 +148,6 @@ void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
// std::string data_format = ctx.Attr<std::string>("data_format"); //
// change:delete
std::string data_format = "AnyLayout"; std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
......
...@@ -968,6 +968,18 @@ class TestPool2dAPI(OpTest): ...@@ -968,6 +968,18 @@ class TestPool2dAPI(OpTest):
append_batch_size=False, append_batch_size=False,
dtype="float32") dtype="float32")
input_NHWC_negetive = fluid.layers.data(
name="input_NHWC_negetive",
shape=[2, -1, 5, 3],
append_batch_size=False,
dtype="float32")
input_NCHW_negetive = fluid.layers.data(
name="input_NCHW_negetive",
shape=[2, 3, -1, -1],
append_batch_size=False,
dtype="float32")
ksize = [3, 3] ksize = [3, 3]
out_1 = fluid.layers.pool2d( out_1 = fluid.layers.pool2d(
input=input_NHWC, input=input_NHWC,
...@@ -1034,11 +1046,34 @@ class TestPool2dAPI(OpTest): ...@@ -1034,11 +1046,34 @@ class TestPool2dAPI(OpTest):
use_cudnn=False, use_cudnn=False,
data_format="NHWC") data_format="NHWC")
# test negetive
out_9 = fluid.layers.pool2d(
input=input_NHWC_negetive,
pool_size=ksize,
pool_type="avg",
pool_padding=[0, 0],
use_cudnn=False,
data_format="NHWC")
assert out_9.shape == (2, -1, 3, 3)
out_10 = fluid.layers.pool2d(
input=input_NCHW_negetive,
pool_size=ksize,
pool_type="avg",
pool_padding=[0, 0],
use_cudnn=False,
data_format="NCHW")
assert out_10.shape == (2, 3, -1, -1)
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
[res_1, res_2, res_3, res_4, res_5, res_6, res_7, res_8] = exe.run( [res_1, res_2, res_3, res_4, res_5, res_6, res_7, res_8] = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={"input_NHWC": x_NHWC, feed={
"input_NCHW": x_NCHW}, "input_NHWC": x_NHWC,
"input_NCHW": x_NCHW,
"input_NHWC_negetive": x_NHWC,
"input_NCHW_negetive": x_NCHW
},
fetch_list=[ fetch_list=[
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8 out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8
]) ])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册