diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index ddb598f575f6737f7c7d4336eeee866b12c12fb1..0e74a23523b7d5182fabff88d08f6cc3f56a1783 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -476,6 +476,21 @@ class Reshape2Op : public ReshapeOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : ReshapeOp(type, inputs, outputs, attrs) {} + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true, + platform::errors::InvalidArgument( + "Output(XShape) of ReshapeOp should not be null.")); + const auto &x_dims = ctx->GetInputDim("X"); + std::vector xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < x_dims.size(); ++i) { + xshape_dims[i + 1] = x_dims[i]; + } + ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims)); + ctx->ShareLoD("X", /*->*/ "XShape"); + + ReshapeOp::InferShape(ctx); + } }; class Reshape2OpMaker : public ReshapeOpMaker { @@ -636,13 +651,10 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel); -DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor, - PT_INFER_META(phi::ReshapeWithXShapeInferMeta)); - REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, ops::Reshape2GradMaker, ops::Reshape2GradMaker, - ReshapeInferShapeFunctor, ops::ReshapeOpInplaceInferer); + ops::ReshapeOpInplaceInferer); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, ops::Reshape2DoubleGradMaker, ops::Reshape2DoubleGradMaker,