From 9adf4738e98b38a7b5ae4895bb34f279f4282f0b Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sun, 27 Sep 2020 14:43:37 +0000 Subject: [PATCH] fix failed unittests --- .../unittests/test_imperative_parallel_coalesce_split.py | 7 ++++--- python/paddle/optimizer/adam.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py b/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py index e5c32d00038..480df7482e3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_parallel_coalesce_split.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid.dygraph.parallel import DataParallel from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.dygraph.parallel import _coalesce_tensors, _split_tensors, _reshape_inplace class MyLayer(fluid.Layer): @@ -57,8 +58,8 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase): orig_var_shapes.append(var.shape) # execute interface - coalesced_vars = test_layer._coalesce_tensors(var_groups) - test_layer._split_tensors(coalesced_vars) + coalesced_vars = _coalesce_tensors(var_groups) + _split_tensors(coalesced_vars) # compare for orig_var_shape, var in zip(orig_var_shapes, vars): @@ -74,7 +75,7 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase): new_shape = [5, 10] x_data = np.random.random(ori_shape).astype("float32") x = to_variable(x_data) - test_layer._reshape_inplace(x, new_shape) + _reshape_inplace(x, new_shape) self.assertEqual(x.shape, new_shape) diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 568f0b9d8f1..9cbb45ce60d 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -17,6 +17,7 @@ from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable +import paddle from paddle.fluid.dygraph.parallel import apply_collective_grads __all__ = ["Adam"] -- GitLab