From 4b06d8db9179c74a35582c85f782e8c268d361a6 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 26 Sep 2017 20:06:12 +0800 Subject: [PATCH] fix globalPooling type (int => bool) --- paddle/operators/pool_op.cc | 47 ++++++++++++++++++++----------------- paddle/operators/pool_op.h | 8 +++---- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index a5e731cc668..9959b3ec07c 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -35,7 +35,7 @@ class PoolOp : public framework::OperatorWithKernel { auto in_x = ctx.Input("X"); auto out = ctx.Output("Out"); - int global_pooling = Attr("globalPooling"); + bool global_pooling = Attr("globalPooling"); std::string pooling_type = Attr("poolingType"); std::vector ksize = Attr>("ksize"); std::vector strides = Attr>("strides"); @@ -45,6 +45,15 @@ class PoolOp : public framework::OperatorWithKernel { "pooling_type should be 'max' or 'avg'"); PADDLE_ENFORCE(in_x->dims().size() == 4 || in_x->dims().size() == 5, "Pooling intput should be 4-D or 5-D"); + + if (global_pooling) { + ksize.resize(static_cast(in_x->dims().size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + + PADDLE_ENFORCE(in_x->dims().size() == static_cast(ksize.size() + 2), + "Input size and Pooling size should be consistent."); PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, "Pooling size should be 2 elements. or 3 elements."); PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), @@ -52,12 +61,6 @@ class PoolOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), "paddings size and pooling size should be the same."); - if (global_pooling == 1) { - ksize.resize(static_cast(in_x->dims().size()) - 2); - for (size_t i = 0; i < ksize.size(); ++i) - ksize[i] = static_cast(in_x->dims()[i + 2]); - } - std::vector output_shape({in_x->dims()[0], in_x->dims()[1]}); for (size_t i = 0; i < ksize.size(); ++i) { output_shape.push_back(OutputSizePool(in_x->dims()[i + 2], ksize[i], @@ -103,15 +106,16 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { "poolingType of pooling operator." "str constant equal to 'max' or 'avg'"); AddAttr>( - "ksize", "pooling size(height, width) of pooling operator.") - .AddCustomChecker(GreaterThanChecker_pool({0, 0})); - AddAttr( + "ksize", + "Pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be specified."); + AddAttr( "globalPooling", "whether to use the globalPooling." - "int constant equal to 0 or 1" - "default 0" - "If globalPooling = 1, ksize is ignored and need not be specified.") - .SetDefault(0); + "int constant equal to false or true" + "default false" + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); AddAttr>("strides", "strides(height, width) of pooling operator." "default {1,1}") @@ -177,15 +181,16 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { "poolingType of pooling operator." "str constant equal to 'max' or 'avg'"); AddAttr>( - "ksize", "pooling size(depth, height, width) of pooling operator.") - .AddCustomChecker(GreaterThanChecker_pool({0, 0, 0})); - AddAttr( + "ksize", + "pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be specified."); + AddAttr( "globalPooling", "whether to use the globalPooling." - "int constant equal to 0 or 1" - "default 0" - "If globalPooling = 1, ksize is ignored and need not be specified.") - .SetDefault(0); + "int constant equal to false or true" + "default false" + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); AddAttr>( "strides", "strides(depth, height, width) of pooling operator." diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index 94712822050..73c97216249 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -31,12 +31,12 @@ class PoolKernel : public framework::OpKernel { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); - int global_pooling = context.Attr("globalPooling"); + bool global_pooling = context.Attr("globalPooling"); std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (global_pooling == 1) { + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -92,13 +92,13 @@ class PoolGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - int global_pooling = context.Attr("globalPooling"); + bool global_pooling = context.Attr("globalPooling"); std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (global_pooling == 1) { + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x->dims()[i + 2]); } -- GitLab