diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index ed153e7722f8b82df6374c08e0a7580386621599..c817b356939a716e243ea6314af235e523ea44b3 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -17,6 +17,82 @@ limitations under the License. */ namespace paddle { namespace operators { +class ReshapeOp : public framework::OperatorWithKernel { + public: + ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReshapeOp should not be null."); + + const std::vector &shape = ctx->Attrs().Get>("shape"); + PADDLE_ENFORCE(!shape.empty(), + "The shape information must be set by Attr(shape)."); + + std::vector output_shape; + auto x_dims = ctx->GetInputDim("X"); + auto out_dims = ValidateShape(shape, x_dims); + ctx->SetOutputDim("Out", out_dims); + // NOTE: Reshape op cannot reshape an input sequence batch into an + // output sequence batch that has a different number of time steps. Here + // output always shares the LoD information with input. But if + // Attr(shape) contains 0 or -1, the actual output shape can only be + // determined during runtime. The check for wheather it is a valid + // output sequence batch is performed in runtime. + ctx->ShareLoD("X", /*->*/ "Out"); + } + + private: + framework::DDim ValidateShape(const std::vector shape, + const framework::DDim &in_dims) const { + const int64_t in_size = framework::product(in_dims); + // only one dimension canbe set to -1, whose size will be automatically + // infered. + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PADDLE_ENFORCE( + unk_dim_idx == -1, + "Only one input dimension of Attr(shape) can be unknown."); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PADDLE_ENFORCE( + static_cast(i) < in_dims.size(), + "The index of dimension to copy from input shape must be less " + "than the size of input shape."); + } else { + PADDLE_ENFORCE( + shape[i] > 0, + "Each input dimension of Attr(shape) must not be negtive except " + "one unknown dimension."); + } + + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = + (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } + + if (unk_dim_idx != -1) { + output_shape[unk_dim_idx] = -in_size / capacity; + PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, + "Invalid shape is given."); + } else { + PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); + } + return framework::make_ddim(output_shape); + } +}; + class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { public: ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker) diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index db632577d74834796d9d314425eed2c9638b7404..59adb5e87c1f729a15c1a97d4d2d756fc22bf00e 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -20,81 +20,6 @@ limitations under the License. */ namespace paddle { namespace operators { -class ReshapeOp : public framework::OperatorWithKernel { - public: - ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ReshapeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ReshapeOp should not be null."); - - const std::vector &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE(!shape.empty(), - "The shape information must be set by Attr(shape)."); - - std::vector output_shape; - auto x_dims = ctx->GetInputDim("X"); - auto out_dims = ValidateShape(shape, x_dims); - ctx->SetOutputDim("Out", out_dims); - // NOTE: Reshape op cannot reshape an input sequence batch into an - // output sequence batch that has a different number of time steps. Here - // output always shares the LoD information with input. But if - // Attr(shape) contains 0 or -1, the actual output shape can only be - // determined during runtime. The check for wheather it is a valid - // output sequence batch is performed in runtime. - ctx->ShareLoD("X", /*->*/ "Out"); - } - - static framework::DDim ValidateShape(const std::vector shape, - const framework::DDim &in_dims) { - const int64_t in_size = framework::product(in_dims); - // only one dimension canbe set to -1, whose size will be automatically - // infered. - const int64_t unk_dim_val = -1; - const int64_t copy_dim_val = 0; - - std::vector output_shape(shape.size(), 0); - int64_t capacity = 1; - int unk_dim_idx = -1; - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == unk_dim_val) { - PADDLE_ENFORCE( - unk_dim_idx == -1, - "Only one input dimension of Attr(shape) can be unknown."); - unk_dim_idx = i; - } else if (shape[i] == copy_dim_val) { - PADDLE_ENFORCE( - static_cast(i) < in_dims.size(), - "The index of dimension to copy from input shape must be less " - "than the size of input shape."); - } else { - PADDLE_ENFORCE( - shape[i] > 0, - "Each input dimension of Attr(shape) must not be negtive except " - "one unknown dimension."); - } - - capacity *= (shape[i] ? shape[i] : in_dims[i]); - output_shape[i] = - (shape[i] ? static_cast(shape[i]) : in_dims[i]); - } - - if (unk_dim_idx != -1) { - output_shape[unk_dim_idx] = -in_size / capacity; - PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, - "Invalid shape is given."); - } else { - PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); - } - return framework::make_ddim(output_shape); - } -}; - template class ReshapeKernel : public framework::OpKernel { public: @@ -102,8 +27,7 @@ class ReshapeKernel : public framework::OpKernel { auto *out = ctx.Output("Out"); auto *in = ctx.Input("X"); - auto out_dims = ReshapeOp::ValidateShape( - ctx.Attr>("shape"), in->dims()); + auto out_dims = out->dims(); if (!in->lod().empty()) { PADDLE_ENFORCE_EQ(