提交 9aecb4bd 编写于 作者: J jingqinghe

renew reduce sum double grad test=develop

上级 dc4fc5cf
......@@ -51,6 +51,20 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
op->SetType("reduce_sum");
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X");
class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
public:
......@@ -63,50 +77,6 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
}
};
class ReduceSumDoubleGradDescMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
std::vector<std::unique_ptr<framework::OpDesc>> ops;
auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx
auto out_grads = InputGrad(framework::GradVarName("Out"));
if (!out_grads.empty()) {
auto* out_grad_op = new framework::OpDesc();
out_grad_op->SetType("reduce_sum");
out_grad_op->SetInput("X", x_gg);
out_grad_op->SetAttrMap(Attrs());
out_grad_op->SetOutput("Out", out_grads);
ops.emplace_back(out_grad_op);
}
return ops;
}
};
class ReduceSumDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
public:
using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;
std::shared_ptr<imperative::GradOpNode> operator()() const override {
auto out_grads = InputGrad(framework::GradVarName("Out"));
if (!out_grads.empty()) {
auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx
auto node = this->NewGradNode();
{
imperative::TracedGradOp op(node);
op.SetType("reduce_sum");
op.SetInput("X", x_gg);
op.SetAttrMap(Attrs());
op.SetOutput("Out", out_grads);
}
return node;
} else {
return nullptr;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -121,8 +91,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumDoubleGradDescMaker,
ops::ReduceSumDoubleGradOpBaseMaker,
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
ops::ReduceSumGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册