提交 6db476ed 编写于 作者: C chengduoZH

Separate the declarations and implementation of the PoolOp and PoolMaker class...

Separate the declarations and implementation of the PoolOp and PoolMaker class in order to reuse in pool_cudnn
上级 67edd04a
......@@ -22,108 +22,94 @@ int OutputSizePool(int input_size, int filter_size, int padding, int stride) {
return output_size;
}
class PoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
std::string pooling_type = ctx->Attrs().Get<std::string>("poolingType");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D");
if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Input size and pooling size should be consistent.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"Strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
std::string pooling_type = ctx->Attrs().Get<std::string>("poolingType");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D");
if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
};
class PoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Input size and pooling size should be consistent.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"Strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
};
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<std::string>("poolingType",
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"Strides(height, width) of pooling operator."
"Default {1,1}.")
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings",
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<std::string>("poolingType",
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"Strides(height, width) of pooling operator."
"Default {1,1}.")
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings",
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
The pooling2d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
Input(X) and output(Out) are in NCHW format. Where N is batch size, C is the
......@@ -131,58 +117,55 @@ number of channels, H and W is the height and width of feature.
Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively.
)DOC");
}
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of "
"feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of feature.");
AddAttr<std::string>("poolingType",
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
}
Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of "
"feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of feature.");
AddAttr<std::string>("poolingType",
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
The pooling3d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
Input(X) and output(Out) are in NCDHW format. Where N is batch
......@@ -190,8 +173,7 @@ size, C is the number of channels, D, H and W is the depth, height and
width of feature. Parameters(ksize, strides, paddings) are three elements.
These three elements represent depth, height and width, respectively.
)DOC");
}
};
}
} // namespace operators
} // namespace paddle
......
......@@ -24,6 +24,34 @@ namespace operators {
using Tensor = framework::Tensor;
class PoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
class PoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
};
template <typename Place, typename T>
class PoolKernel : public framework::OpKernel<T> {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册