提交 a57e8a43 编写于 作者: C chengduoZH

add cpu test

上级 1e731f59
...@@ -67,7 +67,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -67,7 +67,7 @@ void AllReduceOpHandle::RunImpl() {
if (platform::is_gpu_place(lod_tensors[0]->place())) { if (platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(nccl_ctxs_); PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1; int dtype = -1;
size_t numel = 0; size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
......
...@@ -119,11 +119,10 @@ class ParallelExecutor(object): ...@@ -119,11 +119,10 @@ class ParallelExecutor(object):
if use_cuda: if use_cuda:
# Experiments on se-resnext shows that too many threads hurt # Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future. # performance. Worth tunning for other models in the future.
exec_strategy.num_threads = len(self._places) * 2 exec_strategy.num_threads = len(self._places) * 4
else: else:
cpu_num = int( # Currently num_threads must be 1.
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) exec_strategy.num_threads = 1
exec_strategy.num_threads = min(len(self._places) * 2, cpu_num)
if build_strategy is None: if build_strategy is None:
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import multiprocessing
import os
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import time import time
...@@ -73,7 +75,9 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -73,7 +75,9 @@ class TestParallelExecutorBase(unittest.TestCase):
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count() batch_size *= fluid.core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time() begin = time.time()
first_loss, = run_executor( first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name]) exe=exe, feed=feed_dict, fetch_list=[loss.name])
......
...@@ -104,8 +104,9 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -104,8 +104,9 @@ class TestMNIST(TestParallelExecutorBase):
def check_simple_fc_convergence(self, def check_simple_fc_convergence(self,
balance_parameter_opt_between_cards, balance_parameter_opt_between_cards,
use_cuda=True): use_cuda=True):
self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence(simple_fc_net, allow_op_delay=True) self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True)
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
...@@ -142,6 +143,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -142,6 +143,7 @@ class TestMNIST(TestParallelExecutorBase):
seed=1000, seed=1000,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda,
use_parallel_executor=True, use_parallel_executor=True,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
) )
...@@ -161,7 +163,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -161,7 +163,7 @@ class TestMNIST(TestParallelExecutorBase):
def check_batchnorm_fc_convergence( def check_batchnorm_fc_convergence(
self, balance_parameter_opt_between_cards, use_cuda): self, balance_parameter_opt_between_cards, use_cuda):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda)
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence( self.check_network_convergence(
......
...@@ -133,27 +133,28 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False): ...@@ -133,27 +133,28 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
def check_resnet_convergence(self, def check_resnet_convergence(self,
balance_parameter_opt_between_cards, balance_parameter_opt_between_cards,
use_cuda=True): use_cuda=True,
iter=20):
import functools import functools
batch_size = 2 batch_size = 2
self.check_network_convergence( self.check_network_convergence(
functools.partial( functools.partial(
SE_ResNeXt50Small, batch_size=batch_size), SE_ResNeXt50Small, batch_size=batch_size),
iter=20, iter=iter,
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda, use_cuda=use_cuda,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
) )
def test_resnet(self): def test_resnet(self):
# os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
self.check_resnet_convergence(False, use_cuda=True) self.check_resnet_convergence(False, use_cuda=True)
# self.check_resnet_convergence(False,use_cuda=False) self.check_resnet_convergence(False, use_cuda=False, iter=5)
def test_resnet_with_new_strategy(self): def test_resnet_with_new_strategy(self):
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
self.check_resnet_convergence(True, use_cuda=True) self.check_resnet_convergence(True, use_cuda=True)
self.check_resnet_convergence(True, use_cuda=False) self.check_resnet_convergence(True, use_cuda=False, iter=5)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册