提交 4f304eaa 编写于 作者: Y Yancey1989

fix unittest test=develop

上级 c722b1dc
...@@ -197,9 +197,17 @@ ParallelExecutor::ParallelExecutor( ...@@ -197,9 +197,17 @@ ParallelExecutor::ParallelExecutor(
PADDLE_ENFORCE(places.size() > 1, PADDLE_ENFORCE(places.size() > 1,
"If you set build_strategy.reduce with 'Reduce'," "If you set build_strategy.reduce with 'Reduce',"
"the number of places must be greater than 1."); "the number of places must be greater than 1.");
PADDLE_ENFORCE(exec_strategy.type_ != ExecutionStrategy::kParallelGraph, }
"You should set build_strategy.reduce with 'AllReduce' for "
"the ParallelGraph executor type"); if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
PADDLE_ENFORCE(
member_->use_all_reduce_,
"build_strategy.reduce should be `AllReduce` if you want to use"
"ParallelGraph executor.");
PADDLE_ENFORCE(
member_->use_cuda_,
"execution_strategy.use_cuda should be True if you want to use"
"ParallelGraph executor.");
} }
// Step 1. Bcast the params to devs. // Step 1. Bcast the params to devs.
......
...@@ -166,6 +166,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -166,6 +166,8 @@ class TestMNIST(TestParallelExecutorBase):
def check_batchnorm_fc_convergence(self, use_cuda, exec_type): def check_batchnorm_fc_convergence(self, use_cuda, exec_type):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
if not use_cuda and exec_type == ExecutorType.ParallelGraph:
return
img, label = self._init_data() img, label = self._init_data()
......
...@@ -173,10 +173,6 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -173,10 +173,6 @@ class TestTransformer(TestParallelExecutorBase):
def test_main(self): def test_main(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.check_network_convergence(transformer, use_cuda=True) self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(
transformer,
use_cuda=True,
exec_type=ExecutorType.ParallelGraph)
self.check_network_convergence( self.check_network_convergence(
transformer, use_cuda=True, enable_sequential_execution=True) transformer, use_cuda=True, enable_sequential_execution=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5) self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册