提交 4f304eaa 编写于 作者: Y Yancey1989

fix unittest test=develop

上级 c722b1dc
......@@ -197,9 +197,17 @@ ParallelExecutor::ParallelExecutor(
PADDLE_ENFORCE(places.size() > 1,
"If you set build_strategy.reduce with 'Reduce',"
"the number of places must be greater than 1.");
PADDLE_ENFORCE(exec_strategy.type_ != ExecutionStrategy::kParallelGraph,
"You should set build_strategy.reduce with 'AllReduce' for "
"the ParallelGraph executor type");
}
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
PADDLE_ENFORCE(
member_->use_all_reduce_,
"build_strategy.reduce should be `AllReduce` if you want to use"
"ParallelGraph executor.");
PADDLE_ENFORCE(
member_->use_cuda_,
"execution_strategy.use_cuda should be True if you want to use"
"ParallelGraph executor.");
}
// Step 1. Bcast the params to devs.
......
......@@ -166,6 +166,8 @@ class TestMNIST(TestParallelExecutorBase):
def check_batchnorm_fc_convergence(self, use_cuda, exec_type):
if use_cuda and not core.is_compiled_with_cuda():
return
if not use_cuda and exec_type == ExecutorType.ParallelGraph:
return
img, label = self._init_data()
......
......@@ -173,10 +173,6 @@ class TestTransformer(TestParallelExecutorBase):
def test_main(self):
if core.is_compiled_with_cuda():
self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(
transformer,
use_cuda=True,
exec_type=ExecutorType.ParallelGraph)
self.check_network_convergence(
transformer, use_cuda=True, enable_sequential_execution=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册