未验证 提交 476470da 编写于 作者: C Charles-hit 提交者: GitHub

add unit test for sum higher level op (#45961)

上级 ff1da188
...@@ -663,6 +663,78 @@ class TestAddNTripleGradCheck(unittest.TestCase): ...@@ -663,6 +663,78 @@ class TestAddNTripleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestSumDoubleGradCheck(unittest.TestCase):
def sum_wrapper(self, x):
return paddle.sum(x[0], axis=1, keepdim=True)
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float32
data = layers.data('data', [2, 4], False, dtype)
data.persistable = True
out = paddle.sum(data, axis=1, keepdim=True)
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.double_grad_check([data],
out,
x_init=[data_arr],
place=place,
eps=eps)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph(self.sum_wrapper, [data],
out,
x_init=[data_arr],
place=place)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestSumTripleGradCheck(unittest.TestCase):
def sum_wrapper(self, x):
return paddle.sum(x[0], axis=1, keepdim=True)
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float32
data = layers.data('data', [2, 4], False, dtype)
data.persistable = True
out = paddle.sum(data, axis=1, keepdim=True)
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.triple_grad_check([data],
out,
x_init=[data_arr],
place=place,
eps=eps)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(self.sum_wrapper, [data],
out,
x_init=[data_arr],
place=place)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
enable_static() enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册