未验证 提交 ba891437 编写于 作者: Q Qinghe JING 提交者: GitHub

Add double grad in reduce sum release 1.8 (#27164)

* add double grad to reduce sum
上级 448d5544
......@@ -52,6 +52,20 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInference, "X");
template <typename T>
class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
op->SetType("reduce_sum");
}
};
class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
public:
void operator()(paddle::framework::InferVarTypeContext* ctx) const override {
......@@ -77,6 +91,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
ops::ReduceSumGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
......
......@@ -101,6 +101,29 @@ class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestReduceSumWithDimDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [7, 11]
eps = 0.05
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.reduce_sum(x, dim=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMulDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册