未验证 提交 65c06141 编写于 作者: L Leo Chen 提交者: GitHub

disable_fuse_all_reduce (#27746)

* disable_fuse_all_reduce

* fix test

* fix ut
上级 f6ad2375
......@@ -34,6 +34,7 @@ class InplaceTestBase(unittest.TestCase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = False
self.fuse_all_reduce_ops = False
def setUp(self):
paddle.enable_static()
......@@ -93,6 +94,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = self.fuse_all_reduce_ops
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
......@@ -146,6 +148,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = self.fuse_all_reduce_ops
compiled_program = fluid.CompiledProgram(
prog).with_data_parallel(
loss_name=loss.name,
......@@ -175,6 +178,7 @@ class CUDAInplaceTest(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = False
self.fuse_all_reduce_ops = False
def test_multi_card_fetch_var(self):
self.check_multi_card_fetch_var()
......@@ -187,6 +191,7 @@ class CPUInplaceTest(InplaceTestBase):
def initParameter(self):
self.use_cuda = False
self.fuse_all_optimizer_ops = False
self.fuse_all_reduce_ops = False
def test_multi_card_fetch_var(self):
self.check_multi_card_fetch_var()
......
......@@ -20,6 +20,7 @@ class CUDAInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = True
self.fuse_all_reduce_ops = False
def test_multi_card_fetch_var(self):
self.check_multi_card_fetch_var()
......@@ -32,6 +33,7 @@ class CPUInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = False
self.fuse_all_optimizer_ops = True
self.fuse_all_reduce_ops = False
def test_multi_card_fetch_var(self):
self.check_multi_card_fetch_var()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册