提交 f2ccc11f 编写于 作者: C chengduoZH

fix pool doc (pool_op.cc)

上级 c2c2d610
...@@ -32,10 +32,7 @@ class PoolOp : public framework::OperatorWithKernel { ...@@ -32,10 +32,7 @@ class PoolOp : public framework::OperatorWithKernel {
"X(Input) of Pooling should not be null."); "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Out(Output) of Pooling should not be null."); "Out(Output) of Pooling should not be null.");
// PADDLE_ENFORCE_NOT_NULL(Attr<std::string>("poolingType"),
// "pooling_type should not be null.");
// PADDLE_ENFORCE_NOT_NULL(Attr<std::vector<int>>("ksize"), "ksize should
// not be null.");
auto in_X = ctx.Input<Tensor>("X"); auto in_X = ctx.Input<Tensor>("X");
auto out = ctx.Output<Tensor>("Out"); auto out = ctx.Output<Tensor>("Out");
int global_pooling = Attr<int>("globalPooling"); int global_pooling = Attr<int>("globalPooling");
...@@ -56,11 +53,15 @@ class PoolOp : public framework::OperatorWithKernel { ...@@ -56,11 +53,15 @@ class PoolOp : public framework::OperatorWithKernel {
} }
if (ksize.size() == 2) { if (ksize.size() == 2) {
PADDLE_ENFORCE_EQ(strides.size(), 2, "Pool2DOp strides should be 2-D."); PADDLE_ENFORCE_EQ(strides.size(), 2,
PADDLE_ENFORCE_EQ(paddings.size(), 2, "Pool2DOp paddings should be 2-D."); "Pool2DOp strides size should be 2 elements.");
PADDLE_ENFORCE_EQ(paddings.size(), 2,
"Pool2DOp paddings size should be 2 elements");
} else { } else {
PADDLE_ENFORCE_EQ(strides.size(), 3, "Pool3DOp strides should be 3-D."); PADDLE_ENFORCE_EQ(strides.size(), 3,
PADDLE_ENFORCE_EQ(paddings.size(), 3, "Pool3DOp paddings should be 3-D."); "Pool3DOp strides should be 3 elements.");
PADDLE_ENFORCE_EQ(paddings.size(), 3,
"Pool3DOp paddings should be 3 elements.");
} }
std::vector<int64_t> output_shape({in_X->dims()[0], in_X->dims()[1]}); std::vector<int64_t> output_shape({in_X->dims()[0], in_X->dims()[1]});
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
...@@ -83,76 +84,84 @@ class PoolOpGrad : public framework::OperatorWithKernel { ...@@ -83,76 +84,84 @@ class PoolOpGrad : public framework::OperatorWithKernel {
} }
}; };
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"The input tensor of pooling operator. " "The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, D, H and W is the depth, height and width of " "number of channels, H and W is the height and width of image.");
"image.");
AddOutput("Out", AddOutput("Out",
"The output tensor of pooling operator." "The output tensor of pooling operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCHW.");
AddAttr<std::string>("poolingType", AddAttr<std::string>("poolingType",
"poolingType of pooling operator.['max' or 'ave']"); "poolingType of pooling operator."
"str constant equal to 'max' or 'ave'");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(depth, height, width) of pooling operator."); "ksize", "pooling size(height, width) of pooling operator.");
AddAttr<int>("globalPooling", AddAttr<int>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to 0 or 1"
"default 0" "default 0"
"whether to use the globalPooling.") "If globalPooling = 1, ksize is ignored and need not be specified.")
.SetDefault(0); .SetDefault(0);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides", "strides(height, width) of pooling operator."
"default {1,1,1}" "default {1,1}")
"strides(depth, height, width) of pooling operator.") .SetDefault({1, 1});
.SetDefault({1, 1, 1}); AddAttr<std::vector<int>>("paddings",
AddAttr<std::vector<int>>( "paddings(height, width) of pooling operator."
"paddings", "default {0,0}")
"default {0,0,0}" .SetDefault({0, 0});
"paddings(depth, height, width) of pooling operator.")
.SetDefault({0, 0, 0});
AddComment(R"DOC( AddComment(R"DOC(
The pooling3d operation calculates the output based on The pooling2d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters. the input, poolingType and ksize, strides, paddings parameters.
)DOC"); )DOC");
} }
}; };
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput("X",
"X",
"The input tensor of pooling operator. " "The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCDHW. Where N is batch size, C is "
"number of channels, H and W is the height and width of image."); "the "
"number of channels, D, H and W is the depth, height and width of "
"image.");
AddOutput("Out", AddOutput("Out",
"The output tensor of pooling operator." "The output tensor of pooling operator."
"The format of output tensor is also NCHW."); "The format of output tensor is also NCDHW.");
AddAttr<std::string>("poolingType", AddAttr<std::string>("poolingType",
"poolingType of pooling operator.['max' or 'ave']"); "poolingType of pooling operator."
"str constant equal to 'max' or 'ave'");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(height, width) of pooling operator."); "ksize", "pooling size(depth, height, width) of pooling operator.");
AddAttr<int>("globalPooling", AddAttr<int>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to 0 or 1"
"default 0" "default 0"
"whether to use the globalPooling.[0 or 1]") "If globalPooling = 1, ksize is ignored and need not be specified.")
.SetDefault(0); .SetDefault(0);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>(
"default {1, 1}" "strides",
"strides(height, width) of pooling operator.") "strides(depth, height, width) of pooling operator."
.SetDefault({1, 1}); "default {1,1,1}")
AddAttr<std::vector<int>>("paddings", .SetDefault({1, 1, 1});
"default {0, 0}" AddAttr<std::vector<int>>(
"paddings(height, width) of pooling operator.") "paddings",
.SetDefault({0, 0}); "paddings(depth, height, width) of pooling operator."
"default {0,0,0}")
.SetDefault({0, 0, 0});
AddComment(R"DOC( AddComment(R"DOC(
The pooling2d operation calculates the output based on The pooling3d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters. the input, poolingType and ksize, strides, paddings parameters.
)DOC"); )DOC");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册