提交 0b392187 编写于 作者: D Double_V 提交者: lvmengsi

memory optimizer for reshape op,test=develop (#20569)

上级 01209b51
...@@ -423,7 +423,6 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker { ...@@ -423,7 +423,6 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("reshape2_grad"); grad_op->SetType("reshape2_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput("XShape", Output("XShape")); grad_op->SetInput("XShape", Output("XShape"));
grad_op->SetInput("ShapeTensor", Input("ShapeTensor")); grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
...@@ -441,13 +440,10 @@ class Reshape2DoubleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -441,13 +440,10 @@ class Reshape2DoubleGradMaker : public framework::SingleGradOpDescMaker {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("reshape2_grad_grad"); grad_op->SetType("reshape2_grad_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput("ShapeTensor", Input("ShapeTensor")); grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
grad_op->SetInput("DOut", Input(framework::GradVarName("Out"))); grad_op->SetInput("DOut", Input(framework::GradVarName("Out")));
grad_op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); grad_op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
auto ddx = OutputGrad(framework::GradVarName("X"));
grad_op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); grad_op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
...@@ -501,7 +497,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { ...@@ -501,7 +497,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true,
"Input(X@GRAD_GRAD) shouldn't be null."); "Input(X@GRAD_GRAD) shouldn't be null.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册