提交 9adf4738 编写于 作者: C Chen Weihang

fix failed unittests

上级 5903151f
...@@ -22,6 +22,7 @@ import paddle.fluid as fluid ...@@ -22,6 +22,7 @@ import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph.parallel import DataParallel from paddle.fluid.dygraph.parallel import DataParallel
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import _coalesce_tensors, _split_tensors, _reshape_inplace
class MyLayer(fluid.Layer): class MyLayer(fluid.Layer):
...@@ -57,8 +58,8 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase): ...@@ -57,8 +58,8 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase):
orig_var_shapes.append(var.shape) orig_var_shapes.append(var.shape)
# execute interface # execute interface
coalesced_vars = test_layer._coalesce_tensors(var_groups) coalesced_vars = _coalesce_tensors(var_groups)
test_layer._split_tensors(coalesced_vars) _split_tensors(coalesced_vars)
# compare # compare
for orig_var_shape, var in zip(orig_var_shapes, vars): for orig_var_shape, var in zip(orig_var_shapes, vars):
...@@ -74,7 +75,7 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase): ...@@ -74,7 +75,7 @@ class TestImperativeParallelCoalesceSplit(unittest.TestCase):
new_shape = [5, 10] new_shape = [5, 10]
x_data = np.random.random(ori_shape).astype("float32") x_data = np.random.random(ori_shape).astype("float32")
x = to_variable(x_data) x = to_variable(x_data)
test_layer._reshape_inplace(x, new_shape) _reshape_inplace(x, new_shape)
self.assertEqual(x.shape, new_shape) self.assertEqual(x.shape, new_shape)
......
...@@ -17,6 +17,7 @@ from ..fluid import core ...@@ -17,6 +17,7 @@ from ..fluid import core
from ..fluid import framework from ..fluid import framework
from ..fluid.framework import Variable from ..fluid.framework import Variable
import paddle
from paddle.fluid.dygraph.parallel import apply_collective_grads from paddle.fluid.dygraph.parallel import apply_collective_grads
__all__ = ["Adam"] __all__ = ["Adam"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册