From 0289a0091f094c75190698df7e450d8e1a70bbaa Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 11 Sep 2017 22:15:29 -0700 Subject: [PATCH] follow comments to cleanup code --- paddle/operators/reshape_op.cc | 35 ++++++++++++++++++++++++++-------- paddle/operators/reshape_op.h | 9 ++++----- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 37cbecbf2..da29c8915 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -27,21 +27,26 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto *in = ctx.Input("X"); + // input check + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null"); auto shape = ctx.Attr>("shape"); - int64_t capacity = -1; + PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); for (auto dim : shape) { PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); - if (capacity < 0) { - capacity = dim; - } else { - capacity *= dim; - } } + // capacity check + int64_t capacity = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + auto *in = ctx.Input("X"); int64_t in_size = framework::product(in->dims()); PADDLE_ENFORCE_EQ(capacity, in_size, "The size of Input(X) mismatches with Attr(shape)."); - ctx.Output("Out")->Resize(in->dims()); + // resize output + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto out_dims = framework::make_ddim(shape_int64); + ctx.Output("Out")->Resize(out_dims); } }; @@ -56,6 +61,17 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(Reshape operator Reshape Input(X) into the shape specified by Attr(shape). + +An example: +Given a 2-D tensor X with 2 rows and 2 columns + + [[1, 2], [3, 4]] + +with target shape = [1, 4], the reshape operator will tansform +the tensor X into a 1-D tensor: + + [1, 2, 3, 4] + )DOC"); } }; @@ -70,6 +86,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); auto dims = ctx.Input("X")->dims(); auto *d_in = ctx.Output(framework::GradVarName("X")); d_in->Resize(dims); diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 0e920329d..26708e72d 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -30,11 +30,10 @@ class ReshapeKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto shape = ctx.Attr>("shape"); - std::vector tmp; - for (auto dim : shape) { - tmp.push_back(dim); - } - auto out_dims = framework::make_ddim(tmp); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto out_dims = framework::make_ddim(shape_int64); out->CopyFrom(*in, ctx.GetPlace()); out->Resize(out_dims); } -- GitLab