未验证 提交 d422a1ed 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Handled special sum_grad_op code gen in Eager Dygraph (#38573)

* Handled special sum_grad_op code gen in Eager Dygraph

* Fixed merge issues
上级 89c0877e
...@@ -321,7 +321,7 @@ class TestImperative(unittest.TestCase): ...@@ -321,7 +321,7 @@ class TestImperative(unittest.TestCase):
with paddle.set_grad_enabled(True): with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled()) self.assertTrue(paddle.is_grad_enabled())
def test_sum_op(self): def func_sum_op(self):
x = np.ones([2, 2], np.float32) x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
inputs = [] inputs = []
...@@ -338,7 +338,7 @@ class TestImperative(unittest.TestCase): ...@@ -338,7 +338,7 @@ class TestImperative(unittest.TestCase):
tmp = paddle.to_tensor(x) tmp = paddle.to_tensor(x)
tmp.stop_gradient = False tmp.stop_gradient = False
inputs2.append(tmp) inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2) ret2 = paddle.add_n(inputs2)
loss2 = fluid.layers.reduce_sum(ret2) loss2 = fluid.layers.reduce_sum(ret2)
fluid.set_flags({'FLAGS_sort_sum_gradient': True}) fluid.set_flags({'FLAGS_sort_sum_gradient': True})
loss2.backward() loss2.backward()
...@@ -349,6 +349,11 @@ class TestImperative(unittest.TestCase): ...@@ -349,6 +349,11 @@ class TestImperative(unittest.TestCase):
a = inputs2[0].gradient() a = inputs2[0].gradient()
self.assertTrue(np.allclose(inputs2[0].gradient(), x)) self.assertTrue(np.allclose(inputs2[0].gradient(), x))
def test_sum_op(self):
with _test_eager_guard():
self.func_sum_op()
self.func_sum_op()
def func_empty_var(self): def func_empty_var(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
cur_program = fluid.Program() cur_program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册