未验证 提交 b33a3c23 编写于 作者: Y YuanRisheng 提交者: GitHub

revert reshape op infershape (#39946)

上级 9a7b9eda
...@@ -476,6 +476,21 @@ class Reshape2Op : public ReshapeOp { ...@@ -476,6 +476,21 @@ class Reshape2Op : public ReshapeOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {} : ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
platform::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
ReshapeOp::InferShape(ctx);
}
}; };
class Reshape2OpMaker : public ReshapeOpMaker { class Reshape2OpMaker : public ReshapeOpMaker {
...@@ -636,13 +651,10 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ...@@ -636,13 +651,10 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor,
PT_INFER_META(phi::ReshapeWithXShapeInferMeta));
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>, ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>, ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ReshapeInferShapeFunctor, ops::ReshapeOpInplaceInferer); ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>, ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>, ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册