提交 de5dea5f 编写于 作者: J jingqinghe

support reduce_sum double grad test=develop

上级 6abd05f2
...@@ -63,6 +63,50 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { ...@@ -63,6 +63,50 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
} }
}; };
class ReduceSumDoubleGradDescMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
std::vector<std::unique_ptr<framework::OpDesc>> ops;
auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx
auto out_grads = InputGrad(framework::GradVarName("Out"));
if (!out_grads.empty()) {
auto* out_grad_op = new framework::OpDesc();
out_grad_op->SetType("reduce_sum");
out_grad_op->SetInput("X", x_gg);
out_grad_op->SetAttrMap(Attrs());
out_grad_op->SetOutput("Out", out_grads);
ops.emplace_back(out_grad_op);
}
return ops;
}
};
class ReduceSumDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
public:
using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;
std::shared_ptr<imperative::GradOpNode> operator()() const override {
auto out_grads = InputGrad(framework::GradVarName("Out"));
if (!out_grads.empty()) {
auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx
auto node = this->NewGradNode();
{
imperative::TracedGradOp op(node);
op.SetType("reduce_sum");
op.SetInput("X", x_gg);
op.SetAttrMap(Attrs());
op.SetOutput("Out", out_grads);
}
return node;
} else {
return nullptr;
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -77,6 +121,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, ...@@ -77,6 +121,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>, ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>); ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
i ops::ReduceSumDoubleGradDescMaker,
ops::ReduceSumDoubleGradOpBaseMaker,
ops::ReduceSumGradNoNeedBufferVarInferer); ops::ReduceSumGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -101,6 +101,29 @@ class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase): ...@@ -101,6 +101,29 @@ class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestReduceSumWithDimDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [7, 11]
eps = 0.05
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.reduce_sum(x, dim=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMulDoubleGradCheck(unittest.TestCase): class TestMulDoubleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册