提交 106e2852 编写于 作者: Y Yancey1989

add unittest for parllelgraph mode test=develop

上级 5cc83f79
...@@ -300,7 +300,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -300,7 +300,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph; ir::Graph &result = *graph;
// int num_trainers = Get<int>(kNumTrainers); int num_trainers = Get<int>(kNumTrainers);
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->IsVar() && node->Var()) { if (node->IsVar() && node->Var()) {
...@@ -387,7 +387,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -387,7 +387,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} }
// if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) { // if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
if (!is_forwarding && nccl_ctxs_->contexts_.size() > 1) { // insert synchronous ops at the backpropagation; and
// insert synchronous ops if the graph contains mutilple places.
if (!is_forwarding &&
(places_.size() > 1 || num_trainers > 1 ||
(nccl_ctxs_ && nccl_ctxs_->contexts_.size() > 1))) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr( if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
......
...@@ -49,18 +49,18 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -49,18 +49,18 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList { auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
return executors_[i]->Run(fetch_tensors); try {
return executors_[i]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
return FeedFetchList();
}; };
if (pool_) { if (pool_) {
run_futures.emplace_back(pool_->enqueue(std::move(call))); run_futures.emplace_back(pool_->enqueue(std::move(call)));
} else { } else {
try { call();
fetch_datas.emplace_back(std::move(call()));
} catch (...) {
exception_holder_.Catch(std::current_exception());
break;
}
} }
} }
...@@ -69,11 +69,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -69,11 +69,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
f.wait(); f.wait();
} else { } else {
try { fetch_datas.emplace_back(std::move(f.get()));
fetch_datas.emplace_back(std::move(f.get()));
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
} }
} }
} }
......
...@@ -87,7 +87,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -87,7 +87,7 @@ ParallelExecutor::ParallelExecutor(
"the number of places must be greater than 1."); "the number of places must be greater than 1.");
PADDLE_ENFORCE(exec_strategy.type_ != ExecutionStrategy::kParallelGraph, PADDLE_ENFORCE(exec_strategy.type_ != ExecutionStrategy::kParallelGraph,
"You should set build_strategy.reduce with 'AllReduce' for " "You should set build_strategy.reduce with 'AllReduce' for "
"ParallelGraph executor type"); "the ParallelGraph executor type");
} }
// Step 1. Bcast the params to devs. // Step 1. Bcast the params to devs.
......
...@@ -48,7 +48,7 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -48,7 +48,7 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class CTRReader : public framework::FileReader { class CTRReader : public framework::FileReader {
public: public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue, explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, int thread_num, int batch_size, size_t thread_num,
const std::vector<std::string>& slots, const std::vector<std::string>& slots,
const std::vector<std::string>& file_list) const std::vector<std::string>& file_list)
: batch_size_(batch_size), slots_(slots), file_list_(file_list) { : batch_size_(batch_size), slots_(slots), file_list_(file_list) {
......
...@@ -26,23 +26,26 @@ import sys ...@@ -26,23 +26,26 @@ import sys
__all__ = ['TestParallelExecutorBase'] __all__ = ['TestParallelExecutorBase']
ExecutorType = fluid.ExecutionStrategy().ExecutorType
class TestParallelExecutorBase(unittest.TestCase): class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self, def check_network_convergence(
method, self,
use_cuda=True, method,
memory_opt=True, use_cuda=True,
iter=50, memory_opt=True,
batch_size=None, iter=50,
allow_op_delay=False, batch_size=None,
feed_dict=None, allow_op_delay=False,
seed=None, feed_dict=None,
use_parallel_executor=True, seed=None,
use_reduce=False, use_parallel_executor=True,
fuse_elewise_add_act_ops=False, use_reduce=False,
optimizer=fluid.optimizer.Adam, fuse_elewise_add_act_ops=False,
use_fast_executor=False, optimizer=fluid.optimizer.Adam,
enable_sequential_execution=False): exec_type=fluid.ExecutionStrategy().ExecutorType.Default,
enable_sequential_execution=False):
def run_executor(exe, feed, fetch_list, program=None): def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor): if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed) res = exe.run(fetch_list=fetch_list, feed=feed)
...@@ -58,68 +61,69 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -58,68 +61,69 @@ class TestParallelExecutorBase(unittest.TestCase):
startup = fluid.Program() startup = fluid.Program()
startup.random_seed = 1 # Fix random seed startup.random_seed = 1 # Fix random seed
main.random_seed = 1 main.random_seed = 1
with fluid.program_guard(main, startup): scope = fluid.Scope()
if seed is not None: with fluid.scope_guard(scope):
startup.random_seed = seed with fluid.program_guard(main, startup):
main.random_seed = seed if seed is not None:
startup.random_seed = seed
loss = method(use_feed=feed_dict is not None) main.random_seed = seed
optimizer().minimize(loss) loss = method(use_feed=feed_dict is not None)
if memory_opt: optimizer().minimize(loss)
fluid.memory_optimize(main)
if memory_opt:
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() fluid.memory_optimize(main)
startup_exe = fluid.Executor(place)
startup_exe.run(startup) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exec_strategy = fluid.ExecutionStrategy() startup_exe = fluid.Executor(place)
exec_strategy.allow_op_delay = allow_op_delay startup_exe.run(startup)
if use_fast_executor: exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True exec_strategy.allow_op_delay = allow_op_delay
exec_strategy.executor_type = exec_type
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ build_strategy = fluid.BuildStrategy()
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
if use_cuda and core.is_compiled_with_cuda(): build_strategy.enable_sequential_execution = enable_sequential_execution
build_strategy.remove_unnecessary_lock = True if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
exe = fluid.ParallelExecutor( if use_parallel_executor:
use_cuda, exe = fluid.ParallelExecutor(
loss_name=loss.name, use_cuda,
exec_strategy=exec_strategy, loss_name=loss.name,
build_strategy=build_strategy) exec_strategy=exec_strategy,
else: build_strategy=build_strategy)
exe = fluid.Executor(place=place) else:
exe = fluid.Executor(place=place)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count( if batch_size is not None:
) if use_cuda else int( batch_size *= fluid.core.get_cuda_device_count(
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) ) if use_cuda else int(
begin = time.time() os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
first_loss, = run_executor( begin = time.time()
exe=exe, feed=feed_dict, fetch_list=[loss.name]) first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[]) for i in range(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name]) last_loss, = run_executor(
end = time.time() exe=exe, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
if batch_size is not None:
print("%.4f Instance per second" % ( if batch_size is not None:
(batch_size * iter + 2) / (end - begin))) print("%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin)))
avg_last_loss_val = np.array(last_loss).mean()
avg_first_loss_val = np.array(first_loss).mean() avg_last_loss_val = np.array(last_loss).mean()
if math.isnan(float(avg_last_loss_val)) or math.isnan( avg_first_loss_val = np.array(first_loss).mean()
float(avg_first_loss_val)): if math.isnan(float(avg_last_loss_val)) or math.isnan(
sys.exit("got NaN loss, training failed.") float(avg_first_loss_val)):
sys.exit("got NaN loss, training failed.")
print(first_loss, last_loss)
# self.assertGreater(first_loss[0], last_loss[0]) print(first_loss, last_loss)
return first_loss, last_loss # self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
...@@ -181,6 +181,9 @@ class TestCRFModel(unittest.TestCase): ...@@ -181,6 +181,9 @@ class TestCRFModel(unittest.TestCase):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=True) is_sparse=True, build_strategy=build_strategy, use_cuda=True)
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=True)
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False) is_sparse=True, build_strategy=build_strategy, use_cuda=False)
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
import os import os
import paddle.fluid as fluid import paddle.fluid as fluid
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase, ExecutorType
def simple_fc_net(use_feed): def simple_fc_net(use_feed):
...@@ -99,7 +99,10 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -99,7 +99,10 @@ class TestMNIST(TestParallelExecutorBase):
self.assertAlmostEqual(loss[0], loss[1], delta=1e-4) self.assertAlmostEqual(loss[0], loss[1], delta=1e-4)
# simple_fc # simple_fc
def check_simple_fc_convergence(self, use_cuda, use_reduce=False): def check_simple_fc_convergence(self,
use_cuda,
use_reduce=False,
exec_type=ExecutorType.Default):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
...@@ -110,19 +113,21 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -110,19 +113,21 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=use_reduce) use_reduce=use_reduce,
exec_type=exec_type)
def test_simple_fc(self): def test_simple_fc(self):
# use_cuda # use_cuda
self.check_simple_fc_convergence(True) self.check_simple_fc_convergence(True, ExecutorType.Default)
self.check_simple_fc_convergence(True, ExecutorType.ParallelGraph)
self.check_simple_fc_convergence(False) self.check_simple_fc_convergence(False)
def test_simple_fc_with_new_strategy(self): def test_simple_fc_with_new_strategy(self):
# use_cuda, use_reduce # use_cuda, use_reducea
self._compare_reduce_and_allreduce(simple_fc_net, True) self._compare_reduce_and_allreduce(simple_fc_net, True)
self._compare_reduce_and_allreduce(simple_fc_net, False) self._compare_reduce_and_allreduce(simple_fc_net, False)
def check_simple_fc_parallel_accuracy(self, use_cuda): def check_simple_fc_parallel_accuracy(self, use_cuda, exec_type):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
...@@ -134,14 +139,16 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -134,14 +139,16 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_parallel_executor=False) use_parallel_executor=False,
exec_type=exec_type)
parallel_first_loss, parallel_last_loss = self.check_network_convergence( parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net, method=simple_fc_net,
seed=1, seed=1,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_parallel_executor=True) use_parallel_executor=True,
exec_type=exec_type)
self.assertAlmostEquals( self.assertAlmostEquals(
np.mean(parallel_first_loss), np.mean(parallel_first_loss),
...@@ -151,10 +158,12 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -151,10 +158,12 @@ class TestMNIST(TestParallelExecutorBase):
np.mean(parallel_last_loss), single_last_loss, delta=1e-6) np.mean(parallel_last_loss), single_last_loss, delta=1e-6)
def test_simple_fc_parallel_accuracy(self): def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy(True) self.check_simple_fc_parallel_accuracy(True, ExecutorType.Default)
self.check_simple_fc_parallel_accuracy(False) self.check_simple_fc_parallel_accuracy(True, ExecutorType.ParallelGraph)
# FIXME(Yancey1989): ParallelGraph executor type support CPU mode
self.check_simple_fc_parallel_accuracy(False, ExecutorType.Default)
def check_batchnorm_fc_convergence(self, use_cuda, use_fast_executor): def check_batchnorm_fc_convergence(self, use_cuda, exec_type):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
...@@ -165,12 +174,13 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -165,12 +174,13 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_fast_executor=use_fast_executor) exec_type=exec_type)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
for use_cuda in (False, True): for use_cuda in (False, True):
for use_fast_executor in (False, True): for exec_type in (ExecutorType.Default, ExecutorType.Experimental,
self.check_batchnorm_fc_convergence(use_cuda, use_fast_executor) ExecutorType.ParallelGraph):
self.check_batchnorm_fc_convergence(use_cuda, exec_type)
def test_batchnorm_fc_with_new_strategy(self): def test_batchnorm_fc_with_new_strategy(self):
# FIXME(zcd): close this test temporally. # FIXME(zcd): close this test temporally.
......
...@@ -19,7 +19,7 @@ import paddle.fluid.layers.ops as ops ...@@ -19,7 +19,7 @@ import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import paddle.fluid.core as core import paddle.fluid.core as core
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase, ExecutorType
import unittest import unittest
import math import math
import os import os
...@@ -167,13 +167,17 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120): ...@@ -167,13 +167,17 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120):
return decayed_lr return decayed_lr
def optimizer(learning_rate=0.01): def optimizer(learning_rate=0.01, lr_scale=1.0):
optimizer = fluid.optimizer.Momentum( def _opt():
learning_rate=cosine_decay( return fluid.optimizer.Momentum(
learning_rate=learning_rate, step_each_epoch=2, epochs=1), learning_rate=cosine_decay(
momentum=0.9, learning_rate=learning_rate / lr_scale,
regularization=fluid.regularizer.L2Decay(1e-4)) step_each_epoch=2,
return optimizer epochs=1),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
return _opt
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
...@@ -216,7 +220,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -216,7 +220,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=False, use_reduce=False,
optimizer=optimizer) optimizer=optimizer())
reduce_first_loss, reduce_last_loss = self.check_network_convergence( reduce_first_loss, reduce_last_loss = self.check_network_convergence(
model, model,
feed_dict={"image": img, feed_dict={"image": img,
...@@ -225,7 +229,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -225,7 +229,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=True, use_reduce=True,
optimizer=optimizer) optimizer=optimizer())
for loss in zip(all_reduce_first_loss, reduce_first_loss): for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
...@@ -243,7 +247,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -243,7 +247,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=False, use_reduce=False,
optimizer=optimizer, optimizer=optimizer(),
enable_sequential_execution=True) enable_sequential_execution=True)
reduce_first_loss_seq, reduce_last_loss_seq = self.check_network_convergence( reduce_first_loss_seq, reduce_last_loss_seq = self.check_network_convergence(
...@@ -254,7 +258,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -254,7 +258,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=True, use_reduce=True,
optimizer=optimizer, optimizer=optimizer(),
enable_sequential_execution=True) enable_sequential_execution=True)
for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq): for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq):
...@@ -277,7 +281,9 @@ class TestResnet(TestParallelExecutorBase): ...@@ -277,7 +281,9 @@ class TestResnet(TestParallelExecutorBase):
use_cuda=True, use_cuda=True,
use_reduce=False, use_reduce=False,
iter=20, iter=20,
delta2=1e-6): delta2=1e-6,
exec_type=ExecutorType.Default,
lr_scale=1.0):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
...@@ -295,8 +301,9 @@ class TestResnet(TestParallelExecutorBase): ...@@ -295,8 +301,9 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=use_reduce, use_reduce=use_reduce,
optimizer=optimizer, optimizer=optimizer(),
use_parallel_executor=False) use_parallel_executor=False,
exec_type=exec_type)
parallel_first_loss, parallel_last_loss = self.check_network_convergence( parallel_first_loss, parallel_last_loss = self.check_network_convergence(
model, model,
feed_dict={"image": img, feed_dict={"image": img,
...@@ -305,7 +312,8 @@ class TestResnet(TestParallelExecutorBase): ...@@ -305,7 +312,8 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=use_reduce, use_reduce=use_reduce,
optimizer=optimizer) optimizer=optimizer(lr_scale=lr_scale),
exec_type=exec_type)
self.assertAlmostEquals( self.assertAlmostEquals(
np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6) np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6)
...@@ -313,7 +321,14 @@ class TestResnet(TestParallelExecutorBase): ...@@ -313,7 +321,14 @@ class TestResnet(TestParallelExecutorBase):
np.mean(parallel_last_loss), single_last_loss[0], delta=delta2) np.mean(parallel_last_loss), single_last_loss[0], delta=delta2)
def test_seresnext_with_learning_rate_decay(self): def test_seresnext_with_learning_rate_decay(self):
self._check_resnet_convergence(model=SE_ResNeXt50Small, use_cuda=True) if core.is_compiled_with_cuda():
self._check_resnet_convergence(
model=SE_ResNeXt50Small, use_cuda=True)
self._check_resnet_convergence(
model=SE_ResNeXt50Small,
use_cuda=True,
exec_type=ExecutorType.ParallelGraph,
lr_scale=core.get_cuda_device_count())
self._check_resnet_convergence( self._check_resnet_convergence(
model=SE_ResNeXt50Small, use_cuda=False, iter=2, delta2=1e-3) model=SE_ResNeXt50Small, use_cuda=False, iter=2, delta2=1e-3)
......
...@@ -17,7 +17,7 @@ from __future__ import print_function ...@@ -17,7 +17,7 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
import transformer_model import transformer_model
import numpy as np import numpy as np
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase, ExecutorType
import unittest import unittest
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -173,6 +173,10 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -173,6 +173,10 @@ class TestTransformer(TestParallelExecutorBase):
def test_main(self): def test_main(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.check_network_convergence(transformer, use_cuda=True) self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(
transformer,
use_cuda=True,
exec_type=ExecutorType.ParallelGraph)
self.check_network_convergence( self.check_network_convergence(
transformer, use_cuda=True, enable_sequential_execution=True) transformer, use_cuda=True, enable_sequential_execution=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5) self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册