diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 37cbecbf250afe910c25e2d630de1da606c14332..da29c8915010a0fd10631ca21c0fe5104a01ef7d 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -27,21 +27,26 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto *in = ctx.Input("X"); + // input check + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null"); auto shape = ctx.Attr>("shape"); - int64_t capacity = -1; + PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); for (auto dim : shape) { PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); - if (capacity < 0) { - capacity = dim; - } else { - capacity *= dim; - } } + // capacity check + int64_t capacity = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + auto *in = ctx.Input("X"); int64_t in_size = framework::product(in->dims()); PADDLE_ENFORCE_EQ(capacity, in_size, "The size of Input(X) mismatches with Attr(shape)."); - ctx.Output("Out")->Resize(in->dims()); + // resize output + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto out_dims = framework::make_ddim(shape_int64); + ctx.Output("Out")->Resize(out_dims); } }; @@ -56,6 +61,17 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(Reshape operator Reshape Input(X) into the shape specified by Attr(shape). + +An example: +Given a 2-D tensor X with 2 rows and 2 columns + + [[1, 2], [3, 4]] + +with target shape = [1, 4], the reshape operator will tansform +the tensor X into a 1-D tensor: + + [1, 2, 3, 4] + )DOC"); } }; @@ -70,6 +86,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); auto dims = ctx.Input("X")->dims(); auto *d_in = ctx.Output(framework::GradVarName("X")); d_in->Resize(dims); diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 0e920329d99f7643dc2f0aed697581b0b84c17e8..26708e72dc8f80d2cff1c1ee5e8763b959320205 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -30,11 +30,10 @@ class ReshapeKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto shape = ctx.Attr>("shape"); - std::vector tmp; - for (auto dim : shape) { - tmp.push_back(dim); - } - auto out_dims = framework::make_ddim(tmp); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto out_dims = framework::make_ddim(shape_int64); out->CopyFrom(*in, ctx.GetPlace()); out->Resize(out_dims); }