提交 00548a16 编写于 作者: W wanghaoshuang

Remove intermediate output's gradient from inputs of grad_op.

上级 387e10c6
...@@ -194,13 +194,45 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -194,13 +194,45 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
} }
}; };
class GRUUnitGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("gru_unit_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("HiddenPrev", Input("HiddenPrev"));
op->SetInput("Weight", Input("Weight"));
op->SetInput("Bias", Input("Bias"));
op->SetInput("Hidden", Output("Hidden"));
op->SetInput("Gate", Output("Gate"));
op->SetInput("ResetHiddenPrev", Output("ResetHiddenPrev"));
op->SetInput(framework::GradVarName("Hidden"), OutputGrad("Hidden"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("HiddenPrev"),
InputGrad("HiddenPrev"));
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>) ops::GRUUnitGradOpMaker);
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp) REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>, gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, double>); ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, double>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册