提交 6db476ed 编写于 作者: C chengduoZH

Separate the declarations and implementation of the PoolOp and PoolMaker class...

Separate the declarations and implementation of the PoolOp and PoolMaker class in order to reuse in pool_cudnn
上级 67edd04a
...@@ -22,14 +22,8 @@ int OutputSizePool(int input_size, int filter_size, int padding, int stride) { ...@@ -22,14 +22,8 @@ int OutputSizePool(int input_size, int filter_size, int padding, int stride) {
return output_size; return output_size;
} }
class PoolOp : public framework::OperatorWithKernel { void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
public: PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null."); "Out(Output) of Pooling should not be null.");
...@@ -62,25 +56,17 @@ class PoolOp : public framework::OperatorWithKernel { ...@@ -62,25 +56,17 @@ class PoolOp : public framework::OperatorWithKernel {
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
} }
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
} }
};
class PoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null."); "Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
};
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
public: framework::OpAttrChecker *op_checker)
Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
...@@ -131,12 +117,10 @@ number of channels, H and W is the height and width of feature. ...@@ -131,12 +117,10 @@ number of channels, H and W is the height and width of feature.
Parameters(ksize, strides, paddings) are two elements. Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively. These two elements represent height and width, respectively.
)DOC"); )DOC");
} }
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
public: framework::OpAttrChecker *op_checker)
Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
...@@ -169,8 +153,7 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -169,8 +153,7 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
"Default false." "Default false."
"If globalPooling = true, ksize is ignored and need not be specified.") "If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides",
"Strides(depth, height, width) of pooling operator." "Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.") "Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
...@@ -190,8 +173,7 @@ size, C is the number of channels, D, H and W is the depth, height and ...@@ -190,8 +173,7 @@ size, C is the number of channels, D, H and W is the depth, height and
width of feature. Parameters(ksize, strides, paddings) are three elements. width of feature. Parameters(ksize, strides, paddings) are three elements.
These three elements represent depth, height and width, respectively. These three elements represent depth, height and width, respectively.
)DOC"); )DOC");
} }
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -24,6 +24,34 @@ namespace operators { ...@@ -24,6 +24,34 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class PoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
class PoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
};
template <typename Place, typename T> template <typename Place, typename T>
class PoolKernel : public framework::OpKernel<T> { class PoolKernel : public framework::OpKernel<T> {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册