提交 d2fda532 编写于 作者: G gongweibao

add expand comment

上级 48556ba3
......@@ -23,12 +23,18 @@ class BlockExpandOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("block"),
"Input(block) of BlockExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("padding"),
"Input(padding) of BlockExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("stride"),
"Input(stride) of BlockExpandOp should not be null.");
using namespace framework;
PADDLE_ENFORCE(ctx->HasInput("input"),
"Input of BlockExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of BlockExpandOp op should not be null.");
auto in_dim = ctx->GetInputDim("input");
PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input format must be NCHW.");
PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1.");
ctx->ShareLoD("X", /*->*/ "Out");
// ctx->SetOutputDim("Out", {1});
}
};
......@@ -38,8 +44,26 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
BlockExpandOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("block", "The input of block_expand op");
AddOutput("stride", "The output of block_expand op");
AddInput("input", "The input of block_expand op");
AddOutput("out", "The output of block_expand op");
AddAttr<int>("block_height",
R"DOC(
)DOC");
AddAttr<int>("block_width",
R"DOC(
)DOC");
AddAttr<int>("stride_height",
R"DOC(
)DOC");
AddAttr<int>("stride_width",
R"DOC(
)DOC");
AddAttr<int>("padding_height",
R"DOC(
)DOC");
AddAttr<int>("padding_width",
R"DOC(
)DOC");
AddComment(R"DOC(
Expand feature map to minibatch matrix.
- matrix width is: blockH_ * blockW_ * channels_
......
......@@ -25,34 +25,34 @@ namespace operators {
template <typename Place, typename T>
class BlockExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
using namespace framework;
const Tensor* input = context.Input<Tensor>("input");
const Tensor* filter = context.Input<Tensor>("filter");
const Tensor* stride = context.Input<Tensor>("stride");
const Tensor* padding = context.Input<Tensor>("padding");
Tensor* out = context.Output<Tensor>("Out");
auto input_dim = input->dims();
size_t N = input_dim[0];
size_t C = input_dim[1];
PADDLE_ENFORCE_GE(N, 1, "Input batchsize must >= 1.");
PADDLE_ENFORCE_EQ(input_dim.size(), 4, "Input format must be NCHW.");
size_t input_height = input_dim[2];
size_t input_height = input_dim[3];
size_t filter_height = filter[0];
size_t filter_width = filter[1];
size_t output_height = 1 +
(input_height + 2 * padding_height - block_height() +
stride_height - 1) /
stride_height;
size_t output_width =
const Tensor* in = ctx.Input<Tensor>("input");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto in_dim = in->dims();
int N = in_dim[0];
int C = in_dim[1];
int in_height = in_dim[2];
int in_width = in_dim[3];
int block_height = ctx.Attr<int>("block_height");
int block_width = ctx.Attr<int>("block_width");
int stride_height = ctx.Attr<int>("stride_height");
int stride_width = ctx.Attr<int>("stride_width");
int padding_height = ctx.Attr<int>("padding_height");
int padding_width = ctx.Attr<int>("padding_width");
int output_height =
1 +
(in_height + 2 * padding_height - block_height + stride_height - 1) /
stride_height;
int output_width =
1 +
(input_width + 2 * padding_width - block_width() + stride_width - 1) /
(in_width + 2 * padding_width - block_width + stride_width - 1) /
stride_width;
Tensor col;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册