提交 d2fda532 编写于 作者: G gongweibao

add expand comment

上级 48556ba3
...@@ -23,12 +23,18 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -23,12 +23,18 @@ class BlockExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("block"), using namespace framework;
"Input(block) of BlockExpandOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("input"),
PADDLE_ENFORCE(ctx->HasInput("padding"), "Input of BlockExpandOp should not be null.");
"Input(padding) of BlockExpandOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"),
PADDLE_ENFORCE(ctx->HasInput("stride"), "Output(Out) of BlockExpandOp op should not be null.");
"Input(stride) of BlockExpandOp should not be null.");
auto in_dim = ctx->GetInputDim("input");
PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input format must be NCHW.");
PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1.");
ctx->ShareLoD("X", /*->*/ "Out");
// ctx->SetOutputDim("Out", {1}); // ctx->SetOutputDim("Out", {1});
} }
}; };
...@@ -38,8 +44,26 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -38,8 +44,26 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
BlockExpandOpMaker(framework::OpProto* proto, BlockExpandOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("block", "The input of block_expand op"); AddInput("input", "The input of block_expand op");
AddOutput("stride", "The output of block_expand op"); AddOutput("out", "The output of block_expand op");
AddAttr<int>("block_height",
R"DOC(
)DOC");
AddAttr<int>("block_width",
R"DOC(
)DOC");
AddAttr<int>("stride_height",
R"DOC(
)DOC");
AddAttr<int>("stride_width",
R"DOC(
)DOC");
AddAttr<int>("padding_height",
R"DOC(
)DOC");
AddAttr<int>("padding_width",
R"DOC(
)DOC");
AddComment(R"DOC( AddComment(R"DOC(
Expand feature map to minibatch matrix. Expand feature map to minibatch matrix.
- matrix width is: blockH_ * blockW_ * channels_ - matrix width is: blockH_ * blockW_ * channels_
......
...@@ -25,34 +25,34 @@ namespace operators { ...@@ -25,34 +25,34 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class BlockExpandKernel : public framework::OpKernel<T> { class BlockExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using namespace framework; using namespace framework;
const Tensor* input = context.Input<Tensor>("input"); const Tensor* in = ctx.Input<Tensor>("input");
const Tensor* filter = context.Input<Tensor>("filter"); Tensor* out = ctx.Output<Tensor>("Out");
const Tensor* stride = context.Input<Tensor>("stride"); out->mutable_data<T>(ctx.GetPlace());
const Tensor* padding = context.Input<Tensor>("padding");
Tensor* out = context.Output<Tensor>("Out"); auto in_dim = in->dims();
int N = in_dim[0];
auto input_dim = input->dims(); int C = in_dim[1];
size_t N = input_dim[0];
size_t C = input_dim[1]; int in_height = in_dim[2];
PADDLE_ENFORCE_GE(N, 1, "Input batchsize must >= 1."); int in_width = in_dim[3];
PADDLE_ENFORCE_EQ(input_dim.size(), 4, "Input format must be NCHW.");
int block_height = ctx.Attr<int>("block_height");
size_t input_height = input_dim[2]; int block_width = ctx.Attr<int>("block_width");
size_t input_height = input_dim[3]; int stride_height = ctx.Attr<int>("stride_height");
int stride_width = ctx.Attr<int>("stride_width");
size_t filter_height = filter[0]; int padding_height = ctx.Attr<int>("padding_height");
size_t filter_width = filter[1]; int padding_width = ctx.Attr<int>("padding_width");
size_t output_height = 1 + int output_height =
(input_height + 2 * padding_height - block_height() + 1 +
stride_height - 1) / (in_height + 2 * padding_height - block_height + stride_height - 1) /
stride_height; stride_height;
size_t output_width = int output_width =
1 + 1 +
(input_width + 2 * padding_width - block_width() + stride_width - 1) / (in_width + 2 * padding_width - block_width + stride_width - 1) /
stride_width; stride_width;
Tensor col; Tensor col;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册