From d2fda53217bf7c5370446f9a404b711ace9df130 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 12 Oct 2017 09:34:28 +0000 Subject: [PATCH] add expand comment --- paddle/operators/block_expand_op.cc | 40 +++++++++++++++++----- paddle/operators/block_expand_op.h | 52 ++++++++++++++--------------- 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/paddle/operators/block_expand_op.cc b/paddle/operators/block_expand_op.cc index 0b36dc1ae54..69c5e02a658 100644 --- a/paddle/operators/block_expand_op.cc +++ b/paddle/operators/block_expand_op.cc @@ -23,12 +23,18 @@ class BlockExpandOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("block"), - "Input(block) of BlockExpandOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("padding"), - "Input(padding) of BlockExpandOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("stride"), - "Input(stride) of BlockExpandOp should not be null."); + using namespace framework; + PADDLE_ENFORCE(ctx->HasInput("input"), + "Input of BlockExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of BlockExpandOp op should not be null."); + + auto in_dim = ctx->GetInputDim("input"); + PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input format must be NCHW."); + PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1."); + + ctx->ShareLoD("X", /*->*/ "Out"); + // ctx->SetOutputDim("Out", {1}); } }; @@ -38,8 +44,26 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { BlockExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("block", "The input of block_expand op"); - AddOutput("stride", "The output of block_expand op"); + AddInput("input", "The input of block_expand op"); + AddOutput("out", "The output of block_expand op"); + AddAttr("block_height", + R"DOC( + )DOC"); + AddAttr("block_width", + R"DOC( + )DOC"); + AddAttr("stride_height", + R"DOC( + )DOC"); + AddAttr("stride_width", + R"DOC( + )DOC"); + AddAttr("padding_height", + R"DOC( + )DOC"); + AddAttr("padding_width", + R"DOC( + )DOC"); AddComment(R"DOC( Expand feature map to minibatch matrix. - matrix width is: blockH_ * blockW_ * channels_ diff --git a/paddle/operators/block_expand_op.h b/paddle/operators/block_expand_op.h index 54a9c5354f1..c0521dbbadb 100644 --- a/paddle/operators/block_expand_op.h +++ b/paddle/operators/block_expand_op.h @@ -25,34 +25,34 @@ namespace operators { template class BlockExpandKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext& ctx) const override { using namespace framework; - const Tensor* input = context.Input("input"); - const Tensor* filter = context.Input("filter"); - const Tensor* stride = context.Input("stride"); - const Tensor* padding = context.Input("padding"); - Tensor* out = context.Output("Out"); - - auto input_dim = input->dims(); - size_t N = input_dim[0]; - size_t C = input_dim[1]; - PADDLE_ENFORCE_GE(N, 1, "Input batchsize must >= 1."); - PADDLE_ENFORCE_EQ(input_dim.size(), 4, "Input format must be NCHW."); - - size_t input_height = input_dim[2]; - size_t input_height = input_dim[3]; - - size_t filter_height = filter[0]; - size_t filter_width = filter[1]; - - size_t output_height = 1 + - (input_height + 2 * padding_height - block_height() + - stride_height - 1) / - stride_height; - - size_t output_width = + const Tensor* in = ctx.Input("input"); + Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto in_dim = in->dims(); + int N = in_dim[0]; + int C = in_dim[1]; + + int in_height = in_dim[2]; + int in_width = in_dim[3]; + + int block_height = ctx.Attr("block_height"); + int block_width = ctx.Attr("block_width"); + int stride_height = ctx.Attr("stride_height"); + int stride_width = ctx.Attr("stride_width"); + int padding_height = ctx.Attr("padding_height"); + int padding_width = ctx.Attr("padding_width"); + + int output_height = + 1 + + (in_height + 2 * padding_height - block_height + stride_height - 1) / + stride_height; + + int output_width = 1 + - (input_width + 2 * padding_width - block_width() + stride_width - 1) / + (in_width + 2 * padding_width - block_width + stride_width - 1) / stride_width; Tensor col; -- GitLab